[
  {
    "path": ".gitignore",
    "content": "\n.DS_Store\n\n# Compiled python modules.\n*.pyc\n\n# Setuptools distribution folder.\n/dist/\n\n# Python egg metadata, regenerated from source files by setuptools.\n/*.egg-info\n/*.egg\nsgdml/_bmark_cache.npz\n"
  },
  {
    "path": "LICENSE.txt",
    "content": "MIT License\n\nCopyright (c) 2018-2022 Stefan Chmiela\n\nPermission is hereby granted, free of charge, to any person obtaining a copy\nof this software and associated documentation files (the \"Software\"), to deal\nin the Software without restriction, including without limitation the rights\nto use, copy, modify, merge, publish, distribute, sublicense, and/or sell\ncopies of the Software, and to permit persons to whom the Software is\nfurnished to do so, subject to the following conditions:\n\nThe above copyright notice and this permission notice shall be included in all\ncopies or substantial portions of the Software.\n\nTHE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\nIMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\nFITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\nAUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\nLIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\nOUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\nSOFTWARE."
  },
  {
    "path": "README.md",
    "content": "# Symmetric Gradient Domain Machine Learning (sGDML)\n\nFor more details visit: [sgdml.org](http://sgdml.org/)  \nDocumentation can be found here: [docs.sgdml.org](http://docs.sgdml.org/)\n\n#### Requirements:\n- Python 3.7+\n- PyTorch (>=1.8)\n- NumPy (>=1.19)\n- SciPy (>=1.1)\n\n#### Optional:\n- ASE (>=3.16.2) (to run atomistic simulations)\n\n## Getting started\n\n### Stable release\n\nMost systems come with the default package manager for Python ``pip`` already preinstalled. Install ``sgdml`` by simply calling:\n\n```\n$ pip install sgdml\n```\n\nThe ``sgdml`` command-line interface and the corresponding Python API can now be used from anywhere on the system.\n\n### Development version\n\n#### (1) Clone the repository\n\n```\n$ git clone https://github.com/stefanch/sGDML.git\n$ cd sGDML\n```\n\n...or update your existing local copy with\n\n```\n$ git pull origin master\n```\n\n#### (2) Install\n\n```\n$ pip install -e .\n```\n\nUsing the flag ``--user``, you can tell ``pip`` to install the package to the current users's home directory, instead of system-wide. This option might require you to update your system's ``PATH`` variable accordingly.\n\n\n### Optional dependencies\n\nSome functionality of this package relies on third-party libraries that are not installed by default. These optional dependencies (or \"package extras\") are specified during installation using the \"square bracket syntax\":\n\n```\n$ pip install sgdml[<optional1>]\n```\n\n#### Atomic Simulation Environment (ASE)\n\nIf you are interested in interfacing with [ASE](https://wiki.fysik.dtu.dk/ase/) to perform atomistic simulations (see [here](http://docs.sgdml.org/applications.html) for examples), use the ``ase`` keyword:\n\n```\n$ pip install sgdml[ase]\n```\n\n## Reconstruct your first force field\n\nDownload one of the example datasets:\n\n```\n$ sgdml-get dataset ethanol_dft\n```\n\nTrain a force field model:\n\n```\n$ sgdml all ethanol_dft.npz 200 1000 5000\n```\n\n## Query a force field\n\n```python\nimport numpy as np\nfrom sgdml.predict import GDMLPredict\nfrom sgdml.utils import io\n\nr,_ = io.read_xyz('geometries/ethanol.xyz') # 9 atoms\nprint(r.shape) # (1,27)\n\nmodel = np.load('models/ethanol.npz')\ngdml = GDMLPredict(model)\ne,f = gdml.predict(r)\nprint(e.shape) # (1,)\nprint(f.shape) # (1,27)\n```\n\n## Authors\n\n* Stefan Chmiela\n* Jan Hermann\n\nWe appreciate and welcome contributions and would like to thank the following people for participating in this project:\n\n* Huziel Sauceda\n* Igor Poltavsky\n* Luis Gálvez\n* Danny Panknin\n* Grégory Fonseca\n* Anton Charkin-Gorbulin\n\n## References\n\n* [1] Chmiela, S., Tkatchenko, A., Sauceda, H. E., Poltavsky, I., Schütt, K. T., Müller, K.-R.,\n*Machine Learning of Accurate Energy-conserving Molecular Force Fields.*\nScience Advances, 3(5), e1603015 (2017)   \n[10.1126/sciadv.1603015](http://dx.doi.org/10.1126/sciadv.1603015)\n\n* [2] Chmiela, S., Sauceda, H. E., Müller, K.-R., Tkatchenko, A.,\n*Towards Exact Molecular Dynamics Simulations with Machine-Learned Force Fields.*\nNature Communications, 9(1), 3887 (2018)   \n[10.1038/s41467-018-06169-2](https://doi.org/10.1038/s41467-018-06169-2)\n\n* [3] Chmiela, S., Sauceda, H. E., Poltavsky, I., Müller, K.-R., Tkatchenko, A.,\n*sGDML: Constructing Accurate and Data Efficient Molecular Force Fields Using Machine Learning.*\nComputer Physics Communications, 240, 38-45 (2019)\n[10.1016/j.cpc.2019.02.007](https://doi.org/10.1016/j.cpc.2019.02.007)\n\n* [4] Chmiela, S., Vassilev-Galindo, V., Unke, O. T., Kabylda, A., Sauceda, H. E., Tkatchenko, A., Müller, K.-R.,\n*Accurate Global Machine Learning Force Fields for Molecules With Hundreds of Atoms.*\nScience Advances, 9(2), e1603015 (2023)\n[10.1126/sciadv.adf0873](https://doi.org/10.1126/sciadv.adf0873)"
  },
  {
    "path": "pyproject.toml",
    "content": "[tool.black]\nskip-string-normalization = true\nskip-numeric-underscore-normalization = true\n"
  },
  {
    "path": "scripts/sgdml_dataset_from_aims.py",
    "content": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2022 Stefan Chmiela\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nfrom __future__ import print_function\n\nimport argparse\nimport os\nimport sys\n\nimport numpy as np\n\nfrom sgdml.utils import io, ui\n\n\ndef read_reference_data(f):  # noqa C901\n    eV_to_kcalmol = 0.036749326 / 0.0015946679\n\n    e_next, f_next, geo_next = False, False, False\n    n_atoms = None\n    R, z, E, F = [], [], [], []\n\n    geo_idx = 0\n    for line in f:\n        if n_atoms:\n            cols = line.split()\n            if e_next:\n                E.append(float(cols[5]))\n                e_next = False\n            elif f_next:\n                a = int(cols[1]) - 1\n                F.append(list(map(float, cols[2:5])))\n                if a == n_atoms - 1:\n                    f_next = False\n            elif geo_next:\n                if 'atom' in cols:\n                    a_count += 1  # noqa: F821\n                    R.append(list(map(float, cols[1:4])))\n\n                    if geo_idx == 0:\n                        z.append(io._z_str_to_z_dict[cols[4]])\n\n                    if a_count == n_atoms:\n                        geo_next = False\n                        geo_idx += 1\n            elif 'Energy and forces in a compact form:' in line:\n                e_next = True\n            elif 'Total atomic forces (unitary forces cleaned) [eV/Ang]:' in line:\n                f_next = True\n            elif (\n                'Atomic structure (and velocities) as used in the preceding time step:'\n                in line\n            ):\n                geo_next = True\n                a_count = 0\n        elif 'The structure contains' in line and 'atoms,  and a total of' in line:\n            n_atoms = int(line.split()[3])\n            print('Number atoms per geometry:      {:>7d}'.format(n_atoms))\n            continue\n\n        if geo_idx > 0 and geo_idx % 1000 == 0:\n            sys.stdout.write(\"\\rNumber geometries found so far: {:>7d}\".format(geo_idx))\n            sys.stdout.flush()\n    sys.stdout.write(\"\\rNumber geometries found so far: {:>7d}\".format(geo_idx))\n    sys.stdout.flush()\n    print(\n        '\\n'\n        + ui.color_str('[INFO]', bold=True)\n        + ' Energies and forces have been converted from eV to kcal/mol(/Ang)'\n    )\n\n    R = np.array(R).reshape(-1, n_atoms, 3)\n    z = np.array(z)\n    E = np.array(E) * eV_to_kcalmol\n    F = np.array(F).reshape(-1, n_atoms, 3) * eV_to_kcalmol\n\n    f.close()\n    return (R, z, E, F)\n\n\nparser = argparse.ArgumentParser(description='Creates a dataset from FHI-aims format.')\nparser.add_argument(\n    'dataset',\n    metavar='<dataset>',\n    type=argparse.FileType('r'),\n    help='path to xyz dataset file',\n)\nparser.add_argument(\n    '-o',\n    '--overwrite',\n    dest='overwrite',\n    action='store_true',\n    help='overwrite existing dataset file',\n)\nargs = parser.parse_args()\ndataset = args.dataset\n\nname = os.path.splitext(os.path.basename(dataset.name))[0]\ndataset_file_name = name + '.npz'\n\ndataset_exists = os.path.isfile(dataset_file_name)\nif dataset_exists and args.overwrite:\n    print(ui.color_str('[INFO]', bold=True) + ' Overwriting existing dataset file.')\nif not dataset_exists or args.overwrite:\n    print('Writing dataset to \\'%s\\'...' % dataset_file_name)\nelse:\n    sys.exit(\n        ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) + ' Dataset \\'%s\\' already exists.' % dataset_file_name\n    )\n\nR, z, E, F = read_reference_data(dataset)\n\n# Prune all arrays to same length.\nn_mols = min(min(R.shape[0], F.shape[0]), E.shape[0])\nif n_mols != R.shape[0] or n_mols != F.shape[0] or n_mols != E.shape[0]:\n    print(\n        ui.color_str('[WARN]', fore_color=ui.YELLOW, bold=True)\n        + ' Incomplete output detected: Final dataset was pruned to %d points.' % n_mols\n    )\nR = R[:n_mols, :, :]\nF = F[:n_mols, :, :]\nE = E[:n_mols]\n\n# Base variables contained in every model file.\nbase_vars = {\n    'type': 'd',\n    'R': R,\n    'z': z,\n    'E': E[:, None],\n    'F': F,\n    'e_unit': 'kcal/mol',\n    'r_unit': 'Ang',\n    'name': name,\n    'theory': 'unknown',\n}\n\nbase_vars['F_min'], base_vars['F_max'] = np.min(F.ravel()), np.max(F.ravel())\nbase_vars['F_mean'], base_vars['F_var'] = np.mean(F.ravel()), np.var(F.ravel())\n\nbase_vars['E_min'], base_vars['E_max'] = np.min(E), np.max(E)\nbase_vars['E_mean'], base_vars['E_var'] = np.mean(E), np.var(E)\n\nbase_vars['md5'] = io.dataset_md5(base_vars)\n\nnp.savez_compressed(dataset_file_name, **base_vars)\nprint(ui.color_str('DONE', fore_color=ui.GREEN, bold=True))\n"
  },
  {
    "path": "scripts/sgdml_dataset_from_extxyz.py",
    "content": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2022 Stefan Chmiela\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nfrom __future__ import print_function\n\nimport argparse\nimport os\nimport sys\n\ntry:\n    from ase.io import read\nexcept ImportError:\n    raise ImportError('Optional ASE dependency not found! Please run \\'pip install sgdml[ase]\\' to install it.')\n\nimport numpy as np\n\nfrom sgdml import __version__\nfrom sgdml.utils import io, ui\n\nif sys.version[0] == '3':\n    raw_input = input\n\n\n# Note: assumes that the atoms in each molecule are in the same order.\ndef read_nonstd_ext_xyz(f):\n    n_atoms = None\n\n    R, z, E, F = [], [], [], []\n    for i, line in enumerate(f):\n        line = line.strip()\n        if not n_atoms:\n            n_atoms = int(line)\n            print('Number atoms per geometry: {:,}'.format(n_atoms))\n\n        file_i, line_i = divmod(i, n_atoms + 2)\n\n        if line_i == 1:\n            try:\n                e = float(line)\n            except ValueError:\n                pass\n            else:\n                E.append(e)\n\n        cols = line.split()\n        if line_i >= 2:\n            R.append(list(map(float, cols[1:4])))\n            if file_i == 0:  # first molecule\n                z.append(io._z_str_to_z_dict[cols[0]])\n            F.append(list(map(float, cols[4:7])))\n\n        if file_i % 1000 == 0:\n            sys.stdout.write('\\rNumber geometries found so far: {:,}'.format(file_i))\n            sys.stdout.flush()\n    sys.stdout.write('\\rNumber geometries found so far: {:,}'.format(file_i))\n    sys.stdout.flush()\n    print()\n\n    R = np.array(R).reshape(-1, n_atoms, 3)\n    z = np.array(z)\n    E = None if not E else np.array(E)\n    F = np.array(F).reshape(-1, n_atoms, 3)\n\n    if F.shape[0] != R.shape[0]:\n        sys.exit(\n            ui.color_str('[FAIL]', fore_color=ui.RED, bold=True)\n            + ' Force labels are missing from dataset or are incomplete!'\n        )\n\n    f.close()\n    return (R, z, E, F)\n\n# Extracts info string for each frame.\ndef extract_info_from_extxyz(file_path):\n    infos = []\n\n    with open(file_path) as f:\n        lines = f.readlines()\n\n    i = 0\n    while i < len(lines):\n        try:\n            n_atoms = int(lines[i])\n        except ValueError:\n            raise ValueError(f\"Invalid atom count at line {i + 1}\")\n\n        if i + 1 >= len(lines):\n            break\n\n        comment_line = lines[i + 1].strip()\n        info = {}\n        for token in comment_line.split():\n            if \"=\" in token:\n                key, val = token.split(\"=\", 1)\n                val = val.strip('\"')\n                try:\n                    val = float(val)\n                except ValueError:\n                    pass\n                info[key] = val\n        infos.append(info)\n\n        i += 2 + n_atoms\n\n    return infos\n\n\nparser = argparse.ArgumentParser(\n    description='Creates a dataset from extended XYZ format.'\n)\nparser.add_argument(\n    'dataset',\n    metavar='<dataset>',\n    type=argparse.FileType('r'),\n    help='path to extended xyz dataset file',\n)\nparser.add_argument(\n    '-o',\n    '--overwrite',\n    dest='overwrite',\n    action='store_true',\n    help='overwrite existing dataset file',\n)\nargs = parser.parse_args()\ndataset = args.dataset\n\n\nname = os.path.splitext(os.path.basename(dataset.name))[0]\ndataset_file_name = name + '.npz'\n\ndataset_exists = os.path.isfile(dataset_file_name)\nif dataset_exists and args.overwrite:\n    print(ui.color_str('[INFO]', bold=True) + ' Overwriting existing dataset file.')\nif not dataset_exists or args.overwrite:\n    print('Writing dataset to \\'{}\\'...'.format(dataset_file_name))\nelse:\n    sys.exit(\n        ui.color_str('[FAIL]', fore_color=ui.RED, bold=True)\n        + ' Dataset \\'{}\\' already exists.'.format(dataset_file_name)\n    )\n\nlattice, R, z, E, F = None, None, None, None, None\n\nmols = read(dataset.name, format='extxyz', index=':')\n#calc = mols[0].get_calculator() # depreciated\ncalc = mols[0].calc\nis_extxyz = calc is not None\nif is_extxyz:\n\n    print(\"\\rNumber geometries found: {:,}\\n\".format(len(mols)))\n\n    if 'forces' not in calc.results:\n        sys.exit(\n            ui.color_str('[FAIL]', fore_color=ui.RED, bold=True)\n            + ' Forces are missing in the input file!'\n        )\n\n    lattice = np.array(mols[0].get_cell().T)\n    if not np.any(lattice): # all zeros\n        print(\n            ui.color_str('[INFO]', bold=True)\n            + ' No lattice vectors specified in extended XYZ file.'\n        )\n        lattice = None\n\n    Z = np.array([mol.get_atomic_numbers() for mol in mols])\n    all_z_the_same = (Z == Z[0]).all()\n    if not all_z_the_same:\n        sys.exit(\n            ui.color_str('[FAIL]', fore_color=ui.RED, bold=True)\n            + ' Order of atoms changes accross dataset.'\n        )\n\n    R = np.array([mol.get_positions() for mol in mols])\n    z = Z[0]\n\n    # ASE did not parse info string. Try doing it manually.\n    if not mols[0].info:\n\n        print(\n            ui.color_str('[INFO]', bold=True)\n            + ' ASE did not parse info string completely. Try doing it manually.'\n        )\n\n        infos = extract_info_from_extxyz(dataset.name)\n        for mol, info in zip(mols, infos):\n            mol.info.update(info)\n\n    if 'Energy' in mols[0].info:\n        E = np.array([mol.info['Energy'] for mol in mols])\n    if 'energy' in mols[0].info:\n        E = np.array([mol.info['energy'] for mol in mols])\n    F = np.array([mol.get_forces() for mol in mols])\n\nelse:  # legacy non-standard XYZ format\n\n    with open(dataset.name) as f:\n        R, z, E, F = read_nonstd_ext_xyz(f)\n\n# Base variables contained in every model file.\nbase_vars = {\n    'type': 'd',\n    'code_version': __version__,\n    'name': name,\n    'theory': 'unknown',\n    'R': R,\n    'z': z,\n    'F': F,\n}\n\nbase_vars['F_min'], base_vars['F_max'] = np.min(F.ravel()), np.max(F.ravel())\nbase_vars['F_mean'], base_vars['F_var'] = np.mean(F.ravel()), np.var(F.ravel())\n\nprint('Please provide a description of the length unit used in your input file, e.g. \\'Ang\\' or \\'au\\': ')\nprint('Note: This string will be stored in the dataset file and passed on to models files for later reference.')\nr_unit = raw_input('> ').strip()\nif r_unit != '':\n    base_vars['r_unit'] = r_unit\n\nprint('Please provide a description of the energy unit used in your input file, e.g. \\'kcal/mol\\' or \\'eV\\': ')\nprint('Note: This string will be stored in the dataset file and passed on to models files for later reference.')\ne_unit = raw_input('> ').strip()\nif e_unit != '':\n    base_vars['e_unit'] = e_unit\n\nif E is not None:\n    base_vars['E'] = E\n    base_vars['E_min'], base_vars['E_max'] = np.min(E), np.max(E)\n    base_vars['E_mean'], base_vars['E_var'] = np.mean(E), np.var(E)\nelse:\n    print(ui.color_str('[INFO]', bold=True) + ' No energy labels found in dataset.')\n\nif lattice is not None:\n    base_vars['lattice'] = lattice\n\nbase_vars['md5'] = io.dataset_md5(base_vars)\nnp.savez_compressed(dataset_file_name, **base_vars)\nprint(ui.color_str('[DONE]', fore_color=ui.GREEN, bold=True))\n"
  },
  {
    "path": "scripts/sgdml_dataset_from_ipi.py",
    "content": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018 Stefan Chmiela\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nfrom __future__ import print_function\n\nimport argparse\nimport os\nimport sys\n\nimport numpy as np\n\nfrom sgdml.utils import io, ui\n\n\ndef raw_input_float(prompt):\n    while True:\n        try:\n            return float(input(prompt))\n        except ValueError:\n            print(ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) + ' That is not a valid float.')\n\n\n# Assumes that the atoms in each molecule are in the same order.\ndef read_concat_xyz(f):\n    n_atoms = None\n\n    R, z = [], []\n    for i, line in enumerate(f):\n        line = line.strip()\n        if not n_atoms:\n            n_atoms = int(line)\n            print('Number atoms per geometry:      {:>7d}'.format(n_atoms))\n\n        file_i, line_i = divmod(i, n_atoms + 2)\n\n        cols = line.split()\n        if line_i >= 2:\n            if file_i == 0:  # first molecule\n                z.append(io._z_str_to_z_dict[cols[0]])\n            R.append(list(map(float, cols[1:4])))\n\n        if file_i % 1000 == 0:\n            sys.stdout.write(\"\\rNumber geometries found so far: {:>7d}\".format(file_i))\n            sys.stdout.flush()\n    sys.stdout.write(\"\\rNumber geometries found so far: {:>7d}\\n\".format(file_i))\n    sys.stdout.flush()\n\n    # Only keep complete entries.\n    R = R[: int(n_atoms * np.floor(len(R) / float(n_atoms)))]\n\n    R = np.array(R).reshape(-1, n_atoms, 3)\n    z = np.array(z)\n\n    f.close()\n    return (R, z)\n\n\ndef read_out_file(f, col):\n\n    E = []\n    for i, line in enumerate(f):\n        line = line.strip()\n        if line[0] != '#':  # Ignore comments.\n            E.append(float(line.split()[col]))\n        if i % 1000 == 0:\n            sys.stdout.write(\"\\rNumber lines processed so far:  {:>7d}\".format(len(E)))\n            sys.stdout.flush()\n    sys.stdout.write(\"\\rNumber lines processed so far:  {:>7d}\\n\".format(len(E)))\n    sys.stdout.flush()\n\n    return np.array(E)\n\n\nparser = argparse.ArgumentParser(\n    description='Creates a dataset from extended [TODO] format.'\n)\nparser.add_argument(\n    'geometries',\n    metavar='<geometries>',\n    type=argparse.FileType('r'),\n    help='path to XYZ geometry file',\n)\nparser.add_argument(\n    'forces',\n    metavar='<forces>',\n    type=argparse.FileType('r'),\n    help='path to XYZ force file',\n)\nparser.add_argument(\n    'energies',\n    metavar='<energies>',\n    type=argparse.FileType('r'),\n    help='path to CSV force file',\n)\nparser.add_argument(\n    'energy_col',\n    metavar='<energy_col>',\n    type=lambda x: io.is_strict_pos_int(x),\n    help='which column to parse from energy file (zero based)',\n    nargs='?',\n    default=0,\n)\nparser.add_argument(\n    '-o',\n    '--overwrite',\n    dest='overwrite',\n    action='store_true',\n    help='overwrite existing dataset file',\n)\nargs = parser.parse_args()\ngeometries = args.geometries\nforces = args.forces\nenergies = args.energies\nenergy_col = args.energy_col\n\nname = os.path.splitext(os.path.basename(geometries.name))[0]\ndataset_file_name = name + '.npz'\n\ndataset_exists = os.path.isfile(dataset_file_name)\nif dataset_exists and args.overwrite:\n    print(ui.color_str('[INFO]', bold=True) + ' Overwriting existing dataset file.')\nif not dataset_exists or args.overwrite:\n    print('Writing dataset to \\'%s\\'...' % dataset_file_name)\nelse:\n    sys.exit(\n        ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) + ' Dataset \\'%s\\' already exists.' % dataset_file_name\n    )\n\n\nprint('Reading geometries...')\nR, z = read_concat_xyz(geometries)\n\nprint('Reading forces...')\nF, _ = read_concat_xyz(forces)\n\nprint('Reading energies from column %d...' % energy_col)\nE = read_out_file(energies, energy_col)\n\n# Prune all arrays to same length.\nn_mols = min(min(R.shape[0], F.shape[0]), E.shape[0])\nif n_mols != R.shape[0] or n_mols != F.shape[0] or n_mols != E.shape[0]:\n    print(\n        ui.color_str('[WARN]', fore_color=ui.YELLOW, bold=True)\n        + ' Incomplete output detected: Final dataset was pruned to %d points.' % n_mols\n    )\nR = R[:n_mols, :, :]\nF = F[:n_mols, :, :]\nE = E[:n_mols]\n\nprint(\n    ui.color_str('[INFO]', bold=True)\n    + ' Geometries, forces and energies must have consistent units.'\n)\nR_conv_fact = raw_input_float('Unit conversion factor for geometries: ')\nR = R * R_conv_fact\nF_conv_fact = raw_input_float('Unit conversion factor for forces: ')\nF = F * F_conv_fact\nE_conv_fact = raw_input_float('Unit conversion factor for energies: ')\nE = E * E_conv_fact\n\n# Base variables contained in every model file.\nbase_vars = {\n    'type': 'd',\n    'R': R,\n    'z': z,\n    'E': E[:, None],\n    'F': F,\n    'name': name,\n    'theory': 'unknown',\n}\nbase_vars['md5'] = io.dataset_md5(base_vars)\n\nnp.savez_compressed(dataset_file_name, **base_vars)\nui.color_str('[DONE]', fore_color=ui.GREEN, bold=True)\n"
  },
  {
    "path": "scripts/sgdml_dataset_to_extxyz.py",
    "content": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2019 Stefan Chmiela\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nfrom __future__ import print_function\n\nimport argparse\nimport os\nimport sys\n\nimport numpy as np\n\nfrom sgdml.utils import io, ui\n\n\nparser = argparse.ArgumentParser(\n    description='Converts a native dataset file to extended XYZ format.'\n)\nparser.add_argument(\n    'dataset',\n    metavar='<dataset>',\n    type=lambda x: io.is_file_type(x, 'dataset'),\n    help='path to dataset file',\n)\nparser.add_argument(\n    '-o',\n    '--overwrite',\n    dest='overwrite',\n    action='store_true',\n    help='overwrite existing xyz dataset file',\n)\n\nargs = parser.parse_args()\ndataset_path, dataset = args.dataset\n\nname = os.path.splitext(os.path.basename(dataset_path))[0]\ndataset_file_name = name + '.xyz'\n\nxyz_exists = os.path.isfile(dataset_file_name)\nif xyz_exists and args.overwrite:\n    print(ui.color_str('[INFO]', bold=True) + ' Overwriting existing xyz dataset file.')\nif not xyz_exists or args.overwrite:\n    print(ui.color_str('[INFO]', bold=True) + ' Writing dataset to \\'{}\\'...'.format(dataset_file_name))\nelse:\n    sys.exit(\n        ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) + ' Dataset \\'{}\\' already exists.'.format(dataset_file_name)\n    )\n\nR = dataset['R']\nz = dataset['z']\nF = dataset['F']\n\nlattice = dataset['lattice'] if 'lattice' in dataset else None\n\ntry:\n    with open(dataset_file_name, 'w') as file:\n\n        n = R.shape[0]\n        for i, r in enumerate(R):\n\n            e = np.squeeze(dataset['E'][i]) if 'E' in dataset else None\n            f = dataset['F'][i,:,:]\n            ext_xyz_str = io.generate_xyz_str(r, z, e=e, f=f, lattice=lattice) + '\\n'\n\n            file.write(ext_xyz_str)\n\n            progr = float(i) / (n - 1)\n            ui.callback(i, n - 1, disp_str='Exporting %d data points...' % n)\n            \nexcept IOError:\n    sys.exit(\"ERROR: Writing xyz file failed.\")\n\nprint()\n"
  },
  {
    "path": "scripts/sgdml_dataset_via_ase.py",
    "content": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2022 Stefan Chmiela\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nfrom __future__ import print_function\n\nimport argparse\nimport os\nimport sys\n\ntry:\n    from ase.io import read\nexcept ImportError:\n    raise ImportError('Optional ASE dependency not found! Please run \\'pip install sgdml[ase]\\' to install it.')\n\nimport numpy as np\n\nfrom sgdml import __version__\nfrom sgdml.utils import io, ui\n\nif sys.version[0] == '3':\n    raw_input = input\n\n\nparser = argparse.ArgumentParser(\n    description='Creates a dataset from any input format supported by ASE.'\n)\nparser.add_argument(\n    'dataset',\n    metavar='<dataset>',\n    type=argparse.FileType('r'),\n    help='path to input dataset file',\n)\nparser.add_argument(\n    '-o',\n    '--overwrite',\n    dest='overwrite',\n    action='store_true',\n    help='overwrite existing dataset file',\n)\nargs = parser.parse_args()\ndataset = args.dataset\n\n\nname = os.path.splitext(os.path.basename(dataset.name))[0]\ndataset_file_name = name + '.npz'\n\ndataset_exists = os.path.isfile(dataset_file_name)\nif dataset_exists and args.overwrite:\n    print(ui.color_str('[INFO]', bold=True) + ' Overwriting existing dataset file.')\nif not dataset_exists or args.overwrite:\n    print('Writing dataset to \\'{}\\'...'.format(dataset_file_name))\nelse:\n    sys.exit(\n        ui.color_str('[FAIL]', fore_color=ui.RED, bold=True)\n        + ' Dataset \\'{}\\' already exists.'.format(dataset_file_name)\n    )\n\nmols = read(dataset.name, index=':')\n\n# filter incomplete outputs from trajectory\nmols = [mol for mol in mols if mol.get_calculator() is not None]\n\nlattice, R, z, E, F = None, None, None, None, None\n\ncalc = mols[0].get_calculator()\n\nprint(\"\\rNumber geometries: {:,}\".format(len(mols)))\n#print(\"\\rAvailable properties: \" + ', '.join(calc.results))\nprint()\n\nif 'forces' not in calc.results:\n    sys.exit(\n        ui.color_str('[FAIL]', fore_color=ui.RED, bold=True)\n        + ' Forces are missing in the input file!'\n    )\n\nlattice = np.array(mols[0].get_cell().T)\nif not np.any(lattice):\n    print(\n        ui.color_str('[INFO]', bold=True)\n        + ' No lattice vectors specified.'\n    )\n    lattice = None\n\nZ = np.array([mol.get_atomic_numbers() for mol in mols])\nall_z_the_same = (Z == Z[0]).all()\nif not all_z_the_same:\n    sys.exit(\n        ui.color_str('[FAIL]', fore_color=ui.RED, bold=True)\n        + ' Order of atoms changes accross dataset.'\n    )\n\nR = np.array([mol.get_positions() for mol in mols])\nz = Z[0]\n\nif 'Energy' in mols[0].info:\n    E = np.array([float(mol.info['Energy']) for mol in mols])\nelse:\n    E = np.array([mol.get_potential_energy() for mol in mols])\nF = np.array([mol.get_forces() for mol in mols])\n\nprint('Please provide a name for this dataset. Otherwise the original filename will be reused.')\ncustom_name = raw_input('> ').strip()\nif custom_name != '':\n    name = custom_name\n\nprint('Please provide a descriptor for the level of theory used to create this dataset.')\ntheory = raw_input('> ').strip()\nif theory == '':\n    theory = 'unknown'\n\n# Base variables contained in every model file.\nbase_vars = {\n    'type': 'd',\n    'code_version': __version__,\n    'name': name,\n    'theory': theory,\n    'R': R,\n    'z': z,\n    'F': F,\n}\n\nbase_vars['F_min'], base_vars['F_max'] = np.min(F.ravel()), np.max(F.ravel())\nbase_vars['F_mean'], base_vars['F_var'] = np.mean(F.ravel()), np.var(F.ravel())\n\nprint('If you want to convert your original length unit, please provide a conversion factor (default: 1.0): ')\nR_to_new_unit = raw_input('> ').strip()\nif R_to_new_unit != '':\n    R_to_new_unit = float(R_to_new_unit)\nelse:\n    R_to_new_unit = 1.0\n\nprint('If you want to convert your original energy unit, please provide a conversion factor (default: 1.0): ')\nE_to_new_unit = raw_input('> ').strip()\nif E_to_new_unit != '':\n    E_to_new_unit = float(E_to_new_unit)\nelse:\n    E_to_new_unit = 1.0\n\nprint('Please provide a description of the length unit, e.g. \\'Ang\\' or \\'au\\': ')\nprint('Note: This string will be stored in the dataset file and passed on to models files for later reference.')\nr_unit = raw_input('> ').strip()\nif r_unit != '':\n    base_vars['r_unit'] = r_unit\n\nprint('Please provide a description of the energy unit, e.g. \\'kcal/mol\\' or \\'eV\\': ')\nprint('Note: This string will be stored in the dataset file and passed on to models files for later reference.')\ne_unit = raw_input('> ').strip()\nif e_unit != '':\n    base_vars['e_unit'] = e_unit\n\nif E is not None:\n    base_vars['E'] = E * E_to_new_unit\n    base_vars['E_min'], base_vars['E_max'] = np.min(E), np.max(E)\n    base_vars['E_mean'], base_vars['E_var'] = np.mean(E), np.var(E)\nelse:\n    print(ui.color_str('[INFO]', bold=True) + ' No energy labels found in dataset.')\n\nbase_vars['R'] *= R_to_new_unit\nbase_vars['F'] *= E_to_new_unit / R_to_new_unit\n\nif lattice is not None:\n    base_vars['lattice'] = lattice\n\nbase_vars['md5'] = io.dataset_md5(base_vars)\nnp.savez_compressed(dataset_file_name, **base_vars)\nprint(ui.color_str('[DONE]', fore_color=ui.GREEN, bold=True))\n"
  },
  {
    "path": "scripts/sgdml_datasets_from_model.py",
    "content": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018 Stefan Chmiela\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nfrom __future__ import print_function\n\nimport argparse\nimport os\nimport sys\n\nimport numpy as np\n\nfrom sgdml.utils import io, ui\n\nparser = argparse.ArgumentParser(\n    description='Extracts the training and test data subsets from a dataset that were used to construct a model.'\n)\nparser.add_argument(\n    'model',\n    metavar='<model_file>',\n    type=lambda x: io.is_file_type(x, 'model'),\n    help='path to model file',\n)\nparser.add_argument(\n    'dataset',\n    metavar='<dataset_file>',\n    type=lambda x: io.is_file_type(x, 'dataset'),\n    help='path to dataset file referenced in model',\n)\nparser.add_argument(\n    '-o',\n    '--overwrite',\n    dest='overwrite',\n    action='store_true',\n    help='overwrite existing files',\n)\nargs = parser.parse_args()\n\nmodel_path, model = args.model\ndataset_path, dataset = args.dataset\n\n\nfor s in ['train', 'valid']:\n\n    if dataset['md5'] != model['md5_' + s]:\n        sys.exit(\n            ui.fail_str('[FAIL]')\n            + ' Dataset fingerprint does not match the one referenced in model for \\'%s\\'.'\n            % s\n        )\n\n    idxs = model['idxs_' + s]\n    R = dataset['R'][idxs, :, :]\n    E = dataset['E'][idxs]\n    F = dataset['F'][idxs, :, :]\n\n    base_vars = {\n        'type': 'd',\n        'name': dataset['name'].astype(str),\n        'theory': dataset['theory'].astype(str),\n        'z': dataset['z'],\n        'R': R,\n        'E': E,\n        'F': F,\n    }\n    base_vars['md5'] = io.dataset_md5(base_vars)\n\n    subset_file_name = '%s_%s.npz' % (\n        os.path.splitext(os.path.basename(dataset_path))[0],\n        s,\n    )\n    file_exists = os.path.isfile(subset_file_name)\n    if file_exists and args.overwrite:\n        print(ui.info_str('[INFO]') + ' Overwriting existing model file.')\n    if not file_exists or args.overwrite:\n        np.savez_compressed(subset_file_name, **base_vars)\n        ui.callback(1, disp_str='Extracted %s dataset saved to \\'%s\\'' % (s, subset_file_name)) # DONE\n    else:\n        print(\n            ui.warn_str('[WARN]')\n            + ' %s dataset \\'%s\\' already exists.' % (s.capitalize(), subset_file_name)\n            + '\\n       Run \\'python %s -o %s %s\\' to overwrite.\\n'\n            % (os.path.basename(__file__), model_path, dataset_path)\n        )\n        sys.exit()\n"
  },
  {
    "path": "setup.cfg",
    "content": "[flake8]\nmax-complexity = 12\nignore = E501,W503,E741\nselect = C,E,F,W\n\n[isort]\nmulti_line_output = 3\ninclude_trailing_comma = 1\nline_length = 85\nsections = FUTURE,STDLIB,TYPING,THIRDPARTY,FIRSTPARTY,LOCALFOLDER\nknown_typing = typing, typing_extensions\nno_lines_before = TYPING\n"
  },
  {
    "path": "setup.py",
    "content": "import os\nimport re\nfrom io import open\nfrom setuptools import setup, find_packages\n\n\ndef get_property(property, package):\n    result = re.search(\n        r'{}\\s*=\\s*[\\'\"]([^\\'\"]*)[\\'\"]'.format(property),\n        open(package + '/__init__.py').read(),\n    )\n    return result.group(1)\n\n\nfrom os import path\n\nthis_dir = path.abspath(path.dirname(__file__))\nwith open(path.join(this_dir, 'README.md'), encoding='utf8') as f:\n    long_description = f.read()\n\n# Scripts\nscripts = []\nfor dirname, dirnames, filenames in os.walk('scripts'):\n    for filename in filenames:\n        if filename.endswith('.py'):\n            scripts.append(os.path.join(dirname, filename))\n\nsetup(\n    name='sgdml',\n    version=get_property('__version__', 'sgdml'),\n    description='Reference implementation of the GDML and sGDML force field models.',\n    long_description=long_description,\n    long_description_content_type='text/markdown',\n    classifiers=[\n        'Development Status :: 4 - Beta',\n        'Environment :: Console',\n        'Intended Audience :: Science/Research',\n        'Intended Audience :: Education',\n        'Intended Audience :: Developers',\n        'License :: OSI Approved :: MIT License',\n        'Operating System :: MacOS :: MacOS X',\n        'Operating System :: POSIX :: Linux',\n        'Programming Language :: Python :: 3.7',\n        'Topic :: Scientific/Engineering :: Chemistry',\n        'Topic :: Scientific/Engineering :: Physics',\n        'Topic :: Software Development :: Libraries :: Python Modules',\n    ],\n    url='http://www.sgdml.org',\n    author='Stefan Chmiela',\n    author_email='sgdml@chmiela.com',\n    license='LICENSE.txt',\n    packages=find_packages(),\n    install_requires=['torch >= 1.8', 'numpy >= 1.19.0', 'scipy >= 1.1.0', 'psutil', 'future'],\n    entry_points={\n        'console_scripts': ['sgdml=sgdml.cli:main', 'sgdml-get=sgdml.get:main']\n    },\n    extras_require={'ase': ['ase >= 3.16.2']},\n    scripts=scripts,\n    include_package_data=True,\n    zip_safe=False,\n)\n"
  },
  {
    "path": "sgdml/__init__.py",
    "content": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2019-2025 Stefan Chmiela\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\n__version__ = '1.0.3'\n\nMAX_PRINT_WIDTH = 100\nLOG_LEVELNAME_WIDTH = 7  # do not modify\n\n# more descriptive callback status\nDONE = 1\nNOT_DONE = 0\n\n\n# Logging\n\nimport copy\nimport logging\nimport re\nimport textwrap\n\nfrom .utils import ui\n\n\nclass ColoredFormatter(logging.Formatter):\n\n    LEVEL_COLORS = {\n        'DEBUG': (ui.CYAN, ui.BLACK),\n        'INFO': (ui.WHITE, ui.BLACK),\n        'DONE': (ui.GREEN, ui.BLACK),\n        'WARNING': (ui.YELLOW, ui.BLACK),\n        'ERROR': (ui.RED, ui.BLACK),\n        'CRITICAL': (ui.BLACK, ui.RED),\n    }\n\n    LEVEL_NAMES = {\n        'DEBUG': '[DEBG]',\n        'INFO': '[INFO]',\n        'DONE': '[DONE]',\n        'WARNING': '[WARN]',\n        'ERROR': '[FAIL]',\n        'CRITICAL': '[CRIT]',\n    }\n\n    def __init__(self, msg, use_color=True):\n\n        logging.Formatter.__init__(self, msg)\n        self.use_color = use_color\n\n    def format(self, record):\n\n        _record = copy.copy(record)\n        levelname = _record.levelname\n        msg = _record.msg\n\n        levelname = ui.color_str(\n            self.LEVEL_NAMES[levelname],\n            self.LEVEL_COLORS[levelname][0],\n            self.LEVEL_COLORS[levelname][1],\n            bold=True,\n        )\n\n        if _record.levelname != 'CRITICAL':\n            # wrap long messages (except for critical [i.e. exceptions, since they print a formatted traceback string])\n            msg = ui.wrap_str(msg)\n\n        # indent multiline strings after the first line\n        msg = ui.indent_str(msg, LOG_LEVELNAME_WIDTH)[LOG_LEVELNAME_WIDTH:]\n\n        _record.levelname = levelname\n        _record.msg = msg\n        return logging.Formatter.format(self, _record)\n\n\nclass ColoredLogger(logging.Logger):\n    def __init__(self, name):\n\n        logging.Logger.__init__(self, name, logging.DEBUG)\n\n        # add 'DONE' logging level\n        logging.DONE = logging.INFO + 1\n        logging.addLevelName(logging.DONE, 'DONE')\n\n        # only display levelname and message\n        formatter = ColoredFormatter('%(levelname)s %(message)s')\n\n        # this handler will write to sys.stderr by default\n        hd = logging.StreamHandler()\n        hd.setFormatter(formatter)\n        hd.setLevel(\n            logging.INFO\n        ) # control logging level here\n\n        self.addHandler(hd)\n        return\n\n    def done(self, msg, *args, **kwargs):\n\n        if self.isEnabledFor(logging.DONE):\n            self._log(logging.DONE, msg, args, **kwargs)\n\n\nlogging.setLoggerClass(ColoredLogger)\n"
  },
  {
    "path": "sgdml/cli.py",
    "content": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2022 Stefan Chmiela\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nfrom __future__ import print_function\n\nimport logging\nimport multiprocessing as mp\nimport argparse\nimport os\nimport shutil\nimport psutil\nimport sys\nimport traceback\nimport time\nfrom functools import partial\n\nimport numpy as np\nimport scipy as sp\n\ntry:\n    import torch\nexcept ImportError:\n    _has_torch = False\nelse:\n    _has_torch = True\n\ntry:\n    _torch_mps_is_available = torch.backends.mps.is_available()\nexcept AttributeError:\n    _torch_mps_is_available = False\n_torch_mps_is_available = False\n\ntry:\n    _torch_cuda_is_available = torch.cuda.is_available()\nexcept AttributeError:\n    _torch_cuda_is_available = False\n\ntry:\n    import ase\nexcept ImportError:\n    _has_ase = False\nelse:\n    _has_ase = True\n\nfrom . import __version__, DONE, NOT_DONE, MAX_PRINT_WIDTH\nfrom .predict import GDMLPredict\nfrom .train import GDMLTrain\nfrom .utils import io, ui\n\n# BASE_DIR = os.path.dirname(os.path.abspath(__file__))\nPACKAGE_NAME = 'sgdml'\n\nlog = logging.getLogger(__name__)\n\n\nclass AssistantError(Exception):\n    pass\n\n\ndef _print_splash(max_memory, max_processes, use_torch):\n\n    logo_str = r\"\"\"         __________  __  _____\n   _____/ ____/ __ \\/  |/  / /\n  / ___/ / __/ / / / /|_/ / /\n (__  ) /_/ / /_/ / /  / / /___\n/____/\\____/_____/_/  /_/_____/\"\"\"\n\n    can_update, latest_version = _check_update()\n\n    version_str = __version__\n    version_str += (\n        ' '\n        + ui.color_str(\n            ' Latest: ' + latest_version + ' ',\n            fore_color=ui.BLACK,\n            back_color=ui.YELLOW,\n            bold=True,\n        )\n        if can_update\n        else ''\n    )\n\n    max_memory_str = '{:d} GB(s) memory'.format(max_memory)\n    max_processes_str = '{:d} CPU(s)'.format(max_processes)\n    hardware_str = 'using {}, {}'.format(max_memory_str, max_processes_str)\n\n    if use_torch and _has_torch:\n\n        if _torch_cuda_is_available:\n            num_gpu = torch.cuda.device_count()\n            if num_gpu > 0:\n                hardware_str += ', {:d} GPU(s)'.format(num_gpu)\n        elif _torch_mps_is_available:\n            hardware_str += ', MPS enabled'\n\n    logo_str_split = logo_str.splitlines()\n    print('\\n'.join(logo_str_split[:-1]))\n    ui.print_two_column_str(logo_str_split[-1] + '  ' + version_str, hardware_str)\n\n    # Print update notice.\n    if can_update:\n        print(\n            '\\n'\n            + ui.color_str(\n                ' UPDATE AVAILABLE ',\n                fore_color=ui.BLACK,\n                back_color=ui.YELLOW,\n                bold=True,\n            )\n            + '\\n'\n            + '-' * MAX_PRINT_WIDTH\n        )\n        print(\n            'A new stable release version {} of this software is available.'.format(\n                latest_version\n            )\n        )\n        print(\n            'You can update your installation by running \\'pip install sgdml --upgrade\\'.'\n        )\n\n    _print_billboard()\n\n\ndef _check_update():\n\n    try:\n        from urllib.request import urlopen\n    except ImportError:\n        from urllib2 import urlopen\n\n    base_url = 'http://api.sgdml.org/'\n    url = '{}update.php?v={}'.format(base_url, __version__)\n\n    can_update, must_update = '0', '0'\n    latest_version = ''\n    try:\n        response = urlopen(url, timeout=1)\n        can_update, must_update, latest_version = response.read().decode().split(',')\n        response.close()\n    except:\n        pass\n\n    return can_update == '1', latest_version\n\n\ndef _print_billboard():\n\n    try:\n        from urllib.request import urlopen\n    except ImportError:\n        from urllib2 import urlopen\n\n    base_url = 'http://api.sgdml.org/'\n    url = '{}billboard.php'.format(base_url)\n\n    resp_str = ''\n    try:\n        response = urlopen(url, timeout=1)\n        resp_str = response.read().decode()\n        response.close()\n    except:\n        pass\n\n    bbs = None\n    try:\n        import json\n\n        bbs = json.loads(resp_str)\n    except:\n        pass\n\n    if bbs is not None:\n\n        for bb in bbs:\n\n            back_color = ui.WHITE\n            if bb['color'] == 'YELLOW':\n                back_color = ui.YELLOW\n            elif bb['color'] == 'GREEN':\n                back_color = ui.GREEN\n            elif bb['color'] == 'RED':\n                back_color = ui.RED\n            elif bb['color'] == 'CYAN':\n                back_color = ui.CYAN\n\n            print(\n                '\\n'\n                + ui.color_str(\n                    ' {} '.format(bb['title']),\n                    fore_color=ui.BLACK,\n                    back_color=back_color,\n                    bold=True,\n                )\n                + '\\n'\n                + '-' * MAX_PRINT_WIDTH\n            )\n\n            print(ui.wrap_str(bb['text'], width=MAX_PRINT_WIDTH - 2))\n\n\ndef _print_dataset_properties(dataset, title_str='Dataset properties'):\n\n    print(ui.color_str(title_str, bold=True))\n\n    n_mols, n_atoms, _ = dataset['R'].shape\n    print('  {:<18} \\'{}\\''.format('Name:', ui.unicode_str(dataset['name'])))\n    print('  {:<18} \\'{}\\''.format('Theory level:', ui.unicode_str(dataset['theory'])))\n    print('  {:<18} {:<d}'.format('Atoms:', n_atoms))\n\n    print('  {:<18} {:,} data points'.format('Size:', n_mols))\n\n    ui.print_lattice(dataset['lattice'] if 'lattice' in dataset else None)\n\n    if 'perms' in dataset:\n        ui.print_two_column_str(\n            '  {:<18} {}'.format('Symmetries:', len(dataset['perms'])),\n            'This dataset contains precomputed permutations.',\n        )\n\n    if 'E' in dataset:\n\n        e_unit = 'unknown unit'\n        if 'e_unit' in dataset:\n            e_unit = ui.unicode_str(dataset['e_unit'])\n\n        print('  Energies [{}]'.format(e_unit))\n        if 'E_min' in dataset and 'E_max' in dataset:\n            E_min, E_max = dataset['E_min'], dataset['E_max']\n        else:\n            E_min, E_max = np.min(dataset['E']), np.max(dataset['E'])\n        E_range_str = ui.gen_range_str(E_min, E_max)\n        ui.print_two_column_str(\n            '    {:<16} {}'.format('Range:', E_range_str), 'min |-- range --| max'\n        )\n\n        E_mean = dataset['E_mean'] if 'E_mean' in dataset else np.mean(dataset['E'])\n        print('    {:<16} {:<.3f}'.format('Mean:', E_mean))\n\n        E_var = dataset['E_var'] if 'E_var' in dataset else np.var(dataset['E'])\n        print('    {:<16} {:<.3f}'.format('Variance:', E_var))\n    else:\n        print('  {:<18} {}'.format('Energies:', 'n/a'))\n\n    f_unit = 'unknown unit'\n    if 'r_unit' in dataset and 'e_unit' in dataset:\n        f_unit = (\n            ui.unicode_str(dataset['e_unit']) + '/' + ui.unicode_str(dataset['r_unit'])\n        )\n\n    print('  Forces [{}]'.format(f_unit))\n\n    if 'F_min' in dataset and 'F_max' in dataset:\n        F_min, F_max = dataset['F_min'], dataset['F_max']\n    else:\n        F_min, F_max = np.min(dataset['F'].ravel()), np.max(dataset['F'].ravel())\n    F_range_str = ui.gen_range_str(F_min, F_max)\n    ui.print_two_column_str(\n        '    {:<16} {}'.format('Range:', F_range_str), 'min |-- range --| max'\n    )\n\n    F_mean = dataset['F_mean'] if 'F_mean' in dataset else np.mean(dataset['F'].ravel())\n    print('    {:<16} {:<.3f}'.format('Mean:', F_mean))\n\n    F_var = dataset['F_var'] if 'F_var' in dataset else np.var(dataset['F'].ravel())\n    print('    {:<16} {:<.3f}'.format('Variance:', F_var))\n\n    print('  {:<18} {}'.format('Fingerprint:', ui.unicode_str(dataset['md5'])))\n\n    # if 'code_version' in dataset:\n    #    print('  {:<18} sGDML {}'.format('Created with:', ui.unicode_str(dataset['code_version'])))\n\n    idx = np.random.choice(n_mols, 1)[0]\n    r = dataset['R'][idx, :, :]\n    e = np.squeeze(dataset['E'][idx]) if 'E' in dataset else None\n    f = dataset['F'][idx, :, :]\n    lattice = dataset['lattice'] if 'lattice' in dataset else None\n\n    print(\n        '\\n'\n        + ui.color_str('Example geometry', fore_color=ui.WHITE, bold=True)\n        + ' (point no. {:,}, chosen randomly)'.format(idx + 1)\n    )\n\n    xyz_info_str = 'Copy & paste the string below into Jmol (www.jmol.org), Avogadro (www.avogadro.cc), etc. to visualize one of the geometries from this dataset. A new example will be drawn on each run.'\n    xyz_info_str = ui.wrap_str(xyz_info_str, width=MAX_PRINT_WIDTH - 2)\n    xyz_info_str = ui.indent_str(xyz_info_str, 2)\n    print(xyz_info_str + '\\n')\n\n    xyz_str = io.generate_xyz_str(r, dataset['z'], e=e, f=f, lattice=lattice)\n    xyz_str = ui.indent_str(xyz_str, 2)\n\n    cut_str = '---- COPY HERE '\n    cut_str_reps = int(np.floor((MAX_PRINT_WIDTH - 6) / len(cut_str)))\n    cutline_str = ui.color_str(\n        '  -' + cut_str * cut_str_reps + '-----', fore_color=ui.GRAY\n    )\n\n    print(cutline_str)\n    print(xyz_str)\n    print(cutline_str)\n\n\ndef _print_task_properties_reduced(\n    use_sym, use_E, use_E_cstr, title_str='Task properties'\n):\n\n    print(ui.color_str(title_str, bold=True))\n\n    energy_fix_str = (\n        (\n            'pointwise energy constraints'\n            if use_E_cstr\n            else 'global integration constant'\n        )\n        if use_E\n        else 'none'\n    )\n    print('  {:<16} {}'.format('Energy offset:', energy_fix_str))\n\n    print(\n        '  {:<16} {}'.format(\n            'Symmetries:', 'include (sGDML)' if use_sym else 'ignore (GDML)'\n        )\n    )\n\n\ndef _print_task_properties(task, title_str='Task properties'):\n\n    print(ui.color_str(title_str, bold=True))\n\n    print('  {:<18}'.format('Dataset'))\n    print('    {:<16} \\'{}\\''.format('Name:', ui.unicode_str(task['dataset_name'])))\n    print(\n        '    {:<16} \\'{}\\''.format(\n            'Theory level:', ui.unicode_str(task['dataset_theory'])\n        )\n    )\n\n    n_atoms = len(task['z'])\n    print('    {:<16} {:<d}'.format('Atoms:', n_atoms))\n\n    ui.print_lattice(task['lattice'] if 'lattice' in task else None, inset=True)\n\n    print('  {:<18} {:<d}'.format('Symmetries:', len(task['perms'])))\n\n    print('  {:<18}'.format('Hyper-parameters'))\n    print('    {:<16} {:<d}'.format('Length scale:', task['sig']))\n\n    if 'lam' in task:\n        print('    {:<16} {:<.0e}'.format('Regularization:', task['lam']))\n\n    # if 'solver_name' in task:\n    #     print('  {:<18}'.format('Solver configuration'))\n    #     print('    {:<16} \\'{}\\''.format('Type:', task['solver_name']))\n\n    #     if task['solver_name'] == 'cg':\n\n    #         if 'solver_tol' in task:\n    #             print('    {:<16} {:<.0e}'.format('Tolerance:', task['solver_tol']))\n\n    #         if 'n_inducing_pts_init' in task:\n    #             print(\n    #                 '    {:<16} {:<d}'.format(\n    #                     'Inducing points:', task['n_inducing_pts_init']\n    #                 )\n    #             )\n    # else:\n    #     print('  {:<18} {}'.format('Solver:', 'unknown'))\n\n    n_train = len(task['idxs_train'])\n    ui.print_two_column_str(\n        '  {:<18} {:,} points'.format('Train on:', n_train),\n        'from \\'' + ui.unicode_str(task['md5_train']) + '\\'',\n    )\n\n    n_valid = len(task['idxs_valid'])\n    ui.print_two_column_str(\n        '  {:<18} {:,} points'.format('Validate on:', n_valid),\n        'from \\'' + ui.unicode_str(task['md5_valid']) + '\\'',\n    )\n\n    # print('  {:<18}'.format('Estimated memory requirement (min.)'))\n\n    # mem_kernel_mat_const = 0\n    # mem_precond_const = 0\n    # print(\n    #    '    {:<16} {}'.format(\n    #        'CPU:', ui.gen_memory_str(mem_kernel_mat_const + mem_precond_const)\n    #    )\n    # )\n    # print('      {:<14} {}'.format('Kernel matrix:', ui.gen_memory_str(mem_kernel_mat_const))\n    # print('      {:<14} {}'.format('Precond. factor:', ui.gen_memory_str(mem_precond_const)))\n\n    # mem_torch_assemble = 0\n    # mem_torch_eval = 0\n    # print(\n    #    '    {:<16} {}'.format(\n    #        'GPU:', ui.gen_memory_str(mem_torch_assemble + mem_torch_eval)\n    #    )\n    # )\n    # print('      {:<14} {}'.format('Kernel matrix assembly:', ui.gen_memory_str(mem_torch_assemble)))\n    # print('      {:<14} {}'.format('Model evaluation:', ui.gen_memory_str(mem_torch_eval)))\n\n\ndef _print_model_properties(model, title_str='Model properties'):\n\n    print(ui.color_str(title_str, bold=True))\n\n    print('  {:<18}'.format('Dataset'))\n    print('    {:<16} \\'{}\\''.format('Name:', ui.unicode_str(model['dataset_name'])))\n    print(\n        '    {:<16} \\'{}\\''.format(\n            'Theory level:', ui.unicode_str(model['dataset_theory'])\n        )\n    )\n\n    n_atoms = len(model['z'])\n    print('    {:<16} {:<d}'.format('Atoms:', n_atoms))\n\n    ui.print_lattice(model['lattice'] if 'lattice' in model else None, inset=True)\n\n    print('  {:<18} {:<d}'.format('Symmetries:', len(model['perms'])))\n\n    print('  {:<18}'.format('Hyper-parameters'))\n    print('    {:<16} {:<d}'.format('Length scale:', model['sig']))\n\n    if 'lam' in model:\n        print('    {:<16} {:<.0e}'.format('Regularization:', model['lam']))\n\n    if 'solver_name' in model:\n        print('  {:<18}'.format('Solver'))\n        print('    {:<16} \\'{}\\''.format('Type:', model['solver_name']))\n\n        if model['solver_name'] == 'cg':\n\n            if 'solver_tol' in model:\n                ui.print_two_column_str(\n                    '    {:<16} {:<.0e}'.format('Tolerance:', model['solver_tol']),\n                    'iterate until: norm(K*alpha - y) <= tol*norm(y) = {:<.0e}'.format(\n                        model['solver_tol'] * model['norm_y_train']\n                    ),\n                )\n\n                if 'solver_resid' in model:\n                    is_conv = (\n                        model['solver_resid']\n                        <= model['solver_tol'] * model['norm_y_train']\n                    )\n                    print(\n                        '    {:<16} {:<.0e}{}'.format(\n                            'Converged to:',\n                            model['solver_resid'],\n                            '' if is_conv else ' (NOT CONVERGED)',\n                        )\n                    )\n\n            if 'solver_iters' in model:\n                print('    {:<16} {:<d}'.format('Iterations:', model['solver_iters']))\n\n            if 'inducing_pts_idxs' in model:\n                n_inducing_pts = len(model['inducing_pts_idxs']) // (3 * n_atoms)\n                ui.print_two_column_str(\n                    '    {:<16} {:<d}'.format('Inducing points:', n_inducing_pts),\n                    'inducing columns: {:<d} (multiplied by DOF)'.format(\n                        n_inducing_pts * n_atoms * 3\n                    ),\n                )\n    else:\n        print('  {:<18} {}'.format('Solver:', 'unknown'))\n\n    n_train = len(model['idxs_train'])\n    ui.print_two_column_str(\n        '  {:<18} {:,} points'.format('Trained on:', n_train),\n        'from \\'' + ui.unicode_str(model['md5_train']) + '\\'',\n    )\n\n    use_E_cstr = 'alphas_E' in model\n    print(\n        '    {:<16} {}'.format(\n            'Energy offset',\n            '[{}] global integration constant'.format('x' if not use_E_cstr else ' '),\n        )\n    )\n    ui.print_two_column_str(\n        '                     {:<16}'.format(\n            '[{}] pointwise energy constraints'.format('x' if use_E_cstr else ' ')\n        ),\n        'using \\'--E_cstr\\'',\n    )\n\n    if model['use_E']:\n        e_err = model['e_err'].item()\n    f_err = model['f_err'].item()\n\n    n_valid = len(model['idxs_valid'])\n    is_valid = not np.isnan(f_err['mae']) and not np.isnan(f_err['rmse'])\n    ui.print_two_column_str(\n        '  {:<18} {}{:,} points'.format(\n            'Validated on:', '' if is_valid else '[pending] ', n_valid\n        ),\n        'from \\'' + ui.unicode_str(model['md5_valid']) + '\\'',\n    )\n\n    n_test = int(model['n_test'])\n    is_test = n_test > 0\n    if is_test:\n        ui.print_two_column_str(\n            '  {:<18} {:,} points'.format('Tested on:', n_test),\n            'from \\'' + ui.unicode_str(model['md5_test']) + '\\'',\n        )\n    else:\n        print('  {:<18} {}'.format('Test:', '[pending]'))\n\n    e_unit = 'unknown unit'\n    f_unit = 'unknown unit'\n    if 'r_unit' in model and 'e_unit' in model:\n        e_unit = model['e_unit']\n        f_unit = ui.unicode_str(model['e_unit']) + '/' + ui.unicode_str(model['r_unit'])\n\n    if is_valid:\n        action_str = 'Validation' if not is_valid else 'Expected test'\n        print('  {:<18}'.format('{} errors (MAE/RMSE)'.format(action_str)))\n        if model['use_E']:\n            print(\n                '    {:<16} {:>.4f}/{:>.4f} [{}]'.format(\n                    'Energy:', e_err['mae'], e_err['rmse'], e_unit\n                )\n            )\n        print(\n            '    {:<16} {:>.4f}/{:>.4f} [{}]'.format(\n                'Forces:', f_err['mae'], f_err['rmse'], f_unit\n            )\n        )\n\n\ndef _print_next_step(\n    prev_step, task_dir=None, model_dir=None, model_files=None, dataset_path=None\n):\n\n    if prev_step == 'create':\n\n        assert task_dir is not None\n\n        ui.print_step_title(\n            'NEXT STEP',\n            '{} train {} <valid_dataset_file>'.format(PACKAGE_NAME, task_dir),\n            underscore=False,\n        )\n\n    elif prev_step == 'train' or prev_step == 'validate' or prev_step == 'resume':\n\n        assert model_dir is not None and model_files is not None\n\n        if dataset_path is None:\n            dataset_path = '<test_dataset_file>'\n\n        n_models = len(model_files)\n        if n_models == 1:\n            model_file_path = os.path.join(model_dir, model_files[0])\n            ui.print_step_title(\n                'NEXT STEP',\n                '{} test {} {} [<n_test>]'.format(\n                    PACKAGE_NAME, model_file_path, dataset_path\n                ),\n                underscore=False,\n            )\n        else:\n            ui.print_step_title(\n                'NEXT STEP',\n                '{} select {}'.format(PACKAGE_NAME, model_dir),\n                underscore=False,\n            )\n\n    elif prev_step == 'select':\n\n        assert model_files is not None\n\n        ui.print_step_title(\n            'NEXT STEP',\n            '{} test {} <test_dataset_file> [<n_test>]'.format(\n                PACKAGE_NAME, model_files[0]\n            ),\n            underscore=False,\n        )\n\n    else:\n        raise AssistantError('Unexpected previous step string.')\n\n\ndef all(\n    dataset,\n    valid_dataset,\n    test_dataset,\n    n_train,\n    n_valid,\n    n_test,\n    sigs,\n    gdml,\n    use_E,\n    use_E_cstr,\n    lazy_training,\n    overwrite,\n    max_memory,\n    max_processes,\n    use_torch,\n    task_dir=None,\n    model_file=None,\n    perms_from_arg=None,\n    **kwargs\n):\n\n    print(\n        '\\n'\n        + ui.color_str(' STEP 0 ', fore_color=ui.BLACK, back_color=ui.WHITE, bold=True)\n        + ' Dataset(s)\\n'\n        + '-' * MAX_PRINT_WIDTH\n    )\n\n    _, dataset_extracted = dataset\n    _print_dataset_properties(dataset_extracted, title_str='Properties')\n\n    if valid_dataset is None:\n        valid_dataset = dataset\n    else:\n        _, valid_dataset_extracted = valid_dataset\n        print()\n        _print_dataset_properties(\n            valid_dataset_extracted, title_str='Properties (validation dataset)'\n        )\n\n        if not np.array_equal(dataset_extracted['z'], valid_dataset_extracted['z']):\n            raise AssistantError(\n                'Atom composition or order in validation dataset does not match the one in bulk dataset.'\n            )\n\n    if test_dataset is None:\n        test_dataset = dataset\n    else:\n        _, test_dataset_extracted = test_dataset\n        _print_dataset_properties(\n            test_dataset_extracted, title_str='Properties (test dataset)'\n        )\n\n        if not np.array_equal(dataset_extracted['z'], test_dataset_extracted['z']):\n            raise AssistantError(\n                'Atom composition or order in test dataset does not match the one in bulk dataset.'\n            )\n\n    ui.print_step_title('STEP 1', 'Cross-validation task creation')\n    task_dir = create(\n        dataset,\n        valid_dataset,\n        n_train,\n        n_valid,\n        sigs,\n        gdml,\n        use_E,\n        use_E_cstr,\n        overwrite,\n        task_dir,\n        perms_from_arg=perms_from_arg,\n        **kwargs\n    )\n\n    ui.print_step_title('STEP 2', 'Training and validation')\n    task_dir_arg = io.is_dir_with_file_type(task_dir, 'task')\n    model_dir_or_file_path = train(\n        task_dir_arg,\n        valid_dataset,\n        lazy_training,\n        overwrite,\n        max_memory,\n        max_processes,\n        use_torch,\n        **kwargs\n    )\n\n    model_dir_arg = io.is_dir_with_file_type(\n        model_dir_or_file_path, 'model', or_file=True\n    )\n\n    _, model_file_names = model_dir_arg\n    if len(model_file_names) == 0:\n        raise AssistantError(\n            'No trained models found!'\n            + ('\\nTry turning turning off \\'--lazy\\'-mode.' if lazy_training else '')\n        )\n\n    ui.print_step_title('STEP 3', 'Hyper-parameter selection')\n    model_file_name = select(model_dir_arg, overwrite, model_file, **kwargs)\n\n    # Have all tasks been trained?\n    _, task_file_names = task_dir_arg\n    if len(task_file_names) > len(model_file_names):\n        log.warning(\n            'Not all training tasks have been completed! The model selected here might not be optimal.'\n            + ('\\nTry turning turning off \\'--lazy\\'-mode.' if lazy_training else '')\n        )\n\n    ui.print_step_title('STEP 4', 'Testing')\n    model_dir_arg = io.is_dir_with_file_type(model_file_name, 'model', or_file=True)\n    test(\n        model_dir_arg,\n        test_dataset,\n        n_test,\n        overwrite=False,\n        max_memory=max_memory,\n        max_processes=max_processes,\n        use_torch=use_torch,\n        **kwargs\n    )\n\n    print(\n        '\\n'\n        + ui.color_str('  DONE  ', fore_color=ui.BLACK, back_color=ui.GREEN, bold=True)\n        + ' Training assistant finished sucessfully.'\n    )\n    print('         This is your model file: \\'{}\\''.format(model_file_name))\n\n\n# if training job exists and is a subset of the requested cv range, add new tasks\n# otherwise, if new range is different or smaller, fail\ndef create(  # noqa: C901\n    dataset,\n    valid_dataset,\n    n_train,\n    n_valid,\n    sigs,\n    gdml,\n    use_E,\n    use_E_cstr,\n    overwrite,\n    task_dir=None,\n    perms_from_arg=None,\n    command=None,\n    **kwargs\n):\n\n    has_valid_dataset = not (valid_dataset is None or valid_dataset == dataset)\n\n    dataset_path, dataset = dataset\n    n_data = dataset['F'].shape[0]\n\n    func_called_directly = (\n        command == 'create'\n    )  # has this function been called from command line or from 'all'?\n    if func_called_directly:\n        ui.print_step_title('TASK CREATION')\n        _print_dataset_properties(dataset)\n        print()\n\n    _print_task_properties_reduced(use_sym=not gdml, use_E=use_E, use_E_cstr=use_E_cstr)\n    print()\n\n    if n_data < n_train:\n        raise AssistantError(\n            'Dataset only contains {} points, can not train on {}.'.format(\n                n_data, n_train\n            )\n        )\n\n    if not has_valid_dataset:\n        valid_dataset_path, valid_dataset = dataset_path, dataset\n        if n_data - n_train < n_valid:\n            raise AssistantError(\n                'Dataset only contains {} points, can not train on {} and validate on {}.'.format(\n                    n_data, n_train, n_valid\n                )\n            )\n    else:\n        valid_dataset_path, valid_dataset = valid_dataset\n        n_valid_data = valid_dataset['R'].shape[0]\n        if n_valid_data < n_valid:\n            raise AssistantError(\n                'Validation dataset only contains {} points, can not validate on {}.'.format(\n                    n_data, n_valid\n                )\n            )\n\n    if sigs is None:\n        log.info(\n            'Kernel hyper-parameter sigma (length scale) was automatically set to range \\'10:10:100\\'.'\n        )\n        sigs = list(range(10, 100, 10))  # default range\n\n    if task_dir is None:\n        task_dir = io.train_dir_name(\n            dataset,\n            n_train,\n            use_sym=not gdml,\n            use_E=use_E,\n            use_E_cstr=use_E_cstr,\n        )\n\n    task_file_names = []\n    if os.path.exists(task_dir):\n        if overwrite:\n            log.info('Overwriting existing training directory')\n            shutil.rmtree(task_dir, ignore_errors=True)\n            os.makedirs(task_dir)\n        else:\n            if io.is_task_dir_resumeable(\n                task_dir, dataset, valid_dataset, n_train, n_valid, sigs, gdml\n            ):\n                log.info(\n                    'Resuming existing hyper-parameter search in \\'{}\\'.'.format(\n                        task_dir\n                    )\n                )\n\n                # Get all task file names.\n                try:\n                    _, task_file_names = io.is_dir_with_file_type(task_dir, 'task')\n                except Exception:\n                    pass\n            else:\n                raise AssistantError(\n                    'Unfinished hyper-parameter search found in \\'{}\\'.\\n'.format(\n                        task_dir\n                    )\n                    + 'Run \\'%s %s -o %s %d %d -s %s\\' to overwrite.'\n                    % (\n                        PACKAGE_NAME,\n                        command,\n                        dataset_path,\n                        n_train,\n                        n_valid,\n                        ' '.join(str(s) for s in sigs),\n                    )\n                )\n    else:\n        os.makedirs(task_dir)\n\n    if task_file_names:\n\n        with np.load(\n            os.path.join(task_dir, task_file_names[0]), allow_pickle=True\n        ) as task:\n            tmpl_task = dict(task)\n    else:\n        if not use_E:\n            log.info(\n                'Energy labels will be ignored for training.\\n'\n                + 'Note: If available in the dataset file, the energy labels will however still be used to generate stratified training, test and validation datasets. Otherwise a random sampling is used.'\n            )\n\n        if 'E' not in dataset:\n            log.warning(\n                'Training dataset will be sampled with no guidance from energy labels (i.e. randomly)!'\n            )\n\n        if 'E' not in valid_dataset:\n            log.warning(\n                'Validation dataset will be sampled with no guidance from energy labels (i.e. randomly)!\\n'\n                + 'Note: Larger validation datasets are recommended due to slower convergence of the error.'\n            )\n\n        if ('lattice' in dataset) ^ ('lattice' in valid_dataset):\n            log.error('One of the datasets specifies lattice vectors and one does not!')\n            # TODO: stop program?\n\n        if 'lattice' in dataset or 'lattice' in valid_dataset:\n            log.info(\n                'Lattice vectors found in dataset: applying periodic boundary conditions.'\n            )\n\n        perms = None\n        if perms_from_arg is not None:\n\n            _, perms_from = perms_from_arg\n            if 'perms' in perms_from:\n                perms = perms_from['perms']\n            else:\n                raise AssistantError(\n                    'Provided permutation file does not contain any (looking for \\'perms\\'-key).'\n                )\n\n        gdml_train = (\n            GDMLTrain()\n        )  # No process number of memory restrictions necessary here.\n        try:\n            tmpl_task = gdml_train.create_task(\n                dataset,\n                n_train,\n                valid_dataset,\n                n_valid,\n                sig=1,\n                perms=perms,\n                use_sym=not gdml,\n                use_E=use_E,\n                use_E_cstr=use_E_cstr,\n                callback=ui.callback,\n            )  # template task\n        except:\n            print()\n            log.critical(traceback.format_exc())\n            print()\n            os._exit(1)\n\n    n_written = 0\n    for sig in sigs:\n        tmpl_task['sig'] = sig\n        task_file_name = io.task_file_name(tmpl_task)\n        task_path = os.path.join(task_dir, task_file_name)\n\n        if os.path.isfile(task_path):\n            log.info('Skipping existing task \\'{}\\'.'.format(task_file_name))\n        else:\n            np.savez_compressed(task_path, **tmpl_task)\n            n_written += 1\n    if n_written > 0:\n        log.done(\n            'Writing {:d}/{:d} task(s) with m={} training points each'.format(\n                n_written, len(sigs), tmpl_task['R_train'].shape[0]\n            )\n        )\n\n    if func_called_directly:\n        _print_next_step('create', task_dir=task_dir)\n\n    return task_dir\n\n\ndef train(\n    task_dir,\n    valid_dataset,\n    lazy_training,\n    overwrite,\n    max_memory,\n    max_processes,\n    use_torch,\n    command=None,\n    **kwargs\n):\n\n    task_dir, task_file_names = task_dir\n    n_tasks = len(task_file_names)\n\n    func_called_directly = (\n        command == 'train'\n    )  # Has this function been called from command line or from 'all'?\n    if func_called_directly:\n        ui.print_step_title('MODEL TRAINING')\n\n    def save_progr_callback(\n        unconv_model, unconv_model_path=None\n    ):  # Saves current (unconverged) model during iterative training\n\n        if unconv_model_path is None:\n            log.critical(\n                'Path for unconverged model not set in \\'save_progr_callback\\'.'\n            )\n            print()\n            os._exit(1)\n\n        np.savez_compressed(unconv_model_path, **unconv_model)\n\n    try:\n        gdml_train = GDMLTrain(\n            max_memory=max_memory, max_processes=max_processes, use_torch=use_torch\n        )\n    except:\n        print()\n        log.critical(traceback.format_exc())\n        print()\n        os._exit(1)\n\n    prev_valid_err = -1\n    has_converged_once = False\n\n    for i, task_file_name in enumerate(task_file_names):\n\n        task_file_path = os.path.join(task_dir, task_file_name)\n        with np.load(task_file_path, allow_pickle=True) as task:\n\n            if n_tasks > 1:\n                if i > 0:\n                    print()\n\n                n_train = len(task['idxs_train'])\n                n_valid = len(task['idxs_valid'])\n                ui.print_two_column_str(\n                    ui.color_str('Task {:d} of {:d}'.format(i + 1, n_tasks), bold=True),\n                    '{:,} + {:,} points (training + validation), sigma (length scale): {}'.format(\n                        n_train, n_valid, task['sig']\n                    ),\n                )\n\n            model_file_name = io.model_file_name(task, is_extended=False)\n            model_file_path = os.path.join(task_dir, model_file_name)\n\n            # is_conv = True\n            # valid_errs = None\n            # is_model_validated = False\n            if not overwrite and os.path.isfile(\n                model_file_path\n            ):  # Train model found, validate if necessary\n                log.info(\n                    'Model \\'{}\\' already exists.'.format(model_file_name)\n                    + (\n                        '\\nRun \\'{} train -o {}\\' to overwrite.'.format(\n                            PACKAGE_NAME, task_file_path\n                        )\n                        if func_called_directly\n                        else ''\n                    )\n                )\n\n                model_path = os.path.join(task_dir, model_file_name)\n                _, model = io.is_file_type(model_path, 'model')\n\n                e_err = {'mae': 0.0, 'rmse': 0.0}\n                if model['use_E']:\n                    e_err = model['e_err'].item()\n                f_err = model['f_err'].item()\n\n                is_conv = True\n                if 'solver_resid' in model:\n                    is_conv = (\n                        model['solver_resid']\n                        <= model['solver_tol'] * model['norm_y_train']\n                    )\n\n                is_model_validated = not (\n                    np.isnan(f_err['mae']) or np.isnan(f_err['rmse'])\n                )\n                if is_model_validated:\n\n                    disp_str = (\n                        'energy %.3f/%.3f, ' % (e_err['mae'], e_err['rmse'])\n                        if model['use_E']\n                        else ''\n                    )\n                    disp_str += 'forces %.3f/%.3f' % (f_err['mae'], f_err['rmse'])\n                    disp_str = 'Validation errors (MAE/RMSE): ' + disp_str\n                    ui.callback(1, 1, disp_str=disp_str)\n\n                    valid_errs = [f_err['rmse']]\n\n            else:  # Train and validate model\n\n                # Check if training this task has been attempted before.\n                if lazy_training and n_tasks > 1:\n                    if 'tried_training' in task and task['tried_training']:\n                        log.warning(\n                            'Skipping task, because it has been tried before (without success).'\n                        )\n                        continue\n\n                # Record in task file that there was a training attempt.\n                task = dict(task)\n                task['tried_training'] = True\n                np.savez_compressed(task_file_path, **task)\n\n                n_train, n_atoms = task['R_train'].shape[:2]\n\n                unconv_model_file = '_unconv_{}'.format(model_file_name)\n                unconv_model_path = os.path.join(task_dir, unconv_model_file)\n\n                try:\n                    model = gdml_train.train(\n                        task,\n                        partial(\n                            save_progr_callback, unconv_model_path=unconv_model_path\n                        ),\n                        ui.callback,\n                    )\n                except:\n                    print()\n                    log.critical(traceback.format_exc())\n                    print()\n                    os._exit(1)\n                else:\n                    if func_called_directly:\n                        log.done('Writing model to file \\'{}\\''.format(model_file_path))\n                    np.savez_compressed(model_file_path, **model)\n\n                    # Delete temporary model, if one exists.\n                    unconv_model_exists = os.path.isfile(unconv_model_path)\n                    if unconv_model_exists:\n                        os.remove(unconv_model_path)\n\n                is_model_validated = False\n\n            if not is_model_validated:\n\n                if (\n                    n_tasks == 1\n                ):  # Only validate if there is more than one training task.\n                    log.info(\n                        'Skipping validation step as there is only one model to validate.'\n                    )\n                    break\n\n                # Validate model.\n                model_dir = (task_dir, [model_file_name])\n                valid_errs = test(\n                    model_dir,\n                    valid_dataset,\n                    -1,  # n_test = -1 -> validation mode\n                    overwrite,\n                    max_memory,\n                    max_processes,\n                    use_torch,\n                    command,\n                    **kwargs\n                )\n\n                is_conv = True\n                if 'solver_resid' in model:\n                    is_conv = (\n                        model['solver_resid']\n                        <= model['solver_tol'] * model['norm_y_train']\n                    )\n\n            has_converged_once = has_converged_once or is_conv\n            if (\n                has_converged_once\n                and prev_valid_err != -1\n                and prev_valid_err < valid_errs[0]\n            ):\n                print()\n                log.info(\n                    'Skipping remaining training tasks, as validation error is rising again.'\n                )\n                break\n\n            prev_valid_err = valid_errs[0]\n\n    model_dir_or_file_path = model_file_path if n_tasks == 1 else task_dir\n    if func_called_directly:\n\n        model_dir_arg = io.is_dir_with_file_type(\n            model_dir_or_file_path, 'model', or_file=True\n        )\n        model_dir, model_files = model_dir_arg\n        _print_next_step('train', model_dir=model_dir, model_files=model_files)\n\n    return model_dir_or_file_path  # model directory or file\n\n\ndef _batch(iterable, n=1):\n    l = len(iterable)\n    for ndx in range(0, l, n):\n        yield iterable[ndx : min(ndx + n, l)]\n\n\ndef _online_err(err, size, n, mae_n_sum, rmse_n_sum):\n\n    err = np.abs(err)\n\n    mae_n_sum += np.sum(err) / size\n    mae = mae_n_sum / n\n\n    rmse_n_sum += np.sum(err**2) / size\n    rmse = np.sqrt(rmse_n_sum / n)\n\n    return mae, mae_n_sum, rmse, rmse_n_sum\n\n\ndef resume(\n    model,\n    dataset,\n    valid_dataset,\n    overwrite,\n    max_memory,\n    max_processes,\n    use_torch,\n    command=None,\n    **kwargs\n):\n\n    model_path, model = model\n    dataset_path, dataset = dataset\n\n    valid_dataset_arg = valid_dataset\n    valid_dataset_path, valid_dataset = valid_dataset\n\n    ui.print_step_title('RESUME TRAINING')\n    _print_model_properties(model, title_str='Model properties (initial)')\n    print()\n\n    if dataset['md5'] != model['md5_train']:\n        raise AssistantError(\n            'Fingerprint of provided training dataset does not match the one specified in model file.'\n        )\n    if valid_dataset['md5'] != model['md5_valid']:\n        raise AssistantError(\n            'Fingerprint of provided validation dataset does not match the one specified in model file.'\n        )\n\n    if model['solver_name'] == 'analytic':\n        raise AssistantError(\n            'This model was trained using a matrix decomposition method and thus already converged to the highest possible accuracy! It does not make sense to resume training in this case.'\n        )\n    elif 'solver_resid' in model and 'solver_tol' in model:\n        if model['solver_resid'] > model['solver_tol'] * model['norm_y_train']:\n\n            gdml_train = GDMLTrain(\n                max_memory=max_memory, max_processes=max_processes, use_torch=use_torch\n            )\n            try:\n                task = gdml_train.create_task_from_model(\n                    model,\n                    dataset,\n                )\n            except:\n                print()\n                log.critical(traceback.format_exc())\n                print()\n                os._exit(1)\n            del gdml_train\n\n            def save_progr_callback(\n                unconv_model,\n            ):  # saves current (unconverged) model during iterative training\n                np.savez_compressed(model_path, **unconv_model)\n\n            try:\n                gdml_train = GDMLTrain(\n                    max_memory=max_memory,\n                    max_processes=max_processes,\n                    use_torch=use_torch,\n                )\n            except:\n                print()\n                log.critical(traceback.format_exc())\n                print()\n                os._exit(1)\n\n            try:\n                model = gdml_train.train(\n                    task, save_progr_callback=save_progr_callback, callback=ui.callback\n                )\n            except:\n                print()\n                log.critical(traceback.format_exc())\n                print()\n                os._exit(1)\n            else:\n                log.done('Model parameters have been updated.')\n                np.savez_compressed(model_path, **model)\n\n        else:\n            log.warning('Model is already converged to the specified tolerance.')\n\n    # Validate model.\n    model_dir, model_file_name = os.path.split(model_path)\n    model_dir_arg = (model_dir, [model_file_name])\n\n    valid_errs = test(\n        model_dir_arg,\n        valid_dataset_arg,\n        -1,  # n_test = -1 -> validation mode\n        overwrite,\n        max_memory,\n        max_processes,\n        use_torch,\n        command,\n        **kwargs\n    )\n\n    _print_next_step('resume', model_dir=model_dir, model_files=[model_file_name])\n\n\ndef validate(\n    model_dir,\n    valid_dataset,\n    overwrite,\n    max_memory,\n    max_processes,\n    use_torch,\n    command=None,\n    **kwargs\n):\n\n    dataset_path_extracted, dataset_extracted = valid_dataset\n\n    func_called_directly = (\n        command == 'validate'\n    )  # has this function been called from command line or from 'all'?\n    if func_called_directly:\n        ui.print_step_title('MODEL VALIDATION')\n        _print_dataset_properties(dataset_extracted)\n\n    test(\n        model_dir,\n        valid_dataset,\n        -1,  # n_test = -1 -> validation mode\n        overwrite,\n        max_memory,\n        max_processes,\n        use_torch,\n        command,\n        **kwargs\n    )\n\n    if func_called_directly:\n\n        model_dir, model_files = model_dir\n        n_models = len(model_files)\n        _print_next_step('validate', model_dir=model_dir, model_files=model_files)\n\n\ndef test(\n    model_dir,\n    test_dataset,\n    n_test,\n    overwrite,\n    max_memory,\n    max_processes,\n    use_torch,\n    command=None,\n    **kwargs\n):  # noqa: C901\n\n    # NOTE: this function runs a validation if n_test < 0 and test with all points if n_test == 0\n\n    model_dir, model_file_names = model_dir\n    n_models = len(model_file_names)\n\n    n_test = 0 if n_test is None else n_test\n    is_validation = n_test < 0\n    is_test = n_test >= 0\n\n    dataset_path, dataset = test_dataset\n\n    func_called_directly = (\n        command == 'test'\n    )  # has this function been called from command line or from 'all'?\n    if func_called_directly:\n        ui.print_step_title('MODEL TEST')\n        _print_dataset_properties(dataset)\n\n    F_rmse = []\n\n    # NEW\n\n    DEBUG_WRITE = False\n\n    if DEBUG_WRITE:\n        if os.path.exists('test_pred.xyz'):\n            os.remove('test_pred.xyz')\n        if os.path.exists('test_ref.xyz'):\n            os.remove('test_ref.xyz')\n        if os.path.exists('test_diff.xyz'):\n            os.remove('test_diff.xyz')\n\n    # NEW\n\n    num_workers, batch_size = -1, -1\n    gdml_train = None\n    for i, model_file_name in enumerate(model_file_names):\n\n        model_path = os.path.join(model_dir, model_file_name)\n        _, model = io.is_file_type(model_path, 'model')\n\n        if i == 0 and command != 'all':\n            print()\n            _print_model_properties(model)\n            print()\n\n        if not np.array_equal(model['z'], dataset['z']):\n            raise AssistantError(\n                'Atom composition or order in dataset does not match the one in model.'\n            )\n\n        if ('lattice' in model) is not ('lattice' in dataset):\n            if 'lattice' in model:\n                raise AssistantError(\n                    'Model contains lattice vectors, but dataset does not.'\n                )\n            elif 'lattice' in dataset:\n                raise AssistantError(\n                    'Dataset contains lattice vectors, but model does not.'\n                )\n\n        if model['use_E']:\n            e_err = model['e_err'].item()\n        f_err = model['f_err'].item()\n\n        is_model_validated = not (np.isnan(f_err['mae']) or np.isnan(f_err['rmse']))\n\n        if n_models > 1:\n            if i > 0:\n                print()\n            print(\n                ui.color_str(\n                    '%s model %d of %d'\n                    % ('Testing' if is_test else 'Validating', i + 1, n_models),\n                    bold=True,\n                )\n            )\n\n        if is_validation:\n            if is_model_validated and not overwrite:\n                log.info(\n                    'Skipping already validated model \\'{}\\'.'.format(model_file_name)\n                    + (\n                        '\\nRun \\'{} validate -o {} {}\\' to overwrite.'.format(\n                            PACKAGE_NAME, model_path, dataset_path\n                        )\n                        if command == 'test'\n                        else ''\n                    )\n                )\n                continue\n\n            if dataset['md5'] != model['md5_valid']:\n                raise AssistantError(\n                    'Fingerprint of provided validation dataset does not match the one specified in model file.'\n                )\n\n        test_idxs = model['idxs_valid']\n        if is_test:\n\n            # exclude training and/or test sets from validation set if necessary\n            excl_idxs = np.empty((0,), dtype=np.uint)\n            if dataset['md5'] == model['md5_train']:\n                excl_idxs = np.concatenate([excl_idxs, model['idxs_train']]).astype(\n                    np.uint\n                )\n            if dataset['md5'] == model['md5_valid']:\n                excl_idxs = np.concatenate([excl_idxs, model['idxs_valid']]).astype(\n                    np.uint\n                )\n\n            n_data = dataset['F'].shape[0]\n            n_data_eff = n_data - len(excl_idxs)\n\n            if (\n                n_test == 0 and n_data_eff != 0\n            ):  # test on all data points that have not been used for training or testing\n                n_test = n_data_eff\n                log.info(\n                    'Test set size was automatically set to {:,} points.'.format(n_test)\n                )\n\n            if n_test == 0 or n_data_eff == 0:\n                log.warning('Skipping! No unused points for test in provided dataset.')\n                return\n            elif n_data_eff < n_test:\n                n_test = n_data_eff\n                log.warning(\n                    'Test size reduced to {:d}. Not enough unused points in provided dataset.'.format(\n                        n_test\n                    )\n                )\n\n            if 'E' in dataset:\n                if gdml_train is None:\n                    gdml_train = GDMLTrain(\n                        max_memory=max_memory, max_processes=max_processes\n                    )\n                test_idxs = gdml_train.draw_strat_sample(\n                    dataset['E'], n_test, excl_idxs=excl_idxs\n                )\n            else:\n                test_idxs = np.delete(np.arange(n_data), excl_idxs)\n\n                log.warning(\n                    'Test dataset will be sampled with no guidance from energy labels (randomly)!\\n'\n                    + 'Note: Larger test datasets are recommended due to slower convergence of the error.'\n                )\n        # shuffle to improve convergence of online error\n        np.random.shuffle(test_idxs)\n\n        # NEW\n        if DEBUG_WRITE:\n            test_idxs = np.sort(test_idxs)\n\n        z = dataset['z']\n        R = dataset['R'][test_idxs, :, :]\n        F = dataset['F'][test_idxs, :, :]\n\n        if model['use_E']:\n            E = dataset['E'][test_idxs]\n\n        try:\n            gdml_predict = GDMLPredict(\n                model,\n                max_memory=max_memory,\n                max_processes=max_processes,\n                use_torch=use_torch,\n            )\n        except:\n            print()\n            log.critical(traceback.format_exc())\n            print()\n            os._exit(1)\n\n        b_size = min(1000, len(test_idxs))\n\n        if not use_torch:\n            if num_workers == -1 or batch_size == -1:\n                ui.callback(NOT_DONE, disp_str='Optimizing parallelism')\n\n                gps, is_from_cache = gdml_predict.prepare_parallel(\n                    n_bulk=b_size, return_is_from_cache=True\n                )\n                num_workers, chunk_size, bulk_mp = (\n                    gdml_predict.num_workers,\n                    gdml_predict.chunk_size,\n                    gdml_predict.bulk_mp,\n                )\n\n                sec_disp_str = 'no chunking'.format(chunk_size)\n                if chunk_size != gdml_predict.n_train:\n                    sec_disp_str = 'chunks of {:d}'.format(chunk_size)\n\n                if num_workers == 0:\n                    sec_disp_str = 'no workers / ' + sec_disp_str\n                else:\n                    sec_disp_str = (\n                        '{:d} workers {}/ '.format(\n                            num_workers, '[MP] ' if bulk_mp else ''\n                        )\n                        + sec_disp_str\n                    )\n\n                ui.callback(\n                    DONE,\n                    disp_str='Optimizing parallelism'\n                    + (' (from cache)' if is_from_cache else ''),\n                    sec_disp_str=sec_disp_str,\n                )\n            else:\n                gdml_predict._set_num_workers(num_workers)\n                gdml_predict._set_chunk_size(chunk_size)\n                gdml_predict._set_bulk_mp(bulk_mp)\n\n        n_atoms = z.shape[0]\n\n        if model['use_E']:\n            e_mae_sum, e_rmse_sum = 0, 0\n        f_mae_sum, f_rmse_sum = 0, 0\n        cos_mae_sum, cos_rmse_sum = 0, 0\n        mag_mae_sum, mag_rmse_sum = 0, 0\n\n        n_done = 0\n        t = time.time()\n        for b_range in _batch(list(range(len(test_idxs))), b_size):\n\n            n_done_step = len(b_range)\n            n_done += n_done_step\n\n            r = R[b_range].reshape(n_done_step, -1)\n            e_pred, f_pred = gdml_predict.predict(r)\n\n            # energy error\n            if model['use_E']:\n                e = E[b_range]\n                e_mae, e_mae_sum, e_rmse, e_rmse_sum = _online_err(\n                    np.squeeze(e) - e_pred, 1, n_done, e_mae_sum, e_rmse_sum\n                )\n\n                # import matplotlib.pyplot as plt\n                # plt.hist(np.squeeze(e) - e_pred)\n                # plt.show()\n\n            # force component error\n            f = F[b_range].reshape(n_done_step, -1)\n            f_mae, f_mae_sum, f_rmse, f_rmse_sum = _online_err(\n                f - f_pred, 3 * n_atoms, n_done, f_mae_sum, f_rmse_sum\n            )\n\n            # magnitude error\n            f_pred_mags = np.linalg.norm(f_pred.reshape(-1, 3), axis=1)\n            f_mags = np.linalg.norm(f.reshape(-1, 3), axis=1)\n            mag_mae, mag_mae_sum, mag_rmse, mag_rmse_sum = _online_err(\n                f_pred_mags - f_mags, n_atoms, n_done, mag_mae_sum, mag_rmse_sum\n            )\n\n            # normalized cosine error\n            f_pred_norm = f_pred.reshape(-1, 3) / f_pred_mags[:, None]\n            f_norm = f.reshape(-1, 3) / f_mags[:, None]\n            cos_err = (\n                np.arccos(np.clip(np.einsum('ij,ij->i', f_pred_norm, f_norm), -1, 1))\n                / np.pi\n            )\n            cos_mae, cos_mae_sum, cos_rmse, cos_rmse_sum = _online_err(\n                cos_err, n_atoms, n_done, cos_mae_sum, cos_rmse_sum\n            )\n\n            # NEW\n\n            if is_test and DEBUG_WRITE:\n\n                try:\n                    with open('test_pred.xyz', 'a') as file:\n\n                        n = r.shape[0]\n                        for i, ri in enumerate(r):\n\n                            r_out = ri.reshape(-1, 3)\n                            e_out = e_pred[i]\n                            f_out = f_pred[i].reshape(-1, 3)\n\n                            ext_xyz_str = (\n                                io.generate_xyz_str(r_out, model['z'], e=e_out, f=f_out)\n                                + '\\n'\n                            )\n\n                            file.write(ext_xyz_str)\n\n                except IOError:\n                    sys.exit(\"ERROR: Writing xyz file failed.\")\n\n                try:\n                    with open('test_ref.xyz', 'a') as file:\n\n                        n = r.shape[0]\n                        for i, ri in enumerate(r):\n\n                            r_out = ri.reshape(-1, 3)\n                            e_out = (\n                                None\n                                if not model['use_E']\n                                else np.squeeze(E[b_range][i])\n                            )\n                            f_out = f[i].reshape(-1, 3)\n\n                            ext_xyz_str = (\n                                io.generate_xyz_str(r_out, model['z'], e=e_out, f=f_out)\n                                + '\\n'\n                            )\n                            file.write(ext_xyz_str)\n\n                except IOError:\n                    sys.exit(\"ERROR: Writing xyz file failed.\")\n\n                try:\n                    with open('test_diff.xyz', 'a') as file:\n\n                        n = r.shape[0]\n                        for i, ri in enumerate(r):\n\n                            r_out = ri.reshape(-1, 3)\n                            e_out = (\n                                None\n                                if not model['use_E']\n                                else (np.squeeze(E[b_range][i]) - e_pred[i])\n                            )\n                            f_out = (f[i] - f_pred[i]).reshape(-1, 3)\n\n                            ext_xyz_str = (\n                                io.generate_xyz_str(r_out, model['z'], e=e_out, f=f_out)\n                                + '\\n'\n                            )\n                            file.write(ext_xyz_str)\n\n                except IOError:\n                    sys.exit(\"ERROR: Writing xyz file failed.\")\n\n            # NEW\n\n            sps = n_done / (time.time() - t)  # examples per second\n            disp_str = 'energy %.3f/%.3f, ' % (e_mae, e_rmse) if model['use_E'] else ''\n            disp_str += 'forces %.3f/%.3f' % (f_mae, f_rmse)\n            disp_str = (\n                '{} errors (MAE/RMSE): '.format('Test' if is_test else 'Validation')\n                + disp_str\n            )\n            sec_disp_str = '@ %.1f geo/s' % sps if b_range is not None else ''\n\n            ui.callback(\n                n_done,\n                len(test_idxs),\n                disp_str=disp_str,\n                sec_disp_str=sec_disp_str,\n                newline_when_done=False,\n            )\n\n        if is_test:\n            ui.callback(\n                DONE,\n                disp_str='Testing on {:,} points'.format(n_test),\n                sec_disp_str=sec_disp_str,\n            )\n        else:\n            ui.callback(DONE, disp_str=disp_str, sec_disp_str=sec_disp_str)\n\n        if model['use_E']:\n            e_rmse_pct = (e_rmse / e_err['rmse'] - 1.0) * 100\n        f_rmse_pct = (f_rmse / f_err['rmse'] - 1.0) * 100\n\n        if is_test and n_models == 1:\n            n_train = len(model['idxs_train'])\n            n_valid = len(model['idxs_valid'])\n            print()\n            ui.print_two_column_str(\n                ui.color_str('Test errors (MAE/RMSE)', bold=True),\n                '{:,} + {:,} points (training + validation), sigma (length scale): {}'.format(\n                    n_train, n_valid, model['sig']\n                ),\n            )\n\n            r_unit = 'unknown unit'\n            e_unit = 'unknown unit'\n            f_unit = 'unknown unit'\n            if 'r_unit' in dataset and 'e_unit' in dataset:\n                r_unit = dataset['r_unit']\n                e_unit = dataset['e_unit']\n                f_unit = str(dataset['e_unit']) + '/' + str(dataset['r_unit'])\n\n            format_str = '  {:<18} {:>.4f}/{:>.4f} [{}]'\n            if model['use_E']:\n                ui.print_two_column_str(\n                    format_str.format('Energy:', e_mae, e_rmse, e_unit),\n                    'relative to expected: {:+.1f}%'.format(e_rmse_pct),\n                )\n\n            ui.print_two_column_str(\n                format_str.format('Forces:', f_mae, f_rmse, f_unit),\n                'relative to expected: {:+.1f}%'.format(f_rmse_pct),\n            )\n\n            print(format_str.format('  Magnitude:', mag_mae, mag_rmse, r_unit))\n            ui.print_two_column_str(\n                format_str.format('  Angle:', cos_mae, cos_rmse, '0-1'),\n                'lower is better',\n            )\n            print()\n\n        model_mutable = dict(model)\n        model.close()\n        model = model_mutable\n\n        model_needs_update = (\n            overwrite\n            or (is_test and model['n_test'] < len(test_idxs))\n            or (is_validation and not is_model_validated)\n        )\n        if model_needs_update:\n\n            if is_validation and overwrite:\n                model['n_test'] = 0  # flag the model as not tested\n\n            if is_test:\n                model['n_test'] = len(test_idxs)\n                model['md5_test'] = dataset['md5']\n\n            if model['use_E']:\n                model['e_err'] = {\n                    'mae': e_mae.item(),\n                    'rmse': e_rmse.item(),\n                }\n\n            model['f_err'] = {'mae': f_mae.item(), 'rmse': f_rmse.item()}\n            np.savez_compressed(model_path, **model)\n\n            if is_test and model['n_test'] > 0:\n                log.info('Expected errors were updated in model file.')\n\n        else:\n            add_info_str = (\n                'the same number of'\n                if model['n_test'] == len(test_idxs)\n                else 'only {:,}'.format(len(test_idxs))\n            )\n            log.warning(\n                'This model has previously been tested on {:,} points, which is why the errors for the current test run with {} points have NOT been used to update the model file.\\n'.format(\n                    model['n_test'], add_info_str\n                )\n                + 'Run \\'{} test -o {} {} {}\\' to overwrite.'.format(\n                    PACKAGE_NAME, os.path.relpath(model_path), dataset_path, n_test\n                )\n            )\n\n        F_rmse.append(f_rmse)\n\n    return F_rmse\n\n\ndef select(model_dir, overwrite, model_file=None, command=None, **kwargs):  # noqa: C901\n\n    func_called_directly = (\n        command == 'select'\n    )  # has this function been called from command line or from 'all'?\n    if func_called_directly:\n        ui.print_step_title('MODEL SELECTION')\n\n    any_model_not_validated = False\n    any_model_is_tested = False\n\n    model_dir, model_file_names = model_dir\n    if len(model_file_names) > 1:\n\n        use_E = True\n\n        rows = []\n        data_names = ['sig', 'MAE', 'RMSE', 'MAE', 'RMSE']\n        for i, model_file_name in enumerate(model_file_names):\n            model_path = os.path.join(model_dir, model_file_name)\n            _, model = io.is_file_type(model_path, 'model')\n\n            use_E = model['use_E']\n\n            if i == 0:\n                idxs_train = set(model['idxs_train'])\n                md5_train = model['md5_train']\n                idxs_valid = set(model['idxs_valid'])\n                md5_valid = model['md5_valid']\n            else:\n                if (\n                    md5_train != model['md5_train']\n                    or md5_valid != model['md5_valid']\n                    or idxs_train != set(model['idxs_train'])\n                    or idxs_valid != set(model['idxs_valid'])\n                ):\n                    raise AssistantError(\n                        '{} contains models trained or validated on different datasets.'.format(\n                            model_dir\n                        )\n                    )\n\n            e_err = {'mae': 0.0, 'rmse': 0.0}\n            if model['use_E']:\n                e_err = model['e_err'].item()\n            f_err = model['f_err'].item()\n\n            is_model_validated = not (np.isnan(f_err['mae']) or np.isnan(f_err['rmse']))\n            if not is_model_validated:\n                any_model_not_validated = True\n\n            is_model_tested = model['n_test'] > 0\n            if is_model_tested:\n                any_model_is_tested = True\n\n            rows.append(\n                [model['sig'], e_err['mae'], e_err['rmse'], f_err['mae'], f_err['rmse']]\n            )\n\n            model.close()\n\n        if any_model_not_validated:\n            log.warning(\n                'One or more models in the given directory have not been validated.'\n            )\n            print()\n\n        if any_model_is_tested:\n            log.error(\n                'One or more models in the given directory have already been tested. This means that their recorded expected errors are test errors, not validation errors. However, one should never perform model selection based on the test error!\\n'\n                + 'Please run the validation command (again) with the overwrite option \\'-o\\', then this selection command.'\n            )\n            return\n\n        f_rmse_col = [row[4] for row in rows]\n        best_idx = f_rmse_col.index(min(f_rmse_col))  # idx of row with lowest f_rmse\n        best_sig = rows[best_idx][0]\n\n        rows = sorted(rows, key=lambda col: col[0])  # sort according to sigma\n        print(ui.color_str('Cross-validation errors', bold=True))\n        print(' ' * 7 + 'Energy' + ' ' * 6 + 'Forces')\n        print((' {:>3} ' + '{:>5} ' * 4).format(*data_names))\n        print(' ' + '-' * 27)\n        format_str = ' {:>3} ' + '{:5.2f} ' * 4\n        format_str_no_E = ' {:>3}     -     - ' + '{:5.2f} ' * 2\n        for row in rows:\n            if use_E:\n                row_str = format_str.format(*row)\n            else:\n                row_str = format_str_no_E.format(*[row[0], row[3], row[4]])\n\n            if row[0] != best_sig:\n                row_str = ui.color_str(row_str, fore_color=ui.GRAY)\n            print(row_str)\n        print()\n\n        sig_col = [row[0] for row in rows]\n        if best_sig == min(sig_col) or best_sig == max(sig_col):\n            log.warning(\n                'The optimal sigma (length scale) lies on the boundary of the search grid.\\n'\n                + 'Model performance might improve if the search grid is extended in direction sigma {} {:d}.'.format(\n                    '<' if best_idx == 0 else '>', best_sig\n                )\n            )\n\n    else:  # only one model available\n        log.info('Skipping model selection step as there is only one model to select.')\n\n        best_idx = 0\n\n    best_model_path = os.path.join(model_dir, model_file_names[best_idx])\n\n    if model_file is None:\n\n        # generate model file name based on model properties\n        best_model = np.load(best_model_path, allow_pickle=True)\n        model_file = io.model_file_name(best_model, is_extended=True)\n        best_model.close()\n\n    model_exists = os.path.isfile(model_file)\n    if model_exists and overwrite:\n        log.info('Overwriting existing model file.')\n\n    if not model_exists or overwrite:\n        if func_called_directly:\n            log.done('Writing model file \\'{}\\''.format(model_file))\n\n        shutil.copy(best_model_path, model_file)\n        shutil.rmtree(model_dir, ignore_errors=True)\n    else:\n        log.warning(\n            'Model \\'{}\\' already exists.\\n'.format(model_file)\n            + 'Run \\'{} select -o {}\\' to overwrite.'.format(\n                PACKAGE_NAME, os.path.relpath(model_dir)\n            )\n        )\n\n    if func_called_directly:\n        _print_next_step('select', model_files=[model_file])\n\n    return model_file\n\n\ndef show(file, command=None, **kwargs):\n\n    ui.print_step_title('SHOW DETAILS')\n    file_path, file = file\n\n    if file['type'].astype(str) == 'd':\n        _print_dataset_properties(file)\n\n    if file['type'].astype(str) == 't':\n        _print_task_properties(file)\n\n    if file['type'].astype(str) == 'm':\n        _print_model_properties(file)\n\n\ndef reset(command=None, **kwargs):\n\n    if ui.yes_or_no('\\nDo you really want to purge all caches and temporary files?'):\n\n        pkg_dir = os.path.dirname(os.path.abspath(__file__))\n        bmark_file = '_bmark_cache.npz'\n        bmark_path = os.path.join(pkg_dir, bmark_file)\n\n        if os.path.exists(bmark_path):\n            try:\n                os.remove(bmark_path)\n            except OSError:\n                print()\n                log.critical('Exception: unable to delete benchmark cache.')\n                print()\n                os._exit(1)\n\n            log.done('Benchmark cache deleted.')\n        else:\n            log.info('Benchmark cache was already empty.')\n    else:\n        print(' Cancelled.')\n\n\ndef main():\n    def _add_argument_sample_size(parser, subset_str):\n        subparser.add_argument(\n            'n_%s' % subset_str,\n            metavar='<n_%s>' % subset_str,\n            type=io.is_strict_pos_int,\n            help='%s sample size' % subset_str,\n        )\n\n    def _add_argument_dir_with_file_type(parser, type, or_file=False):\n        parser.add_argument(\n            '%s_dir' % type,\n            metavar='<%s_dir%s>' % (type, '_or_file' if or_file else ''),\n            type=lambda x: io.is_dir_with_file_type(x, type, or_file=or_file),\n            help='path to %s directory%s' % (type, ' or file' if or_file else ''),\n        )\n\n    # Available resources\n    total_memory = psutil.virtual_memory().total // 2**30\n    total_cpus = mp.cpu_count()\n\n    parser = argparse.ArgumentParser()\n    parser.add_argument(\n        '--version',\n        action='version',\n        version='%(prog)s '\n        + __version__\n        + ' [Python {}, NumPy {}, SciPy {}'.format(\n            '.'.join(map(str, sys.version_info[:3])), np.__version__, sp.__version__\n        )\n        + ', PyTorch {}'.format(torch.__version__ if _has_torch else 'N/A')\n        + ', ASE {}'.format(ase.__version__ if _has_ase else 'N/A')\n        + ']',\n    )\n\n    parent_parser = argparse.ArgumentParser(add_help=False)\n\n    subparsers = parser.add_subparsers(title='commands', dest='command')\n    subparsers.required = True\n    parser_all = subparsers.add_parser(\n        'all',\n        help='reconstruct a force field from beginning to end',\n        parents=[parent_parser],\n    )\n    parser_create = subparsers.add_parser(\n        'create', help='create training task(s)', parents=[parent_parser]\n    )\n    parser_train = subparsers.add_parser(\n        'train', help='train model(s) from task(s)', parents=[parent_parser]\n    )\n    parser_resume = subparsers.add_parser(\n        'resume', help='resume training of a model', parents=[parent_parser]\n    )\n    parser_valid = subparsers.add_parser(\n        'validate', help='validate model(s)', parents=[parent_parser]\n    )\n    parser_select = subparsers.add_parser(\n        'select', help='select best performing model', parents=[parent_parser]\n    )\n    parser_test = subparsers.add_parser(\n        'test', help='test a model', parents=[parent_parser]\n    )\n    parser_show = subparsers.add_parser(\n        'show',\n        help='print details about a dataset, task or model file',\n        parents=[parent_parser],\n    )\n    subparsers.add_parser(\n        'reset', help='delete all caches and temporary files', parents=[parent_parser]\n    )\n\n    for subparser in [parser_all, parser_create]:\n\n        subparser.add_argument(\n            'dataset',\n            metavar='<dataset_file>',\n            type=lambda x: io.is_file_type(x, 'dataset'),\n            help='path to dataset file (train/validation/test subsets are sampled from here if no seperate dataset are specified)',\n        )\n\n        _add_argument_sample_size(subparser, 'train')\n        _add_argument_sample_size(subparser, 'valid')\n        subparser.add_argument(\n            '-v',\n            '--validation_dataset',\n            metavar='<valid_dataset_file>',\n            dest='valid_dataset',\n            type=lambda x: io.is_file_type(x, 'dataset'),\n            help='path to separate validation dataset file',\n        )\n        subparser.add_argument(\n            '-t',\n            '--test_dataset',\n            metavar='<test_dataset_file>',\n            dest='test_dataset',\n            type=lambda x: io.is_file_type(x, 'dataset'),\n            help='path to separate test dataset file',\n        )\n        subparser.add_argument(\n            '-s',\n            '--sig',\n            metavar=('<s1>', '<s2>'),\n            dest='sigs',\n            type=io.parse_list_or_range,\n            help='integer list and/or range <start>:[<step>:]<stop> for the kernel hyper-parameter sigma (length scale)',\n            nargs='+',\n        )\n\n        group = subparser.add_mutually_exclusive_group()\n        group.add_argument(\n            '--gdml',\n            action='store_true',\n            help='don\\'t include symmetries in the model (GDML)',\n        )\n\n        group.add_argument(\n            '--perms_from',\n            metavar='<file>',\n            dest='perms_from_arg',\n            type=lambda x: io.is_valid_file_type(x),\n            help='path to file to take permutations from (key: \\'perms\\')',\n        )\n\n        group = subparser.add_mutually_exclusive_group()\n        group.add_argument(\n            '--no_E',\n            dest='use_E',\n            action='store_false',\n            help='only reconstruct force field w/o potential energy surface',\n        )\n        group.add_argument(\n            '--E_cstr',\n            dest='use_E_cstr',\n            action='store_true',\n            help='include pointwise energy constraints',\n        )\n\n        subparser.add_argument(\n            '--task_dir',\n            metavar='<task_dir>',\n            dest='task_dir',\n            help='user-defined task output dir name',\n        )\n\n    for subparser in [parser_all, parser_select]:\n        subparser.add_argument(\n            '--model_file',\n            metavar='<model_file>',\n            dest='model_file',\n            help='user-defined model output file name',\n        )\n\n    for subparser in [parser_all, parser_train]:\n        subparser.add_argument(\n            '--lazy',\n            dest='lazy_training',\n            action='store_true',\n            help='give up on unfinished tasks (if more than one)',\n        )\n\n    for subparser in [parser_valid, parser_test]:\n        _add_argument_dir_with_file_type(subparser, 'model', or_file=True)\n\n    parser_valid.add_argument(\n        'valid_dataset',\n        metavar='<valid_dataset_file>',\n        type=lambda x: io.is_file_type(x, 'dataset'),\n        help='path to validation dataset file',\n    )\n    parser_test.add_argument(\n        'test_dataset',\n        metavar='<test_dataset_file>',\n        type=lambda x: io.is_file_type(x, 'dataset'),\n        help='path to test dataset file',\n    )\n\n    for subparser in [parser_all, parser_test]:\n        subparser.add_argument(\n            'n_test',\n            metavar='<n_test>',\n            type=io.is_strict_pos_int,\n            help='test sample size',\n            nargs='?',\n            default=None,\n        )\n\n    parser_resume.add_argument(\n        'model',\n        metavar='<model_file>',\n        type=lambda x: io.is_file_type(x, 'model'),\n        help='path to model file to complete training for',\n    )\n    parser_resume.add_argument(\n        'dataset',\n        metavar='<train_dataset_file>',\n        type=lambda x: io.is_file_type(x, 'dataset'),\n        help='path to original training dataset file',\n    )\n\n    _add_argument_dir_with_file_type(parser_train, 'task', or_file=True)\n\n    for subparser in [parser_train, parser_resume]:\n        subparser.add_argument(\n            'valid_dataset',\n            metavar='<valid_dataset_file>',\n            type=lambda x: io.is_file_type(x, 'dataset'),\n            help='path to validation dataset file',\n        )\n\n    _add_argument_dir_with_file_type(parser_select, 'model')\n\n    parser_show.add_argument(\n        'file',\n        metavar='<file>',\n        type=lambda x: io.is_valid_file_type(x),\n        help='path to dataset, task or model file',\n    )\n\n    for subparser in [\n        parser_all,\n        parser_train,\n        parser_resume,\n        parser_valid,\n        parser_test,\n    ]:\n\n        subparser.add_argument(\n            '-m',\n            '--max_memory',\n            metavar='<max_memory>',\n            type=int,\n            help='limit memory usage (whenever possible) [GB]',\n            choices=range(1, total_memory + 1),\n            default=total_memory,\n        )\n\n        subparser.add_argument(\n            '-p',\n            '--max_processes',\n            metavar='<max_processes>',\n            type=int,\n            help='limit number of processes',\n            choices=range(1, total_cpus + 1),\n            default=total_cpus,\n        )\n\n        subparser.add_argument(\n            '--cpu',\n            dest='use_torch',\n            action='store_false',\n            help='use CPU implementation (no PyTorch dependency)',\n        )\n\n    for subparser in [\n        parser_all,\n        parser_create,\n        parser_train,\n        parser_resume,\n        parser_valid,\n        parser_select,\n        parser_test,\n    ]:\n        subparser.add_argument(\n            '-o',\n            '--overwrite',\n            dest='overwrite',\n            action='store_true',\n            help='overwrite existing files',\n        )\n\n    args = parser.parse_args()\n\n    # Post-processing for optional sig argument\n    if 'sigs' in args and args.sigs is not None:\n        args.sigs = np.hstack(\n            args.sigs\n        ).tolist()  # Flatten list, if (part of it) was generated using the range syntax\n        args.sigs = sorted(list(set(args.sigs)))  # remove potential duplicates\n\n    # Post-processing for optional model output file argument\n    if 'model_file' in args and args.model_file is not None:\n        if not args.model_file.endswith('.npz'):\n            args.model_file += '.npz'\n\n    # Check PyTorch GPU support.\n    if ('use_torch' in args and args.use_torch) or 'use_torch' not in args:\n        if _has_torch:\n            if not (_torch_cuda_is_available or _torch_mps_is_available):\n                print()  # TODO: print only if log level includes warning\n                log.warning(\n                    'Your PyTorch installation does not see any GPU(s) on your system and will thus run all calculations on the CPU! If this is what you want, we recommend bypassing PyTorch using \\'--cpu\\' for improved performance.'\n                )\n        else:\n            print()\n            log.critical(\n                'PyTorch dependency not found! Please install or use \\'--cpu\\' to bypass PyTorch and run everything the CPU.'\n            )\n            print()\n            os._exit(1)\n\n    args = vars(args)\n\n    _print_splash(\n        args['max_memory'] if 'max_memory' in args else total_memory,\n        args['max_processes'] if 'max_processes' in args else total_cpus,\n        args['use_torch'] if 'use_torch' in args else True,\n    )\n\n    try:\n        getattr(sys.modules[__name__], args['command'])(**args)\n    except AssistantError as err:\n        log.error(str(err))\n        print()\n        os._exit(1)\n    except:\n        log.critical(traceback.format_exc())\n        print()\n        os._exit(1)\n    print()\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "sgdml/get.py",
    "content": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2023 Stefan Chmiela\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nfrom __future__ import print_function\n\nimport argparse\nimport os\nimport re\nimport sys\n\nfrom . import __version__\nfrom .utils import ui\n\nif sys.version[0] == '3':\n    raw_input = input\n\ntry:\n    from urllib.request import urlopen\nexcept ImportError:\n    from urllib2 import urlopen\n\n\ndef download(command, file_name):\n\n    base_url = 'http://www.quantum-machine.org/gdml/' + (\n        'data/npz/' if command == 'dataset' else 'models/'\n    )\n    request = urlopen(base_url + file_name)\n    file = open(file_name, 'wb')\n    filesize = int(request.headers['Content-Length'])\n\n    size = 0\n    block_sz = 1024\n    while True:\n        buffer = request.read(block_sz)\n        if not buffer:\n            break\n        size += len(buffer)\n        file.write(buffer)\n\n        ui.callback(\n            size,\n            filesize,\n            disp_str='Downloading: {}'.format(file_name),\n            sec_disp_str='{:,} bytes'.format(filesize),\n        )\n    file.close()\n\n\ndef main():\n\n    base_url = 'http://www.quantum-machine.org/gdml/'\n\n    parser = argparse.ArgumentParser()\n\n    parent_parser = argparse.ArgumentParser(add_help=False)\n    parent_parser.add_argument(\n        '-o',\n        '--overwrite',\n        dest='overwrite',\n        action='store_true',\n        help='overwrite existing files',\n    )\n\n    subparsers = parser.add_subparsers(title='commands', dest='command')\n    subparsers.required = True\n    parser_dataset = subparsers.add_parser(\n        'dataset', help='download benchmark dataset', parents=[parent_parser]\n    )\n    parser_model = subparsers.add_parser(\n        'model', help='download pre-trained model', parents=[parent_parser]\n    )\n\n    for subparser in [parser_dataset, parser_model]:\n        subparser.add_argument(\n            'name',\n            metavar='<name>',\n            type=str,\n            help='item name',\n            nargs='?',\n            default=None,\n        )\n\n    args = parser.parse_args()\n\n    print(\"Contacting server (%s)...\" % base_url)\n\n    if args.name is not None:\n\n        url = '%sget.php?version=%s&%s=%s' % (\n            base_url,\n            __version__,\n            args.command,\n            args.name,\n        )\n        response = urlopen(url)\n        match, score = response.read().decode().split(',')\n        response.close()\n\n        if int(score) == 0 or ui.yes_or_no('Do you mean \\'%s\\'?' % match):\n            download(args.command, match + '.npz')\n            return\n\n    response = urlopen(\n        '%sget.php?version=%s&%s' % (base_url, __version__, args.command)\n    )\n    line = response.readlines()\n    response.close()\n\n    print()\n    print('Available %ss:' % args.command)\n\n    print('{:<2} {:<31}    {:>4}'.format('ID', 'Name', 'Size'))\n    print('-' * 42)\n\n    items = line[0].split(b';')\n    for i, item in enumerate(items):\n        name, size = item.split(b',')\n        size = int(size) / 1024**2  # Bytes to MBytes\n\n        print('{:>2d} {:<30} {:>5.1f} MB'.format(i, name.decode(\"utf-8\"), size))\n    print()\n\n    down_list = raw_input(\n        'Please list which %ss to download (e.g. 0 1 2 6) or type \\'all\\': '\n        % args.command\n    )\n    down_idxs = []\n    if 'all' in down_list.lower():\n        down_idxs = list(range(len(items)))\n    elif re.match(\n        \"^ *[0-9][0-9 ]*$\", down_list\n    ):  # only digits and spaces, at least one digit\n        down_idxs = [int(idx) for idx in re.split(r'\\s+', down_list.strip())]\n        down_idxs = list(set(down_idxs))\n    else:\n        print(ui.color_str('ABORTED', fore_color=ui.RED, bold=True))\n\n    for idx in down_idxs:\n        if idx not in range(len(items)):\n            print(\n                ui.color_str('[WARN]', fore_color=ui.YELLOW, bold=True)\n                + ' Index '\n                + str(idx)\n                + ' out of range, skipping.'\n            )\n        else:\n            name = items[idx].split(b',')[0].decode(\"utf-8\")\n            if os.path.exists(name):\n                print(\"'%s' exists, skipping.\" % (name))\n                continue\n\n            download(args.command, name + '.npz')\n\n\nif __name__ == \"__main__\":\n    main()\n"
  },
  {
    "path": "sgdml/intf/__init__.py",
    "content": ""
  },
  {
    "path": "sgdml/intf/ase_calc.py",
    "content": "# MIT License\n#\n# Copyright (c) 2018-2020 Stefan Chmiela\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nimport logging\nimport numpy as np\n\ntry:\n    from ase.calculators.calculator import Calculator\n    from ase.units import kcal, mol\nexcept ImportError:\n    raise ImportError(\n        'Optional ASE dependency not found! Please run \\'pip install sgdml[ase]\\' to install it.'\n    )\n\nfrom ..predict import GDMLPredict\n\n\nclass SGDMLCalculator(Calculator):\n\n    implemented_properties = ['energy', 'forces']\n\n    def __init__(\n        self,\n        model_path,\n        E_to_eV=kcal / mol,\n        F_to_eV_Ang=kcal / mol,\n        use_torch=False,\n        *args,\n        **kwargs\n    ):\n        \"\"\"\n        ASE calculator for the sGDML force field.\n\n        A calculator takes atomic numbers and atomic positions from an Atoms object and calculates the energy and forces.\n\n        Note\n        ----\n        ASE uses eV and Angstrom as energy and length unit, respectively. Unless the paramerters `E_to_eV` and `F_to_eV_Ang` are specified, the sGDML model is assumed to use kcal/mol and Angstorm and the appropriate conversion factors are set accordingly.\n        Here is how to find them: `ASE units <https://wiki.fysik.dtu.dk/ase/ase/units.html>`_.\n\n        Parameters\n        ----------\n                model_path : :obj:`str`\n                        Path to a sGDML model file\n                E_to_eV : float, optional\n                        Conversion factor from whatever energy unit is used by the model to eV. By default this parameter is set to convert from kcal/mol.\n                F_to_eV_Ang : float, optional\n                        Conversion factor from whatever length unit is used by the model to Angstrom. By default, the length unit is not converted (assumed to be in Angstrom)\n                use_torch : boolean, optional\n                        Use PyTorch to calculate predictions\n        \"\"\"\n\n        super(SGDMLCalculator, self).__init__(*args, **kwargs)\n\n        self.log = logging.getLogger(__name__)\n\n        model = np.load(model_path, allow_pickle=True)\n        self.gdml_predict = GDMLPredict(model, use_torch=use_torch)\n        self.gdml_predict.prepare_parallel(n_bulk=1)\n\n        self.log.warning(\n            'Please remember to specify the proper conversion factors, if your model does not use \\'kcal/mol\\' and \\'Ang\\' as units.'\n        )\n\n        # Converts energy from the unit used by the sGDML model to eV.\n        self.E_to_eV = E_to_eV\n\n        # Converts length from eV to unit used in sGDML model.\n        self.Ang_to_R = F_to_eV_Ang / E_to_eV\n\n        # Converts force from the unit used by the sGDML model to eV/Ang.\n        self.F_to_eV_Ang = F_to_eV_Ang\n\n    def calculate(self, atoms=None, *args, **kwargs):\n\n        super(SGDMLCalculator, self).calculate(atoms, *args, **kwargs)\n\n        # convert model units to ASE default units\n        r = np.array(atoms.get_positions()) * self.Ang_to_R\n\n        e, f = self.gdml_predict.predict(r.ravel())\n\n        # convert model units to ASE default units (eV and Ang)\n        e *= self.E_to_eV\n        f *= self.F_to_eV_Ang\n\n        self.results = {'energy': e, 'forces': f.reshape(-1, 3)}\n"
  },
  {
    "path": "sgdml/predict.py",
    "content": "\"\"\"\nThis module contains all routines for evaluating GDML and sGDML models.\n\"\"\"\n\n# MIT License\n#\n# Copyright (c) 2018-2022 Stefan Chmiela, Gregory Fonseca\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nfrom __future__ import print_function\n\nimport sys\nimport logging\nimport os\nimport psutil\n\nimport multiprocessing as mp\n\nPool = mp.get_context('fork').Pool\n\nimport timeit\nfrom functools import partial\n\ntry:\n    import torch\nexcept ImportError:\n    _has_torch = False\nelse:\n    _has_torch = True\n\ntry:\n    _torch_mps_is_available = torch.backends.mps.is_available()\nexcept AttributeError:\n    _torch_mps_is_available = False\n_torch_mps_is_available = False\n\ntry:\n    _torch_cuda_is_available = torch.cuda.is_available()\nexcept AttributeError:\n    _torch_cuda_is_available = False\n\nimport numpy as np\n\nfrom . import __version__\nfrom .utils.desc import Desc\n\n\ndef share_array(arr_np):\n    \"\"\"\n    Return a ctypes array allocated from shared memory with data from a\n    NumPy array of type `float`.\n\n    Parameters\n    ----------\n            arr_np : :obj:`numpy.ndarray`\n                    NumPy array.\n\n    Returns\n    -------\n            array of :obj:`ctype`\n    \"\"\"\n\n    arr = mp.RawArray('d', arr_np.ravel())\n    return arr, arr_np.shape\n\n\ndef _predict_wkr(\n    r, r_desc_d_desc, lat_and_inv, glob_id, wkr_start_stop=None, chunk_size=None\n):\n    \"\"\"\n    Compute (part) of a prediction.\n\n    Every prediction is a linear combination involving the training points used for\n    this model. This function evalutates that combination for the range specified by\n    `wkr_start_stop`. This workload can optionally be processed in chunks,\n    which can be faster as it requires less memory to be allocated.\n\n    Note\n    ----\n        It is sufficient to provide either the parameter `r` or `r_desc_d_desc`.\n        The other one can be set to `None`.\n\n    Parameters\n    ----------\n            r : :obj:`numpy.ndarray`\n                    An array of size 3N containing the Cartesian\n                    coordinates of each atom in the molecule.\n            r_desc_d_desc : tuple of :obj:`numpy.ndarray`\n                    A tuple made up of:\n                        (1) An array of size D containing the descriptors\n                        of dimension D for the molecule.\n                        (2) An array of size D x 3N containing the\n                        descriptor Jacobian for the molecules. It has dimension\n                        D with 3N partial derivatives with respect to the 3N\n                        Cartesian coordinates of each atom.\n            lat_and_inv : tuple of :obj:`numpy.ndarray`\n                    Tuple of 3 x 3 matrix containing lattice vectors as columns and\n                    its inverse.\n            glob_id : int\n                    Identifier of the global namespace that this\n                    function is supposed to be using (zero if only one\n                    instance of this class exists at the same time).\n            wkr_start_stop : tuple of int, optional\n                    Range defined by the indices of first and last (exclusive)\n                    sum element. The full prediction is generated if this parameter\n                    is not specified.\n            chunk_size : int, optional\n                    Chunk size. The whole linear combination is evaluated in a large\n                    vector operation instead of looping over smaller chunks if this\n                    parameter is left unspecified.\n\n    Returns\n    -------\n            :obj:`numpy.ndarray`\n                    Partial prediction of all force components and\n                    energy (appended to array as last element).\n    \"\"\"\n\n    global globs\n    glob = globs[glob_id]\n    sig, n_perms = glob['sig'], glob['n_perms']\n\n    desc_func = glob['desc_func']\n\n    R_desc_perms = np.frombuffer(glob['R_desc_perms']).reshape(\n        glob['R_desc_perms_shape']\n    )\n    R_d_desc_alpha_perms = np.frombuffer(glob['R_d_desc_alpha_perms']).reshape(\n        glob['R_d_desc_alpha_perms_shape']\n    )\n\n    if 'alphas_E_lin' in glob:\n        alphas_E_lin = np.frombuffer(glob['alphas_E_lin']).reshape(\n            glob['alphas_E_lin_shape']\n        )\n\n    r_desc, r_d_desc = r_desc_d_desc or desc_func.from_R(\n        r, lat_and_inv, max_processes=1\n    )  # no additional forking during parallelization\n\n    n_train = int(R_desc_perms.shape[0] / n_perms)\n\n    wkr_start, wkr_stop = (0, n_train) if wkr_start_stop is None else wkr_start_stop\n    if chunk_size is None:\n        chunk_size = n_train\n\n    dim_d = desc_func.dim\n    dim_i = desc_func.dim_i\n    dim_c = chunk_size * n_perms\n\n    # Pre-allocate memory.\n    diff_ab_perms = np.empty((dim_c, dim_d))\n    a_x2 = np.empty((dim_c,))\n    mat52_base = np.empty((dim_c,))\n\n    # avoid divisions (slower)\n    sig_inv = 1.0 / sig\n    mat52_base_fact = 5.0 / (3 * sig**3)\n    diag_scale_fact = 5.0 / sig\n    sqrt5 = np.sqrt(5.0)\n\n    E_F = np.zeros((dim_d + 1,))\n    F = E_F[1:]\n\n    wkr_start *= n_perms\n    wkr_stop *= n_perms\n\n    b_start = wkr_start\n    for b_stop in list(range(wkr_start + dim_c, wkr_stop, dim_c)) + [wkr_stop]:\n\n        rj_desc_perms = R_desc_perms[b_start:b_stop, :]\n        rj_d_desc_alpha_perms = R_d_desc_alpha_perms[b_start:b_stop, :]\n\n        # Resize pre-allocated memory for last iteration, if chunk_size is not a divisor of the training set size.\n        # Note: It's faster to process equally sized chunks.\n        c_size = b_stop - b_start\n        if c_size < dim_c:\n            diff_ab_perms = diff_ab_perms[:c_size, :]\n            a_x2 = a_x2[:c_size]\n            mat52_base = mat52_base[:c_size]\n\n        np.subtract(\n            np.broadcast_to(r_desc, rj_desc_perms.shape),\n            rj_desc_perms,\n            out=diff_ab_perms,\n        )\n        norm_ab_perms = sqrt5 * np.linalg.norm(diff_ab_perms, axis=1)\n\n        np.exp(-norm_ab_perms * sig_inv, out=mat52_base)\n        mat52_base *= mat52_base_fact\n        np.einsum(\n            'ji,ji->j', diff_ab_perms, rj_d_desc_alpha_perms, out=a_x2\n        )  # colum wise dot product\n\n        F += (a_x2 * mat52_base).dot(diff_ab_perms) * diag_scale_fact\n        mat52_base *= norm_ab_perms + sig\n        F -= mat52_base.dot(rj_d_desc_alpha_perms)\n\n        # Note: Energies are automatically predicted with a flipped sign here (because -E are trained, instead of E)\n        E_F[0] += a_x2.dot(mat52_base)\n\n        # Note: Energies are automatically predicted with a flipped sign here (because -E are trained, instead of E)\n        if 'alphas_E_lin' in glob:\n\n            K_fe = diff_ab_perms * mat52_base[:, None]\n            F += alphas_E_lin[b_start:b_stop].dot(K_fe)\n\n            K_ee = (\n                1 + (norm_ab_perms * sig_inv) * (1 + norm_ab_perms / (3 * sig))\n            ) * np.exp(-norm_ab_perms * sig_inv)\n\n            E_F[0] += K_ee.dot(alphas_E_lin[b_start:b_stop])\n\n        b_start = b_stop\n\n    out = E_F[: dim_i + 1]\n\n    # Descriptor has less entries than 3N, need to extend size of the 'E_F' array.\n    if dim_d < dim_i:\n        out = np.empty((dim_i + 1,))\n        out[0] = E_F[0]\n\n    out[1:] = desc_func.vec_dot_d_desc(\n        r_d_desc,\n        F,\n    )  # 'r_d_desc.T.dot(F)' for our special representation of 'r_d_desc'\n\n    return out\n\n\nclass GDMLPredict(object):\n    def __init__(\n        self,\n        model,\n        batch_size=None,\n        num_workers=None,\n        max_memory=None,\n        max_processes=None,\n        use_torch=False,\n        log_level=None,\n    ):\n        \"\"\"\n        Query trained sGDML force fields.\n\n        This class is used to load a trained model and make energy and\n        force predictions for new geometries. GPU support is provided\n        through PyTorch (requires optional `torch` dependency to be\n        installed).\n\n        Note\n        ----\n                The parameters `batch_size` and `num_workers` are only\n                relevant if this code runs on a CPU. Both can be set\n                automatically via the function `prepare_parallel`.\n                Note: Running calculations via PyTorch is only\n                recommended with available GPU hardware. CPU calcuations\n                are faster with our NumPy implementation.\n\n        Parameters\n        ----------\n                model : :obj:`dict`\n                        Data structure that holds all parameters of the\n                        trained model. This object is the output of\n                        `GDMLTrain.train`\n                batch_size : int, optional\n                        Chunk size for processing parallel tasks\n                num_workers : int, optional\n                        Number of parallel workers (in addition to the main\n                        process)\n                max_memory : int, optional\n                        Limit the max. memory usage [GB]. This is only a\n                        soft limit that can not always be enforced.\n                max_processes : int, optional\n                        Limit the max. number of processes. Otherwise\n                        all CPU cores are used. This parameters has no\n                        effect if `use_torch=True`\n                use_torch : boolean, optional\n                        Use PyTorch to calculate predictions\n                log_level : optional\n                        Set custom logging level (e.g. `logging.CRITICAL`)\n        \"\"\"\n\n        global globs\n        if 'globs' not in globals():\n            globs = []\n\n        # Create a personal global space for this model at a new index\n        # Note: do not call delete entries in this list, since 'self.glob_id' is\n        # static. Instead, setting them to None conserves positions while still\n        # freeing up memory.\n        globs.append({})\n        self.glob_id = len(globs) - 1\n        glob = globs[self.glob_id]\n\n        self.log = logging.getLogger(__name__)\n        if log_level is not None:\n            self.log.setLevel(log_level)\n\n        total_memory = psutil.virtual_memory().total // 2**30  # bytes to GB)\n        self.max_memory = (\n            min(max_memory, total_memory) if max_memory is not None else total_memory\n        )\n\n        total_cpus = mp.cpu_count()\n        self.max_processes = (\n            min(max_processes, total_cpus) if max_processes is not None else total_cpus\n        )\n\n        if 'type' not in model or not (model['type'] == 'm' or model['type'] == b'm'):\n            self.log.critical('The provided data structure is not a valid model.')\n            sys.exit()\n\n        self.n_atoms = model['z'].shape[0]\n\n        self.desc = Desc(self.n_atoms, max_processes=max_processes)\n        glob['desc_func'] = self.desc\n\n        # Cache for iterative training mode.\n        self.R_desc = None\n        self.R_d_desc = None\n\n        self.lat_and_inv = (\n            (model['lattice'], np.linalg.inv(model['lattice']))\n            if 'lattice' in model\n            else None\n        )\n\n        self.n_train = model['R_desc'].shape[1]\n        glob['sig'] = model['sig']\n\n        self.std = model['std'] if 'std' in model else 1.0\n        self.c = model['c']\n\n        n_perms = model['perms'].shape[0]\n        glob['n_perms'] = n_perms\n\n        self.tril_perms_lin = model['tril_perms_lin']\n\n        self.torch_predict = None\n        self.use_torch = use_torch\n        if use_torch:\n\n            if not _has_torch:\n                raise ImportError(\n                    'Optional PyTorch dependency not found! Please run \\'pip install sgdml[torch]\\' to install it or disable the PyTorch option.'\n                )\n\n            from .torchtools import GDMLTorchPredict\n\n            self.torch_predict = GDMLTorchPredict(\n                model,\n                self.lat_and_inv,\n                max_memory=max_memory,\n                max_processes=max_processes,\n                log_level=self.log.level,\n            )\n\n            # Enable data parallelism\n            n_gpu = torch.cuda.device_count()\n            if n_gpu > 1:\n                self.torch_predict = torch.nn.DataParallel(self.torch_predict)\n\n            # Send model to device\n            # self.torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'\n            if _torch_cuda_is_available:\n                self.torch_device = 'cuda'\n            elif _torch_mps_is_available:\n                self.torch_device = 'mps'\n            else:\n                self.torch_device = 'cpu'\n\n            while True:\n                try:\n                    self.torch_predict.to(self.torch_device)\n                except RuntimeError as e:\n                    if 'out of memory' in str(e):\n\n                        if _torch_cuda_is_available:\n                            torch.cuda.empty_cache()\n\n                        model = self.torch_predict\n                        if isinstance(self.torch_predict, torch.nn.DataParallel):\n                            model = model.module\n\n                        if (\n                            model.get_n_perm_batches() == 1\n                        ):  # model caches the permutations, this could be why it is too large\n                            model.set_n_perm_batches(\n                                model.get_n_perm_batches() + 1\n                            )  # uncache\n                            # self.torch_predict.to( # NOTE!\n                            #    self.torch_device\n                            # )  # try sending to device again\n                            pass\n                        else:\n                            self.log.critical(\n                                'Not enough memory on device (RAM or GPU memory). There is no hope!'\n                            )\n                            print()\n                            os._exit(1)\n                    else:\n                        raise e\n                else:\n                    break\n        else:\n\n            # Precompute permuted training descriptors and its first derivatives multiplied with the coefficients.\n\n            R_desc_perms = (\n                np.tile(model['R_desc'].T, n_perms)[:, self.tril_perms_lin]\n                .reshape(self.n_train, n_perms, -1, order='F')\n                .reshape(self.n_train * n_perms, -1)\n            )\n            glob['R_desc_perms'], glob['R_desc_perms_shape'] = share_array(R_desc_perms)\n\n            R_d_desc_alpha_perms = (\n                np.tile(model['R_d_desc_alpha'], n_perms)[:, self.tril_perms_lin]\n                .reshape(self.n_train, n_perms, -1, order='F')\n                .reshape(self.n_train * n_perms, -1)\n            )\n            (\n                glob['R_d_desc_alpha_perms'],\n                glob['R_d_desc_alpha_perms_shape'],\n            ) = share_array(R_d_desc_alpha_perms)\n\n            if 'alphas_E' in model:\n                alphas_E_lin = np.tile(model['alphas_E'][:, None], (1, n_perms)).ravel()\n                glob['alphas_E_lin'], glob['alphas_E_lin_shape'] = share_array(\n                    alphas_E_lin\n                )\n\n            # Parallel processing configuration\n\n            self.bulk_mp = False  # Bulk predictions with multiple processes?\n\n            self.pool = None\n\n            # How many workers in addition to main process?\n            num_workers = num_workers or (\n                self.max_processes - 1\n            )  # exclude main process\n            self._set_num_workers(num_workers, force_reset=True)\n\n            # Size of chunks in which each parallel task will be processed (unit: number of training samples)\n            # This parameter should be as large as possible, but it depends on the size of available memory.\n            self._set_chunk_size(batch_size)\n\n    def __del__(self):\n\n        global globs\n\n        try:\n            self.pool.terminate()\n            self.pool.join()\n            self.pool = None\n        except:\n            pass\n\n        if 'globs' in globals() and globs is not None and self.glob_id < len(globs):\n            globs[self.glob_id] = None\n\n    ## Public ##\n\n    # def set_R(self, R):\n    #     \"\"\"\n    #     Store a reference to the training geometries.\n    #     This function is used to avoid unnecessary copies of the\n    #     traininig geometries when evaluation the training error\n    #     (= gradient of the model's loss function).\n\n    #     This routine is used during iterative model training.\n\n    #     Parameters\n    #     ----------\n    #     R : :obj:`numpy.ndarray`\n    #         Array containing the geometry for each training point.\n    #     \"\"\"\n\n    #     # Add singleton dimension if input is (,3N).\n    #     if R.ndim == 1:\n    #         R = R[None, :]\n\n    #     self.R = R\n\n    #     # if self.use_torch:\n    #     #     model = self.torch_predict\n    #     #     if isinstance(self.torch_predict, torch.nn.DataParallel):\n    #     #         model = model.module\n\n    #     #     R_torch = torch.from_numpy(R.reshape(-1, self.n_atoms, 3)).to(self.torch_device)\n    #     #     model.set_R(R_torch)\n\n    def set_R_desc(self, R_desc):\n        \"\"\"\n        Store a reference to the training geometry descriptors.\n\n        This can accelerate iterative model training.\n\n        Parameters\n        ----------\n            R_desc : :obj:`numpy.ndarray`, optional\n                    An 2D array of size M x D containing the\n                    descriptors of dimension D for M\n                    molecules.\n        \"\"\"\n\n        self.R_desc = R_desc\n\n    def set_R_d_desc(self, R_d_desc):\n        \"\"\"\n        Store a reference to the training geometry descriptor Jacobians.\n        This function must be called before `set_alphas()` can be used.\n\n        This routine is used during iterative model training.\n\n        Parameters\n        ----------\n            R_d_desc : :obj:`numpy.ndarray`, optional\n                    A 2D array of size M x D x 3N containing of the\n                    descriptor Jacobians for M molecules. The descriptor\n                    has dimension D with 3N partial derivatives with\n                    respect to the 3N Cartesian coordinates of each atom.\n        \"\"\"\n\n        self.R_d_desc = R_d_desc\n\n        if self.use_torch:\n            model = self.torch_predict\n            if isinstance(self.torch_predict, torch.nn.DataParallel):\n                model = model.module\n\n            model.set_R_d_desc(R_d_desc)\n\n    def set_alphas(self, alphas_F, alphas_E=None):\n        \"\"\"\n        Reconfigure the current model with a new set of regression parameters.\n        `R_d_desc` needs to be set for this function to work.\n\n        This routine is used during iterative model training.\n\n        Parameters\n        ----------\n                alphas_F : :obj:`numpy.ndarray`\n                    1D array containing the new model parameters.\n                alphas_E : :obj:`numpy.ndarray`, optional\n                    1D array containing the additional new model parameters, if\n                    energy constraints are used in the kernel (`use_E_cstr=True`)\n        \"\"\"\n\n        if self.use_torch:\n\n            model = self.torch_predict\n            if isinstance(self.torch_predict, torch.nn.DataParallel):\n                model = model.module\n\n            model.set_alphas(alphas_F, alphas_E=alphas_E)\n\n        else:\n\n            assert self.R_d_desc is not None\n\n            global globs\n            glob = globs[self.glob_id]\n\n            dim_i = self.desc.dim_i\n            R_d_desc_alpha = self.desc.d_desc_dot_vec(\n                self.R_d_desc, alphas_F.reshape(-1, dim_i)\n            )\n\n            R_d_desc_alpha_perms_new = np.tile(R_d_desc_alpha, glob['n_perms'])[\n                :, self.tril_perms_lin\n            ].reshape(self.n_train, glob['n_perms'], -1, order='F')\n\n            R_d_desc_alpha_perms = np.frombuffer(glob['R_d_desc_alpha_perms'])\n            np.copyto(R_d_desc_alpha_perms, R_d_desc_alpha_perms_new.ravel())\n\n            if alphas_E is not None:\n\n                alphas_E_lin_new = np.tile(\n                    alphas_E[:, None], (1, glob['n_perms'])\n                ).ravel()\n\n                alphas_E_lin = np.frombuffer(glob['alphas_E_lin'])\n                np.copyto(alphas_E_lin, alphas_E_lin_new)\n\n    def _set_num_workers(\n        self, num_workers=None, force_reset=False\n    ):  # TODO: complain if chunk or worker parameters do not fit training data (this causes issues with the caching)!!\n        \"\"\"\n        Set number of processes to use during prediction.\n\n        If bulk_mp == True, each worker handles the whole generation of single prediction (this if for querying multiple geometries at once)\n        If bulk_mp == False, each worker may handle only a part of a prediction (chunks are defined in 'wkr_starts_stops'). In that scenario multiple proesses\n        are used to distribute the work of generating a single prediction\n\n        This number should not exceed the number of available CPU cores.\n\n        Note\n        ----\n                This parameter can be optimally determined using\n                `prepare_parallel`.\n\n        Parameters\n        ----------\n                num_workers : int, optional\n                    Number of processes (maximum value is set if `None`).\n                force_reset : bool, optional\n                    Force applying the new setting.\n        \"\"\"\n\n        if force_reset or self.num_workers is not num_workers:\n\n            if self.pool is not None:\n                self.pool.terminate()\n                self.pool.join()\n                self.pool = None\n\n            self.num_workers = 0\n            if num_workers is None or num_workers > 0:\n                self.pool = Pool(num_workers)\n                self.num_workers = (\n                    self.pool._processes\n                )  # number of actual workers (not max_processes)\n\n        # Data ranges for processes\n        if self.bulk_mp or self.num_workers < 2:\n            # wkr_starts = [self.n_train]\n            wkr_starts = [0]\n        else:\n            wkr_starts = list(\n                range(\n                    0,\n                    self.n_train,\n                    int(np.ceil(float(self.n_train) / self.num_workers)),\n                )\n            )\n        wkr_stops = wkr_starts[1:] + [self.n_train]\n\n        self.wkr_starts_stops = list(zip(wkr_starts, wkr_stops))\n\n    def _set_chunk_size(self, chunk_size=None):\n\n        # TODO: complain if chunk or worker parameters do not fit training data (this causes issues with the caching)!!\n        \"\"\"\n        Set chunk size for each worker process.\n\n        Every prediction is generated as a linear combination of the training\n        points that the model is comprised of. If multiple workers are available\n        (and bulk mode is disabled), each one processes an (approximatelly equal)\n        part of those training points. Then, the chunk size determines how much of\n        a processes workload is passed to NumPy's underlying low-level routines at\n        once. If the chunk size is smaller than the number of points the worker is\n        supposed to process, it processes them in multiple steps using a loop. This\n        can sometimes be faster, depending on the available hardware.\n\n        Note\n        ----\n                This parameter can be optimally determined using\n                `prepare_parallel`.\n\n        Parameters\n        ----------\n                chunk_size : int\n                        Chunk size (maximum value is set if `None`).\n        \"\"\"\n\n        if chunk_size is None:\n            chunk_size = self.n_train\n\n        self.chunk_size = chunk_size\n\n    def _set_batch_size(self, batch_size=None):  # deprecated\n        \"\"\"\n\n        Warning\n        -------\n        Deprecated! Please use the function `_set_chunk_size` in future projects.\n\n        Set chunk size for each worker process. A chunk is a subset\n        of the training data points whose linear combination needs to\n        be evaluated in order to generate a prediction.\n\n        The chunk size determines how much of a processes workload will\n        be passed to Python's underlying low-level routines at once.\n        This parameter is highly hardware dependent.\n\n        Note\n        ----\n                This parameter can be optimally determined using\n                `prepare_parallel`.\n\n        Parameters\n        ----------\n                batch_size : int\n                        Chunk size (maximum value is set if `None`).\n        \"\"\"\n\n        self._set_chunk_size(batch_size)\n\n    def _set_bulk_mp(self, bulk_mp=False):\n        \"\"\"\n        Toggles bulk prediction mode.\n\n        If bulk prediction is enabled, the prediction is parallelized accross\n        input geometries, i.e. each worker generates the complete prediction for\n        one query. Otherwise (depending on the number of available CPU cores) the\n        input geometries are process sequentially, but every one of them may be\n        processed by multiple workers at once (in chunks).\n\n        Note\n        ----\n                This parameter can be optimally determined using\n                `prepare_parallel`.\n\n        Parameters\n        ----------\n                bulk_mp : bool, optional\n                        Enable or disable bulk prediction mode.\n        \"\"\"\n\n        bulk_mp = bool(bulk_mp)\n        if self.bulk_mp is not bulk_mp:\n            self.bulk_mp = bulk_mp\n\n            # Reset data ranges for processes stored in 'wkr_starts_stops'\n            self._set_num_workers(self.num_workers)\n\n    def set_opt_num_workers_and_batch_size_fast(self, n_bulk=1, n_reps=1):  # deprecated\n        \"\"\"\n        Warning\n        -------\n        Deprecated! Please use the function `prepare_parallel` in future projects.\n\n        Parameters\n        ----------\n                n_bulk : int, optional\n                        Number of geometries that will be passed to the\n                        `predict` function in each call (performance\n                        will be optimized for that exact use case).\n                n_reps : int, optional\n                        Number of repetitions (bigger value: more\n                        accurate, but also slower).\n\n        Returns\n        -------\n                int\n                        Force and energy prediciton speed in geometries\n                        per second.\n        \"\"\"\n\n        self.prepare_parallel(n_bulk, n_reps)\n\n    def prepare_parallel(\n        self, n_bulk=1, n_reps=1, return_is_from_cache=False\n    ):  # noqa: C901\n        \"\"\"\n        Find and set the optimal parallelization parameters for the\n        currently loaded model, running on a particular system. The result\n        also depends on the number of geometries `n_bulk` that will be\n        passed at once when calling the `predict` function.\n\n        This function runs a benchmark in which the prediction routine is\n        repeatedly called `n_reps`-times (default: 1) with varying parameter\n        configurations, while the runtime is measured for each one. The\n        optimal parameters are then cached for fast retrival in future\n        calls of this function.\n\n        We recommend calling this function after initialization of this\n        class, as it will drastically increase the performance of the\n        `predict` function.\n\n        Note\n        ----\n                Depending on the parameter `n_reps`, this routine may take\n                some seconds/minutes to complete. However, once a\n                statistically significant number of benchmark results has\n                been gathered for a particular configuration, it starts\n                returning almost instantly.\n\n        Parameters\n        ----------\n                n_bulk : int, optional\n                        Number of geometries that will be passed to the\n                        `predict` function in each call (performance\n                        will be optimized for that exact use case).\n                n_reps : int, optional\n                        Number of repetitions (bigger value: more\n                        accurate, but also slower).\n                return_is_from_cache : bool, optional\n                        If enabled, this function returns a second value\n                        indicating if the returned results were obtained\n                        from cache.\n\n        Returns\n        -------\n                int\n                        Force and energy prediciton speed in geometries\n                        per second.\n                boolean, optional\n                        Return, whether this function obtained the results\n                        from cache.\n        \"\"\"\n\n        # global globs\n        # glob = globs[self.glob_id]\n        # n_perms = glob['n_perms']\n\n        # No benchmarking necessary if prediction is running on GPUs.\n        if self.use_torch:\n            self.log.info(\n                'Skipping multi-CPU benchmark, since torch is enabled.'\n            )  # TODO: clarity!\n            return\n\n        # Retrieve cached benchmark results, if available.\n        bmark_result = self._load_cached_bmark_result(n_bulk)\n        if bmark_result is not None:\n\n            num_workers, chunk_size, bulk_mp, gps = bmark_result\n\n            self._set_chunk_size(chunk_size)\n            self._set_num_workers(num_workers)\n            self._set_bulk_mp(bulk_mp)\n\n            if return_is_from_cache:\n                is_from_cache = True\n                return gps, is_from_cache\n            else:\n                return gps\n\n        warm_up_done = False\n\n        best_results = []\n        last_i = None\n\n        best_gps = 0\n        gps_min = 0.0\n\n        best_params = None\n\n        r_dummy = np.random.rand(n_bulk, self.n_atoms * 3)\n\n        def _dummy_predict():\n            self.predict(r_dummy)\n\n        bulk_mp_rng = [True, False] if n_bulk > 1 else [False]\n        for bulk_mp in bulk_mp_rng:\n            self._set_bulk_mp(bulk_mp)\n\n            if bulk_mp is False:\n                last_i = 0\n\n            num_workers_rng = list(range(0, self.max_processes))\n            if bulk_mp:\n                num_workers_rng.reverse()  # benchmark converges faster this way\n\n            # num_workers_rng_sizes = [batch_size for batch_size in batch_size_rng if min_batch_size % batch_size == 0]\n\n            # for num_workers in range(min_num_workers,self.max_processes+1):\n            for num_workers in num_workers_rng:\n                if not bulk_mp and num_workers != 0 and self.n_train % num_workers != 0:\n                    continue\n\n                self._set_num_workers(num_workers)\n\n                best_gps = 0\n                gps_rng = (np.inf, 0.0)  # min and max per num_workers\n\n                min_chunk_size = (\n                    min(self.n_train, n_bulk)\n                    if bulk_mp or num_workers < 2\n                    else int(np.ceil(self.n_train / num_workers))\n                )\n                chunk_size_rng = list(range(min_chunk_size, 0, -1))\n\n                chunk_size_rng_sizes = [\n                    chunk_size\n                    for chunk_size in chunk_size_rng\n                    if min_chunk_size % chunk_size == 0\n                ]\n\n                # print('batch_size_rng_sizes ' + str(bulk_mp))\n                # print(batch_size_rng_sizes)\n\n                i_done = 0\n                i_dir = 1\n                i = 0 if last_i is None else last_i\n                # i = 0\n\n                # print(batch_size_rng_sizes)\n                while i >= 0 and i < len(chunk_size_rng_sizes):\n\n                    chunk_size = chunk_size_rng_sizes[i]\n                    self._set_chunk_size(chunk_size)\n\n                    i_done += 1\n\n                    if warm_up_done == False:\n                        timeit.timeit(_dummy_predict, number=10)\n                        warm_up_done = True\n\n                    gps = n_bulk * n_reps / timeit.timeit(_dummy_predict, number=n_reps)\n\n                    # print(\n                    #  '{:2d}@{:d} {:d} | {:7.2f} gps'.format(\n                    #      num_workers, chunk_size, bulk_mp, gps\n                    #  )\n                    # )\n\n                    gps_rng = (\n                        min(gps_rng[0], gps),\n                        max(gps_rng[1], gps),\n                    )  # min and max per num_workers\n\n                    # gps_min_max = min(gps_min_max[0], gps), max(gps_min_max[1], gps)\n\n                    # print('     best_gps ' + str(best_gps))\n\n                    # NEW\n\n                    # if gps > best_gps and gps > gps_min: # gps is still going up, everything is good\n                    #     best_gps = gps\n                    #     best_params = num_workers, batch_size, bulk_mp\n                    # else:\n                    #     break\n\n                    # if gps > best_gps: # gps is still going up, everything is good\n                    #     best_gps = gps\n                    #     best_params = num_workers, batch_size, bulk_mp\n                    # else: # gps did not go up wrt. to previous step\n\n                    #     # can we switch the search direction?\n                    #     #   did we already?\n                    #     #   we checked two consecutive configurations\n                    #     #   are bigger batch sizes possible?\n\n                    #     print(batch_size_rng_sizes)\n\n                    #     turn_search_dir = i_dir > 0 and i_done == 2 and batch_size != batch_size_rng_sizes[1]\n\n                    #     # only turn, if the current gps is not lower than the lowest overall\n                    #     if turn_search_dir and gps >= gps_min:\n                    #         i -= 2 * i_dir\n                    #         i_dir = -1\n                    #         print('><')\n                    #         continue\n                    #     else:\n                    #         print('>>break ' + str(i_done))\n                    #         break\n\n                    # NEW\n\n                    # gps still going up?\n                    # AND: gps not lower than the lowest overall?\n                    # if gps < best_gps and gps >= gps_min:\n                    if gps < best_gps:\n                        if (\n                            i_dir > 0\n                            and i_done == 2\n                            and chunk_size\n                            != chunk_size_rng_sizes[\n                                1\n                            ]  # there is no point in turning if this is the second batch size in the range\n                        ):  # do we turn?\n                            i -= 2 * i_dir\n                            i_dir = -1\n                            # print('><')\n                            continue\n                        else:\n                            if chunk_size == chunk_size_rng_sizes[1]:\n                                i -= 1 * i_dir\n                            # print('>>break ' + str(i_done))\n                            break\n                    else:\n                        best_gps = gps\n                        best_params = num_workers, chunk_size, bulk_mp\n\n                    if (\n                        not bulk_mp and n_bulk > 1\n                    ):  # stop search early when multiple cpus are available and the 1 cpu case is tested\n                        if (\n                            gps < gps_min\n                        ):  # if the batch size run is lower than the lowest overall, stop right here\n                            # print('breaking here')\n                            break\n\n                    i += 1 * i_dir\n\n                last_i = i - 1 * i_dir\n                i_dir = 1\n\n                if len(best_results) > 0:\n                    overall_best_gps = max(best_results, key=lambda x: x[1])[1]\n                    if best_gps < overall_best_gps:\n                        # print('breaking, because best of last test was worse than overall best so far')\n                        break\n\n                    # if best_gps < gps_min:\n                    #    print('breaking here3')\n                    #    break\n\n                gps_min = gps_rng[0]  # FIX me: is this the overall min?\n                # print ('gps_min ' + str(gps_min))\n\n                # print ('best_gps')\n                # print (best_gps)\n\n                best_results.append(\n                    (best_params, best_gps)\n                )  # best results per num_workers\n\n        (num_workers, chunk_size, bulk_mp), gps = max(best_results, key=lambda x: x[1])\n\n        # Cache benchmark results.\n        self._save_cached_bmark_result(n_bulk, num_workers, chunk_size, bulk_mp, gps)\n\n        self._set_chunk_size(chunk_size)\n        self._set_num_workers(num_workers)\n        self._set_bulk_mp(bulk_mp)\n\n        if return_is_from_cache:\n            is_from_cache = False\n            return gps, is_from_cache\n        else:\n            return gps\n\n    def _save_cached_bmark_result(self, n_bulk, num_workers, chunk_size, bulk_mp, gps):\n\n        pkg_dir = os.path.dirname(os.path.abspath(__file__))\n        bmark_file = '_bmark_cache.npz'\n        bmark_path = os.path.join(pkg_dir, bmark_file)\n\n        bkey = '{}-{}-{}-{}'.format(\n            self.n_atoms, self.n_train, n_bulk, self.max_processes\n        )\n\n        if os.path.exists(bmark_path):\n\n            with np.load(bmark_path, allow_pickle=True) as bmark:\n                bmark = dict(bmark)\n\n                bmark['runs'] = np.append(bmark['runs'], bkey)\n                bmark['num_workers'] = np.append(bmark['num_workers'], num_workers)\n                bmark['batch_size'] = np.append(bmark['batch_size'], chunk_size)\n                bmark['bulk_mp'] = np.append(bmark['bulk_mp'], bulk_mp)\n                bmark['gps'] = np.append(bmark['gps'], gps)\n        else:\n            bmark = {\n                'code_version': __version__,\n                'runs': [bkey],\n                'gps': [gps],\n                'num_workers': [num_workers],\n                'batch_size': [chunk_size],\n                'bulk_mp': [bulk_mp],\n            }\n\n        np.savez_compressed(bmark_path, **bmark)\n\n    def _load_cached_bmark_result(self, n_bulk):\n\n        pkg_dir = os.path.dirname(os.path.abspath(__file__))\n        bmark_file = '_bmark_cache.npz'\n        bmark_path = os.path.join(pkg_dir, bmark_file)\n\n        bkey = '{}-{}-{}-{}'.format(\n            self.n_atoms, self.n_train, n_bulk, self.max_processes\n        )\n\n        if not os.path.exists(bmark_path):\n            return None\n\n        with np.load(bmark_path, allow_pickle=True) as bmark:\n\n            # Keep collecting benchmark runs, until we have at least three.\n            run_idxs = np.where(bmark['runs'] == bkey)[0]\n            if len(run_idxs) >= 3:\n\n                config_keys = []\n                for run_idx in run_idxs:\n                    config_keys.append(\n                        '{}-{}-{}'.format(\n                            bmark['num_workers'][run_idx],\n                            bmark['batch_size'][run_idx],\n                            bmark['bulk_mp'][run_idx],\n                        )\n                    )\n\n                values, uinverse = np.unique(config_keys, return_index=True)\n\n                best_mean = -1\n                best_gps = 0\n                for i, config_key in enumerate(zip(values, uinverse)):\n                    mean_gps = np.mean(\n                        bmark['gps'][\n                            np.where(np.array(config_keys) == config_key[0])[0]\n                        ]\n                    )\n\n                    if best_gps == 0 or best_gps < mean_gps:\n                        best_mean = i\n                        best_gps = mean_gps\n\n                best_idx = run_idxs[uinverse[best_mean]]\n                num_workers = bmark['num_workers'][best_idx]\n                chunk_size = bmark['batch_size'][best_idx]\n                bulk_mp = bmark['bulk_mp'][best_idx]\n\n                return num_workers, chunk_size, bulk_mp, best_gps\n\n        return None\n\n    def get_GPU_batch(self):\n        \"\"\"\n        Get batch size used by the GPU implementation to process bulk\n        predictions (predictions for multiple input geometries at once).\n\n        This value is determined on-the-fly depending on the available GPU\n        memory.\n        \"\"\"\n\n        if self.use_torch:\n\n            model = self.torch_predict\n            if isinstance(model, torch.nn.DataParallel):\n                model = model.module\n\n            return model._batch_size()\n\n    def predict(self, R=None, return_E=True):\n        \"\"\"\n        Predict energy and forces for multiple geometries. This function\n        can run on the GPU, if the optional PyTorch dependency is\n        installed and `use_torch=True` was speciefied during\n        initialization of this class.\n\n        Optionally, the descriptors and descriptor Jacobians for the\n        same geometries can be provided, if already available from some\n        previous calculations.\n\n        Note\n        ----\n                The order of the atoms in `R` is not arbitrary and must\n                be the same as used for training the model.\n\n        Parameters\n        ----------\n                R : :obj:`numpy.ndarray`, optional\n                        An 2D array of size M x 3N containing the\n                        Cartesian coordinates of each atom of M\n                        molecules. If this parameter is ommited, the training\n                        error is returned. Note that the training geometries\n                        need to be set right after initialization using\n                        `set_R()` for this to work.\n                return_E : boolean, optional\n                        If false (default: true), only the forces are returned.\n\n        Returns\n        -------\n                :obj:`numpy.ndarray`\n                        Energies stored in an 1D array of size M (unless `return_E == False`)\n                :obj:`numpy.ndarray`\n                        Forces stored in an 2D arry of size M x 3N.\n        \"\"\"\n\n        # Add singleton dimension if input is (,3N).\n        if R is not None and R.ndim == 1:\n            R = R[None, :]\n\n        if self.use_torch:  # multi-GPU (or CPU if no GPUs are available)\n\n            R_torch = torch.arange(self.n_train)\n            if R is None:\n                if self.R_d_desc is None:\n                    self.log.critical(\n                        'A reference to the training geometry descriptors needs to be set (using \\'set_R_d_desc()\\') for this function to work without arguments (using PyTorch).'\n                    )\n                    print()\n                    os._exit(1)\n            else:\n                R_torch = (\n                    torch.from_numpy(R.reshape(-1, self.n_atoms, 3))\n                    .type(torch.float32)\n                    .to(self.torch_device)\n                )\n\n            model = self.torch_predict\n            if R_torch.shape[0] < torch.cuda.device_count() and isinstance(\n                model, torch.nn.DataParallel\n            ):\n                model = self.torch_predict.module\n            E_torch_F_torch = model.forward(R_torch, return_E=return_E)\n\n            if return_E:\n                E_torch, F_torch = E_torch_F_torch\n                E = E_torch.cpu().numpy()\n            else:\n                (F_torch,) = E_torch_F_torch\n\n            F = F_torch.cpu().numpy().reshape(-1, 3 * self.n_atoms)\n\n        else:  # multi-CPU\n\n            # Use precomputed descriptors in training mode.\n            is_desc_in_cache = self.R_desc is not None and self.R_d_desc is not None\n\n            if R is None and not is_desc_in_cache:\n                self.log.critical(\n                    'A reference to the training geometry descriptors and Jacobians needs to be set for this function to work without arguments.'\n                )\n                print()\n                os._exit(1)\n\n            assert is_desc_in_cache or R is not None\n\n            dim_i = 3 * self.n_atoms\n            n_pred = self.R_desc.shape[0] if R is None else R.shape[0]\n\n            E_F = np.empty((n_pred, dim_i + 1))\n\n            if (\n                self.bulk_mp and self.num_workers > 0\n            ):  # One whole prediction per worker (and multiple workers).\n\n                _predict_wo_r_or_desc = partial(\n                    _predict_wkr,\n                    lat_and_inv=self.lat_and_inv,\n                    glob_id=self.glob_id,\n                    wkr_start_stop=None,\n                    chunk_size=self.chunk_size,\n                )\n\n                for i, e_f in enumerate(\n                    self.pool.imap(\n                        partial(_predict_wo_r_or_desc, None)\n                        if is_desc_in_cache\n                        else partial(_predict_wo_r_or_desc, r_desc_d_desc=None),\n                        zip(self.R_desc, self.R_d_desc) if is_desc_in_cache else R,\n                    )\n                ):\n                    E_F[i, :] = e_f\n\n            else:  # Multiple workers per prediction (or just one worker).\n\n                for i in range(n_pred):\n\n                    if is_desc_in_cache:\n                        r_desc, r_d_desc = self.R_desc[i], self.R_d_desc[i]\n                    else:\n                        r_desc, r_d_desc = self.desc.from_R(R[i], self.lat_and_inv)\n\n                    _predict_wo_wkr_starts_stops = partial(\n                        _predict_wkr,\n                        None,\n                        (r_desc, r_d_desc),\n                        self.lat_and_inv,\n                        self.glob_id,\n                        chunk_size=self.chunk_size,\n                    )\n\n                    if self.num_workers == 0:\n                        E_F[i, :] = _predict_wo_wkr_starts_stops()\n                    else:\n                        E_F[i, :] = sum(\n                            self.pool.imap_unordered(\n                                _predict_wo_wkr_starts_stops, self.wkr_starts_stops\n                            )\n                        )\n\n            E_F *= self.std\n            F = E_F[:, 1:]\n            E = E_F[:, 0] + self.c\n\n        ret = (F,)\n        if return_E:\n            ret = (E,) + ret\n\n        return ret\n"
  },
  {
    "path": "sgdml/solvers/__init__.py",
    "content": ""
  },
  {
    "path": "sgdml/solvers/analytic.py",
    "content": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2020-2022 Stefan Chmiela\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nimport sys\nimport logging\nimport warnings\nfrom functools import partial\n\nimport numpy as np\nimport scipy as sp\nimport timeit\n\nfrom .. import DONE, NOT_DONE\n\n\nclass Analytic(object):\n    def __init__(self, gdml_train, desc, callback=None):\n\n        self.log = logging.getLogger(__name__)\n\n        self.gdml_train = gdml_train\n        self.desc = desc\n\n        self.callback = callback\n\n    # from memory_profiler import profile\n    # @profile\n    def solve(self, task, R_desc, R_d_desc, tril_perms_lin, y):\n\n        sig = task['sig']\n        lam = task['lam']\n        use_E_cstr = task['use_E_cstr']\n\n        n_train, dim_d = R_d_desc.shape[:2]\n        n_atoms = int((1 + np.sqrt(8 * dim_d + 1)) / 2)\n        dim_i = 3 * n_atoms\n\n        if self.callback is not None:\n            self.callback = partial(\n                self.callback,\n                disp_str='Assembling kernel matrix',\n            )\n\n        K = -self.gdml_train._assemble_kernel_mat(\n            R_desc,\n            R_d_desc,\n            tril_perms_lin,\n            sig,\n            self.desc,\n            use_E_cstr=use_E_cstr,\n            callback=self.callback,\n        )  # Flip sign to make convex\n\n        start = timeit.default_timer()\n\n        with warnings.catch_warnings():\n            warnings.simplefilter('ignore')\n\n            if K.shape[0] == K.shape[1]:\n\n                K[np.diag_indices_from(K)] += lam  # Regularize\n\n                if self.callback is not None:\n                    self.callback = partial(\n                        self.callback,\n                        disp_str='Solving linear system (Cholesky factorization)',\n                    )\n                    self.callback(NOT_DONE)\n\n                try:\n\n                    # Cholesky (do not overwrite K in case we need to retry)\n                    L, lower = sp.linalg.cho_factor(\n                        K, overwrite_a=False, check_finite=False\n                    )\n                    alphas = -sp.linalg.cho_solve(\n                        (L, lower), y, overwrite_b=False, check_finite=False\n                    )\n\n                except np.linalg.LinAlgError:  # Try a solver that makes less assumptions\n\n                    if self.callback is not None:\n                        self.callback = partial(\n                            self.callback,\n                            disp_str='Solving linear system (LU factorization)      ',  # Keep whitespaces!\n                        )\n                        self.callback(NOT_DONE)\n\n                    try:\n                        # LU\n                        alphas = -sp.linalg.solve(\n                            K, y, overwrite_a=True, overwrite_b=True, check_finite=False\n                        )\n                    except MemoryError:\n                        self.log.critical(\n                            'Not enough memory to train this system using a closed form solver.'\n                        )\n                        print()\n                        os._exit(1)\n\n                except MemoryError:\n                    self.log.critical(\n                        'Not enough memory to train this system using a closed form solver.'\n                    )\n                    print()\n                    os._exit(1)\n            else:\n\n                if self.callback is not None:\n                    self.callback = partial(\n                        self.callback,\n                        disp_str='Solving over-determined linear system (least squares approximation)',\n                    )\n                    self.callback(NOT_DONE)\n\n                # Least squares for non-square K\n                alphas = -np.linalg.lstsq(K, y, rcond=-1)[0]\n\n        stop = timeit.default_timer()\n\n        if self.callback is not None:\n            dur_s = stop - start\n            sec_disp_str = 'took {:.1f} s'.format(dur_s) if dur_s >= 0.1 else ''\n            self.callback(\n                DONE,\n                disp_str='Training on {:,} points'.format(n_train),\n                sec_disp_str=sec_disp_str,\n            )\n\n        return alphas\n\n    @staticmethod\n    def est_memory_requirement(n_train, n_atoms):\n\n        est_bytes = 3 * (n_train * 3 * n_atoms) ** 2 * 8  # K + factor(s) of K\n        est_bytes += (n_train * 3 * n_atoms) * 8  # alpha\n\n        return est_bytes\n"
  },
  {
    "path": "sgdml/solvers/iterative.py",
    "content": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2020-2025 Stefan Chmiela\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nimport os\nimport logging\nfrom functools import partial\nimport inspect\nimport multiprocessing as mp\n\nimport numpy as np\nimport scipy as sp\nimport timeit\nimport collections\n\nfrom .. import DONE, NOT_DONE\nfrom ..utils import ui\nfrom ..predict import GDMLPredict\n\ntry:\n    import torch\nexcept ImportError:\n    _has_torch = False\nelse:\n    _has_torch = True\n\n\nCG_STEPS_HIST_LEN = (\n    100  # number of past steps to consider when calculatating solver effectiveness\n)\nEFF_RESTART_THRESH = 0  # if solver effectiveness is less than that percentage after 'CG_STEPS_HIST_LEN'-steps, a solver restart is triggert (with stronger preconditioner)\n\nMAX_NUM_RESTARTS = 6\n\n\nclass CGRestartException(Exception):\n    pass\n\n\nclass Iterative(object):\n    def __init__(\n        self,\n        gdml_train,\n        desc,\n        max_memory,\n        max_processes,\n        use_torch,\n        callback=None,\n    ):\n\n        self.log = logging.getLogger(__name__)\n\n        self.gdml_train = gdml_train\n        self.gdml_predict = None\n        self.desc = desc\n\n        self.callback = callback\n\n        self._max_memory = max_memory\n        self._max_processes = max_processes\n        self._use_torch = use_torch\n\n    def _init_precon_operator(\n        self, task, R_desc, R_d_desc, tril_perms_lin, inducing_pts_idxs, callback=None\n    ):\n\n        lam = task['lam']\n        lam_inv = 1.0 / lam\n\n        sig = task['sig']\n\n        use_E_cstr = task['use_E_cstr']\n\n        L_inv_K_mn = self._nystroem_cholesky_factor(\n            R_desc,\n            R_d_desc,\n            tril_perms_lin,\n            sig,\n            lam,\n            use_E_cstr=use_E_cstr,\n            col_idxs=inducing_pts_idxs,\n            callback=callback,\n        )\n\n        L_inv_K_mn = np.ascontiguousarray(L_inv_K_mn)\n\n        lev_scores = np.einsum(\n            'i...,i...->...', L_inv_K_mn, L_inv_K_mn\n        )  # compute leverage scores because it is basically free once we got the factor\n\n        m, n = L_inv_K_mn.shape\n\n        if self._use_torch and False:  # TURNED OFF!\n            _torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'\n            L_inv_K_mn_torch = torch.from_numpy(L_inv_K_mn).to(_torch_device)\n\n        global is_primed\n        is_primed = False\n\n        def _P_vec(v):\n\n            global is_primed\n            if not is_primed:\n                is_primed = True\n                return v\n\n            if self._use_torch and False:  # TURNED OFF!\n\n                v_torch = torch.from_numpy(v).to(_torch_device)[:, None]\n                return (\n                    L_inv_K_mn_torch.t().mm(L_inv_K_mn_torch.mm(v_torch)) - v_torch\n                ).cpu().numpy() * lam_inv\n\n            else:\n\n                ret = L_inv_K_mn.T.dot(L_inv_K_mn.dot(v))\n                ret -= v\n                ret *= lam_inv\n\n                return ret\n\n        return sp.sparse.linalg.LinearOperator((n, n), matvec=_P_vec), lev_scores\n\n    def _init_kernel_operator(\n        self, task, R_desc, R_d_desc, tril_perms_lin, lam, n, callback=None\n    ):\n\n        n_train = R_desc.shape[0]\n\n        # dummy alphas\n        v_F = np.zeros((n - n_train, 1)) if task['use_E_cstr'] else np.zeros((n, 1))\n        v_E = np.zeros((n_train, 1)) if task['use_E_cstr'] else None\n\n        # Note: The standard deviation is set to 1.0, because we are predicting normalized labels here.\n        model = self.gdml_train.create_model(\n            task, 'cg', R_desc, R_d_desc, tril_perms_lin, 1.0, v_F, alphas_E=v_E\n        )\n\n        self.gdml_predict = GDMLPredict(\n            model,\n            max_memory=self._max_memory,\n            max_processes=self._max_processes,\n            use_torch=self._use_torch,\n        )\n\n        self.gdml_predict.set_R_desc(R_desc)  # only needed on CPU\n        self.gdml_predict.set_R_d_desc(R_d_desc)\n\n        if not self._use_torch:\n\n            if callback is not None:\n                callback = partial(callback, disp_str='Optimizing CPU parallelization')\n                callback(NOT_DONE)\n\n            self.gdml_predict.prepare_parallel(n_bulk=n_train)\n\n            if callback is not None:\n                callback(DONE)\n\n        global is_primed\n        is_primed = False\n\n        def _K_vec(v):\n\n            global is_primed\n            if not is_primed:\n                is_primed = True\n                return v\n\n            v_F, v_E = v, None\n            if task['use_E_cstr']:\n                v_F, v_E = v[:-n_train], v[-n_train:]\n\n            self.gdml_predict.set_alphas(v_F, alphas_E=v_E)\n\n            pred = self.gdml_predict.predict(return_E=task['use_E_cstr'])\n            if task['use_E_cstr']:\n                e_pred, f_pred = pred\n                pred = np.hstack((f_pred.ravel(), -e_pred))\n            else:\n                pred = pred[0].ravel()\n\n            pred -= lam * v\n            return pred\n\n        return sp.sparse.linalg.LinearOperator((n, n), matvec=_K_vec)\n\n    def _nystroem_cholesky_factor(\n        self,\n        R_desc,\n        R_d_desc,\n        tril_perms_lin,\n        sig,\n        lam,\n        use_E_cstr,\n        col_idxs,\n        callback_task_name='',\n        callback=None,\n    ):\n\n        if callback_task_name != '':\n            callback_task_name = ' ({})'.format(callback_task_name)\n\n        if callback is not None:\n            callback = partial(\n                callback,\n                disp_str='Assembling kernel [m x k]{}'.format(callback_task_name),\n            )\n\n        dim_d = R_desc.shape[1]\n        n_atoms = int((1 + np.sqrt(8 * dim_d + 1)) / 2)\n        n = R_desc.shape[0] * n_atoms * 3 + (R_desc.shape[0] if use_E_cstr else 0)\n        m = len(\n            range(*col_idxs.indices(n)) if isinstance(col_idxs, slice) else col_idxs\n        )\n\n        K_nmm = self.gdml_train._assemble_kernel_mat(\n            R_desc,\n            R_d_desc,\n            tril_perms_lin,\n            sig,\n            self.desc,\n            use_E_cstr=use_E_cstr,\n            col_idxs=col_idxs,\n            alloc_extra_rows=m,\n            callback=callback,\n        )\n\n        # Store (psd) copy of K_mm in lower part of this oversized K_(n+m)m matrix.\n        K_nmm[-m:, :] = -K_nmm[col_idxs, :]\n\n        K_nm = K_nmm[:-m, :]\n        K_mm = K_nmm[-m:, :]\n\n        if callback is not None:\n            callback = partial(\n                callback,\n                disp_str='Cholesky fact. (1/2) [k x k]{}'.format(callback_task_name),\n            )\n            callback(NOT_DONE)\n\n        # Additional regularization is almost always necessary here (hence pre_reg=True).\n        K_mm, lower = self._cho_factor_stable(K_mm, pre_reg=True)  # overwrites input!\n        L_mm = K_mm\n        # del K_mm\n\n        if callback is not None:\n            callback(DONE)\n            callback = partial(\n                callback,\n                disp_str='m tri. solves (1/2) [k x k]{}'.format(callback_task_name),\n            )\n            callback(0, n)\n\n        b_start, b_size = 0, int(n / 4)  # update in percentage steps of 25\n        for b_stop in list(range(b_size, n, b_size)) + [n]:\n\n            K_nm[b_start:b_stop, :] = sp.linalg.solve_triangular(\n                L_mm,\n                K_nm[b_start:b_stop, :].T,\n                lower=lower,\n                trans='T',\n                overwrite_b=True,\n                check_finite=False,\n            ).T\n            b_start = b_stop\n\n            if callback is not None:\n                callback(b_stop, n)\n\n        del L_mm\n\n        K_nmm[-m:, :] = K_nm.T.dot(K_nm)\n        K_nmm[-m:, :][np.diag_indices_from(K_nmm[-m:, :])] += lam\n        inner = K_nmm[-m:, :]\n\n        if callback is not None:\n            callback = partial(\n                callback,\n                disp_str='Cholesky fact. (2/2) [k x k]{}'.format(callback_task_name),\n            )\n            callback(NOT_DONE)\n\n        L_lower = self._cho_factor_stable(\n            inner, eps_mag_max=-14\n        )  # Do not regularize more than 1e-14.\n        if L_lower is not None:\n            K_nmm[-m:, :], lower = L_lower\n            L = K_nmm[-m:, :]\n            del inner\n        else:\n\n            callback = partial(\n                callback,\n                disp_str='QR fact. (alt.) [k x k]{}'.format(callback_task_name),\n            )\n            callback(NOT_DONE)\n\n            K_nmm[-m:, :] = 0\n            K_nmm[-m:, :][np.diag_indices(m)] = np.sqrt(lam)\n\n            K_nmm[-m:, :] = np.linalg.qr(K_nmm, mode='r')\n            L = K_nmm[-m:, :]\n            lower = False\n\n        if callback is not None:\n            callback(DONE)\n            callback = partial(\n                callback,\n                disp_str='m tri. solves (2/2) [k x k]{}'.format(callback_task_name),\n            )\n            callback(0, n)\n\n        b_start, b_size = 0, int(n / 4)  # update in percentage steps of 25\n        for b_stop in list(range(b_size, n, b_size)) + [n]:\n\n            K_nm[b_start:b_stop, :] = sp.linalg.solve_triangular(\n                L,\n                K_nm[b_start:b_stop, :].T,\n                lower=lower,\n                trans='T',\n                overwrite_b=True,\n                check_finite=False,\n            ).T  # Note: Overwrites K_nm to save memory\n            b_start = b_stop\n\n            if callback is not None:\n                callback(b_stop, n)\n        del L\n\n        return K_nm.T\n\n    def _lev_scores(\n        self,\n        R_desc,\n        R_d_desc,\n        tril_perms_lin,\n        sig,\n        lam,\n        use_E_cstr,\n        n_inducing_pts,\n        callback=None,\n    ):\n\n        n_train, dim_d = R_d_desc.shape[:2]\n        dim_i = 3 * int((1 + np.sqrt(8 * dim_d + 1)) / 2)\n\n        # Convert from training points to actual columns.\n        # dim_m = (\n        #    np.maximum(1, n_inducing_pts // 4) * dim_i\n        # )  # only use 1/4 of inducing points for leverage score estimate\n        dim_m = dim_i * min(n_inducing_pts, 10)\n\n        # Which columns to use for leverage score approximation?\n        lev_approx_idxs = np.sort(\n            np.random.choice(\n                n_train * dim_i + (n_train if use_E_cstr else 0), dim_m, replace=False\n            )\n        )  # random subset of columns\n        # lev_approx_idxs = np.sort(np.random.choice(n_train*dim_i, dim_m, replace=False)) # random subset of columns\n\n        # lev_approx_idxs = np.s_[\n        #    :dim_m\n        # ]  # first 'dim_m' columns (faster kernel construction)\n\n        L_inv_K_mn = self._nystroem_cholesky_factor(\n            R_desc,\n            R_d_desc,\n            tril_perms_lin,\n            sig,\n            lam,\n            use_E_cstr=use_E_cstr,\n            col_idxs=lev_approx_idxs,\n            callback_task_name='lev. scores',\n            callback=callback,\n        )\n\n        lev_scores = np.einsum('i...,i...->...', L_inv_K_mn, L_inv_K_mn)\n        return lev_scores\n\n    def inducing_pts_from_lev_scores(self, lev_scores, N):\n\n        # Sample 'N' columns with probabilities proportional to the leverage scores.\n        inducing_pts_idxs = np.random.choice(\n            np.arange(lev_scores.size),\n            N,\n            replace=False,\n            p=lev_scores / lev_scores.sum(),\n        )\n\n        return np.sort(inducing_pts_idxs)\n\n    # performs a cholesky decompostion of a matrix, but regularizes the matrix (if neeeded) until its positive definite\n    def _cho_factor_stable(self, M, pre_reg=False, eps_mag_max=1):\n        \"\"\"\n        Performs a Cholesky decompostion of a matrix, but regularizes\n        as needed until its positive definite.\n\n        Parameters\n        ----------\n            M : :obj:`numpy.ndarray`\n                Matrix to factorize.\n            pre_reg : boolean, optional\n                Regularize M right away (machine precision), before\n                trying to factorize it (default: False).\n\n        Returns\n        -------\n            :obj:`numpy.ndarray`\n                Matrix whose upper or lower triangle contains the Cholesky factor of a. Other parts of the matrix contain random data.\n            boolean\n                Flag indicating whether the factor is in the lower or upper triangle\n        \"\"\"\n\n        eps = np.finfo(float).eps\n        eps_mag = int(np.floor(np.log10(eps)))\n\n        if pre_reg:\n            M[np.diag_indices_from(M)] += eps\n            eps_mag += 1  # if additional regularization is necessary, start from the next order of magnitude\n\n        for reg in 10.0 ** np.arange(\n            eps_mag, eps_mag_max + 1\n        ):  # regularize more and more aggressively (strongest regularization: 1)\n            try:\n\n                L, lower = sp.linalg.cho_factor(\n                    M, overwrite_a=False, check_finite=False\n                )\n\n            except np.linalg.LinAlgError as e:\n\n                if 'not positive definite' in str(e):\n                    self.log.debug(\n                        'Cholesky solver needs more aggressive regularization (adding {} to diagonal)'.format(\n                            reg\n                        )\n                    )\n                    M[np.diag_indices_from(M)] += reg\n                else:\n                    raise e\n            else:\n                return L, lower\n\n        self.log.critical(\n            'Failed to factorize despite strong regularization (max: {})!\\nYou could try a larger sigma.'.format(\n                10.0**eps_mag_max\n            )\n        )\n        print()\n        os._exit(1)\n\n    def solve(\n        self,\n        task,\n        R_desc,\n        R_d_desc,\n        tril_perms_lin,\n        y,\n        y_std,\n        tol=1e-4,\n        save_progr_callback=None,\n    ):\n\n        global num_iters, start, resid, avg_tt, m  # , P_t\n\n        n_train, n_atoms = task['R_train'].shape[:2]\n        dim_i = 3 * n_atoms\n\n        sig = task['sig']\n        lam = task['lam']\n\n        # these keys are only present if the task was created from an existing model\n        alphas0_F = task['alphas0_F'] if 'alphas0_F' in task else None\n        alphas0_E = task['alphas0_E'] if 'alphas0_E' in task else None\n        num_iters0 = task['solver_iters'] if 'solver_iters' in task else 0\n\n        # Number of inducing points to use for Nystrom approximation.\n        max_memory_bytes = self._max_memory * 1024**3\n        max_n_inducing_pts = Iterative.max_n_inducing_pts(\n            n_train, n_atoms, max_memory_bytes\n        )\n        n_inducing_pts = min(n_train, max_n_inducing_pts)\n        n_inducing_pts_init = (\n            len(task['inducing_pts_idxs']) // (3 * n_atoms)\n            if 'inducing_pts_idxs' in task\n            else None\n        )\n\n        if self.callback is not None:\n            self.callback = partial(\n                self.callback,\n                disp_str='Building preconditioner (k={} ind. point{})'.format(\n                    n_inducing_pts, 's' if n_inducing_pts > 1 else ''\n                ),\n            )\n        subtask_callback = (\n            partial(ui.sec_callback, main_callback=self.callback)\n            if self.callback is not None\n            else None\n        )\n\n        lev_scores = None\n        if n_inducing_pts_init is not None and n_inducing_pts_init == n_inducing_pts:\n            inducing_pts_idxs = task['inducing_pts_idxs']  # reuse old inducing points\n        else:\n            # Determine good inducing points.\n            lev_scores = self._lev_scores(\n                R_desc,\n                R_d_desc,\n                tril_perms_lin,\n                sig,\n                lam,\n                task['use_E_cstr'],\n                n_inducing_pts,\n                callback=subtask_callback,\n            )\n\n            dim_m = n_inducing_pts * dim_i\n            inducing_pts_idxs = self.inducing_pts_from_lev_scores(lev_scores, dim_m)\n\n        start = timeit.default_timer()\n        P_op, lev_scores = self._init_precon_operator(\n            task,\n            R_desc,\n            R_d_desc,\n            tril_perms_lin,\n            inducing_pts_idxs,\n            callback=subtask_callback,\n        )\n        stop = timeit.default_timer()\n\n        if self.callback is not None:\n            dur_s = stop - start\n            sec_disp_str = 'took {:.1f} s'.format(dur_s) if dur_s >= 0.1 else ''\n            self.callback(DONE, sec_disp_str=sec_disp_str)\n\n            self.callback = partial(\n                self.callback,\n                disp_str='Initializing solver',\n            )\n        subtask_callback = (\n            partial(ui.sec_callback, main_callback=self.callback)\n            if self.callback is not None\n            else None\n        )\n\n        n = P_op.shape[0]\n        K_op = self._init_kernel_operator(\n            task, R_desc, R_d_desc, tril_perms_lin, lam, n, callback=subtask_callback\n        )\n\n        num_iters = int(num_iters0)\n\n        if self.callback is not None:\n\n            num_devices = (\n                mp.cpu_count() if self._max_processes is None else self._max_processes\n            )\n            if self._use_torch:\n                num_devices = (\n                    torch.cuda.device_count()\n                    if torch.cuda.is_available()\n                    else torch.get_num_threads()\n                )\n            hardware_str = '{:d} {}{}{}'.format(\n                num_devices,\n                'GPU' if self._use_torch and torch.cuda.is_available() else 'CPU',\n                's' if num_devices > 1 else '',\n                '[PyTorch]' if self._use_torch else '',\n            )\n\n            self.callback(NOT_DONE, sec_disp_str=None)\n\n        start = 0\n        resid = 0\n        avg_tt = 0\n\n        global alpha_t, eff, steps_hist, callback_disp_str\n\n        alpha_t = None\n        if alphas0_F is not None:  # TODO: improve me: this will not workt with E_cstr\n            alpha_t = -alphas0_F\n\n        if alphas0_E is not None:\n            alpha_t = np.hstack((alpha_t, -alphas0_E))\n\n        steps_hist = collections.deque(\n            maxlen=CG_STEPS_HIST_LEN\n        )  # moving average window for step history\n\n        callback_disp_str = 'Initializing solver'\n\n        def _cg_status(xk):\n\n            global num_iters, start, resid, alpha_t, avg_tt, eff, steps_hist, callback_disp_str, P_t\n\n            stop = timeit.default_timer()\n            tt = 0.0 if start == 0 else (stop - start)\n            avg_tt += tt\n            start = timeit.default_timer()\n\n            old_resid = resid\n            try:\n\n                # Can we extract the residual from the solver?\n                f_locals = inspect.currentframe().f_back.f_locals\n                if 'resid' in f_locals:\n                    resid = f_locals['resid']\n                elif 'r' in f_locals:\n                    resid = np.linalg.norm(f_locals['r'])\n                else:\n                    raise KeyError\n\n            except KeyError:\n\n                # Fallback: compute residual from scratch (slower)\n                rk = y + K_op @ xk\n                resid = np.linalg.norm(rk)\n\n            step = 0 if num_iters == num_iters0 else resid - old_resid\n            steps_hist.append(step)\n\n            steps_hist_arr = np.array(steps_hist)\n            steps_hist_all = np.abs(steps_hist_arr).sum()\n            steps_hist_ratio = (\n                (-steps_hist_arr.clip(max=0).sum() / steps_hist_all)\n                if steps_hist_all > 0\n                else 1\n            )\n            eff = (\n                0 if num_iters == num_iters0 else (int(100 * steps_hist_ratio) - 50) * 2\n            )\n\n            if tt > 0.0 and num_iters % int(np.ceil(1.0 / tt)) == 0:  # once per second\n\n                train_rmse = resid / np.sqrt(len(y))\n                if self.callback is not None:\n                    callback_disp_str = 'Training error (RMSE): forces {:.4f}'.format(\n                        train_rmse\n                    )\n                    self.callback(\n                        NOT_DONE,\n                        disp_str=callback_disp_str,\n                        sec_disp_str=(\n                            '{:d} iter @ {} iter/s [eff: {:d}%], k={:d}'.format(\n                                num_iters,\n                                '{:.2f}'.format(1.0 / tt),\n                                eff,\n                                n_inducing_pts,\n                            )\n                        ),\n                    )\n\n            # Write out current solution as a model file once every 2 minutes (give or take).\n            if (\n                tt > 0.0\n                and num_iters % int(np.ceil(2 * 60.0 / tt)) == 0\n                and num_iters % 10 == 0\n            ):\n\n                self.log.debug('Saving model checkpoint.')\n\n                # TODO: support for +E constraints (done?)\n                alphas_F, alphas_E = -xk, None\n                if task['use_E_cstr']:\n                    n_train = task['R_train'].shape[0]\n                    alphas_F, alphas_E = -xk[:-n_train], -xk[-n_train:]\n\n                unconv_model = self.gdml_train.create_model(\n                    task,\n                    'cg',\n                    R_desc,\n                    R_d_desc,\n                    tril_perms_lin,\n                    y_std,\n                    alphas_F,\n                    alphas_E=alphas_E,\n                )\n\n                solver_keys = {\n                    'solver_tol': tol,\n                    'solver_iters': num_iters\n                    + 1,  # number of iterations performed (cg solver)\n                    'solver_resid': resid,  # residual of solution\n                    'norm_y_train': np.linalg.norm(y),\n                    'inducing_pts_idxs': inducing_pts_idxs,\n                }\n\n                unconv_model.update(solver_keys)\n\n                # recover integration constant\n                self.gdml_predict.set_alphas(alphas_F, alphas_E=alphas_E)\n                E_pred, _ = self.gdml_predict.predict()\n\n                E_pred *= y_std\n\n                unconv_model['c'] = 0\n                if 'E_train' in task:\n                    E_ref = np.squeeze(task['E_train'])\n                    unconv_model['c'] = np.mean(E_ref - E_pred)\n\n                if save_progr_callback is not None:\n                    save_progr_callback(unconv_model)\n\n            num_iters += 1\n\n            n_train = task['idxs_train'].shape[0]\n            if (\n                len(steps_hist) == CG_STEPS_HIST_LEN\n                and eff <= EFF_RESTART_THRESH\n                and n_inducing_pts < n_train\n            ):\n                alpha_t = xk\n                raise CGRestartException\n\n        num_restarts = 0\n        while True:\n            try:\n                alphas, info = sp.sparse.linalg.cg(\n                    -K_op,\n                    y,\n                    x0=alpha_t,\n                    M=P_op,\n                    rtol=tol,  # norm(residual) <= max(rtol*norm(b), atol)\n                    atol=0,\n                    maxiter=3\n                    * n_atoms\n                    * n_train\n                    * 10,  # allow 10x as many iterations as theoretically needed (at perfect precision)\n                    callback=_cg_status,\n                )\n                alphas = -alphas\n\n            except CGRestartException:\n\n                num_restarts += 1\n                steps_hist.clear()\n\n                if num_restarts == MAX_NUM_RESTARTS:\n                    info = 1  # convergence to tolerance not achieved\n                    alphas = alpha_t\n                    break\n                else:\n                    num_restarts_left = MAX_NUM_RESTARTS - num_restarts - 1\n                    self.log.debug(\n                        'Restarts left before giving up: {}{}.'.format(\n                            num_restarts_left,\n                            ' (final trial)' if num_restarts_left == 0 else '',\n                        )\n                    )\n\n                # TODO: keep using same number of points\n\n                n_inducing_pts = min(\n                    int(np.ceil(1.2 * n_inducing_pts)), n_train\n                )  # increase in increments (ignoring memory limits...)\n\n                subtask_callback = (\n                    partial(\n                        ui.sec_callback,\n                        main_callback=partial(\n                            self.callback, disp_str=callback_disp_str\n                        ),\n                    )\n                    if self.callback is not None\n                    else None\n                )\n\n                dim_m = n_inducing_pts * dim_i\n                inducing_pts_idxs = self.inducing_pts_from_lev_scores(lev_scores, dim_m)\n\n                del P_op\n                P_op, lev_scores = self._init_precon_operator(\n                    task,\n                    R_desc,\n                    R_d_desc,\n                    tril_perms_lin,\n                    inducing_pts_idxs,\n                    callback=subtask_callback,\n                )\n\n            else:\n                break\n\n        is_conv = info == 0\n\n        if self.callback is not None:\n\n            is_conv_warn_str = '' if is_conv else ' (NOT CONVERGED)'\n            self.callback(\n                DONE,\n                disp_str='Training on {:,} points{}'.format(n_train, is_conv_warn_str),\n                sec_disp_str=(\n                    '{:d} iter @ {} iter/s'.format(\n                        num_iters,\n                        '{:.2f}'.format(num_iters / avg_tt) if avg_tt > 0 else '--',\n                    )\n                ),\n                done_with_warning=not is_conv,\n            )\n\n        train_rmse = resid / np.sqrt(len(y))\n\n        return alphas, tol, num_iters, resid, train_rmse, inducing_pts_idxs, is_conv\n\n    @staticmethod\n    def max_n_inducing_pts(n_train, n_atoms, max_memory_bytes):\n\n        SQUARE_FACT = 5\n        LINEAR_FACT = 4\n\n        to_bytes = 8\n        to_dof = (3 * n_atoms) ** 2 * to_bytes\n\n        sq_factor = LINEAR_FACT * n_train * to_dof\n        ny_factor = SQUARE_FACT * to_dof\n\n        n_inducing_pts = (\n            np.sqrt(sq_factor**2 + 4.0 * ny_factor * max_memory_bytes) - sq_factor\n        ) / (2 * ny_factor)\n        n_inducing_pts = int(n_inducing_pts)\n\n        return min(n_inducing_pts, n_train)\n\n    @staticmethod\n    def est_memory_requirement(n_train, n_inducing_pts, n_atoms):\n\n        SQUARE_FACT = 5\n        LINEAR_FACT = 4\n\n        # est_bytes = n_train * n_inducing_pts * (3 * n_atoms) ** 2 * 8  # P_op\n        # est_bytes += 2 * (n_inducing_pts * 3 * n_atoms) ** 2 * 8  # P_op [cho_factor]\n        # est_bytes += (n_train * 3 * n_atoms) * 8  # lev_scores\n        # est_bytes += (n_train * 3 * n_atoms) * 8  # alpha\n\n        est_bytes = LINEAR_FACT * n_train * n_inducing_pts * (3 * n_atoms) ** 2 * 8\n\n        est_bytes += (\n            SQUARE_FACT * n_inducing_pts * n_inducing_pts * (3 * n_atoms) ** 2 * 8\n        )\n\n        # est_bytes += (n_train * 3 * n_atoms) * 8  # lev_scores\n        # est_bytes += (n_train * 3 * n_atoms) * 8  # alpha\n\n        return est_bytes\n"
  },
  {
    "path": "sgdml/torchtools.py",
    "content": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2019-2023 Stefan Chmiela, Jan Hermann\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nimport os\nimport sys\nimport logging\nfrom functools import partial\nimport numpy as np\nimport torch\nimport torch.nn as nn\nfrom torch.utils.data import DataLoader\n\ntry:\n    _torch_mps_is_available = torch.backends.mps.is_available()\nexcept AttributeError:\n    _torch_mps_is_available = False\n_torch_mps_is_available = False\n\ntry:\n    _torch_cuda_is_available = torch.cuda.is_available()\nexcept AttributeError:\n    _torch_cuda_is_available = False\n\n\nfrom .utils.desc import Desc\nfrom .utils import ui\n\n_dtype = torch.float64\n\n\ndef _next_batch_size(n_total, batch_size):\n\n    batch_size += 1\n    while n_total % batch_size != 0:\n        batch_size += 1\n\n    return batch_size\n\n\nclass GDMLTorchAssemble(nn.Module):\n    \"\"\"\n    PyTorch version of the kernel assembly routines in :class:`~predict.GDMLTrain`.\n    Derives from :class:`torch.nn.Module`. Contains no trainable parameters.\n    \"\"\"\n\n    def __init__(\n        self,\n        J,\n        tril_perms_lin,\n        sig,\n        use_E_cstr,\n        R_desc_torch,\n        R_d_desc_torch,\n        out,\n        callback=None,\n    ):\n\n        global _n_batches, _n_perm_batches\n\n        super(GDMLTorchAssemble, self).__init__()\n\n        self._log = logging.getLogger(__name__)\n\n        self.callback = callback\n\n        self.n_train, self.dim_d = R_d_desc_torch.shape[:2]\n        self.n_atoms = int((1 + np.sqrt(8 * self.dim_d + 1)) / 2)\n        self.dim_i = 3 * self.n_atoms\n\n        self.sig = float(sig)\n        self.tril_perms_lin = tril_perms_lin\n        self.n_perms = len(self.tril_perms_lin) // self.dim_d\n\n        self.use_E_cstr = use_E_cstr\n\n        self.R_desc_torch = nn.Parameter(R_desc_torch.type(_dtype), requires_grad=False)\n        self.R_d_desc_torch = nn.Parameter(\n            R_d_desc_torch.type(_dtype), requires_grad=False\n        )\n\n        self._desc = Desc(self.n_atoms)\n\n        self.J = J\n        _n_batches = 1\n        _n_perm_batches = 1\n\n        self.out = out\n\n    def _forward(\n        self,\n        j,\n    ):\n\n        global _n_batches, _n_perm_batches\n\n        if type(j) is tuple:  # selective/\"fancy\" indexing\n            (\n                K_j,\n                j,\n                keep_idxs_3n,\n            ) = j  # (block index in final K, block index global, indices of partials within block)\n            blk_j_len = len(keep_idxs_3n)\n            blk_j = slice(K_j, K_j + blk_j_len)\n\n        else:  # sequential indexing\n            blk_j_len = self.dim_i\n            K_j = (\n                j * self.dim_i\n                if j < self.n_train\n                else self.n_train * self.dim_i + (j % self.n_train)\n            )\n            blk_j = (\n                slice(K_j, K_j + self.dim_i)\n                if j < self.n_train\n                else slice(K_j, K_j + 1)\n            )\n            keep_idxs_3n = slice(None)  # same as [:]\n\n        q = np.sqrt(5) / self.sig\n\n        if (\n            j < self.n_train\n        ):  # This column only contrains second and first derivative constraints.\n\n            # Create decompressed a 'rj_d_desc'.\n            rj_d_desc_decomp_torch = self._desc.d_desc_from_comp(\n                self.R_d_desc_torch[j % self.n_train, :, :]\n            )[0][:, keep_idxs_3n]\n\n            n_perms_done = 0\n            for perm_batch in np.array_split(\n                np.arange(self.n_perms), min(_n_perm_batches, self.n_perms)\n            ):\n\n                tril_perms_lin_batch = (\n                    self.tril_perms_lin.reshape(-1, self.n_perms)[:, perm_batch]\n                    - n_perms_done * self.dim_d\n                ).ravel()  # index shift\n\n                n_perms_batch = len(perm_batch)\n                n_perms_done += n_perms_batch\n\n                # Create a permutated 'rj_desc'.\n                rj_desc_perms_torch = torch.reshape(\n                    torch.tile(self.R_desc_torch[j, :], (n_perms_batch,))[\n                        tril_perms_lin_batch\n                    ],\n                    (-1, n_perms_batch),\n                ).T\n\n                # Create a permutated 'rj_d_desc'.\n                rj_d_desc_perms_torch = torch.reshape(\n                    torch.tile(rj_d_desc_decomp_torch.T, (n_perms_batch,))[\n                        :, tril_perms_lin_batch\n                    ],\n                    (-1, self.dim_d, n_perms_batch),\n                )\n\n                for i_batch in np.array_split(np.arange(self.n_train), _n_batches):\n\n                    x_diffs = q * (\n                        self.R_desc_torch[i_batch, None, :]\n                        - rj_desc_perms_torch[None, :, :]\n                    )  # N, n_perms, d\n\n                    x_dists = x_diffs.norm(dim=-1)  # N, n_perms\n\n                    exp_xs = torch.exp(-x_dists) * (q**2) / 3  # N, n_perms\n                    exp_xs_1_x_dists = exp_xs * (1 + x_dists)  # N, n_perms*N_train\n\n                    del x_dists  # E_cstr\n\n                    diff_ab_outer_perms_torch = torch.einsum(\n                        '...ki,...kj->...ij',  # (slow)\n                        x_diffs * exp_xs[:, :, None],  # N, n_perms, d\n                        torch.einsum(\n                            '...ki,jik -> ...kj',\n                            x_diffs,\n                            rj_d_desc_perms_torch,\n                        ),  # N, n_perms, a*3\n                    )  # N, n_perms, a*3\n                    del exp_xs\n\n                    if not self.use_E_cstr:\n                        del x_diffs\n\n                    diff_ab_outer_perms_torch -= torch.einsum(\n                        'ikj,...j->...ki',\n                        rj_d_desc_perms_torch,\n                        exp_xs_1_x_dists,\n                    )\n\n                    if not self.use_E_cstr:\n                        del exp_xs_1_x_dists\n\n                    R_d_desc_decomp_torch = self._desc.d_desc_from_comp(\n                        self.R_d_desc_torch[i_batch, :, :]\n                    )\n\n                    k = torch.einsum(\n                        '...ij,...ik->...kj',\n                        diff_ab_outer_perms_torch,  # N, d, 3*a\n                        R_d_desc_decomp_torch,\n                    )\n                    del diff_ab_outer_perms_torch\n                    del R_d_desc_decomp_torch\n\n                    blk_i = slice(\n                        i_batch[0] * self.dim_i, (i_batch[-1] + 1) * self.dim_i\n                    )\n\n                    k_np = k.cpu().numpy().reshape(-1, blk_j_len)\n                    if (\n                        n_perms_done == n_perms_batch\n                    ):  # first permutation batch iteration\n                        self.out[blk_i, blk_j] = k_np\n                    else:\n                        self.out[blk_i, blk_j] = self.out[blk_i, blk_j] + k_np\n                    del k\n\n                    # First derivative constraints\n                    if self.use_E_cstr:\n\n                        K_fe = (x_diffs / q) * exp_xs_1_x_dists[:, :, None]\n                        del x_diffs\n                        del exp_xs_1_x_dists\n\n                        K_fe = -torch.einsum(\n                            '...ik,jki -> ...j', K_fe, rj_d_desc_perms_torch\n                        )\n\n                        E_off_i = self.n_train * self.dim_i\n                        i_batch_off = i_batch + E_off_i\n                        self.out[\n                            i_batch_off[0] : (i_batch_off[-1] + 1), blk_j\n                        ] = K_fe.cpu().numpy()\n\n                del rj_desc_perms_torch\n                del rj_d_desc_perms_torch\n\n        else:\n\n            if self.use_E_cstr:\n\n                n_perms_done = 0\n                for perm_batch in np.array_split(\n                    np.arange(self.n_perms), min(_n_perm_batches, self.n_perms)\n                ):\n\n                    tril_perms_lin_batch = (\n                        self.tril_perms_lin.reshape(-1, self.n_perms)[:, perm_batch]\n                        - n_perms_done * self.dim_d\n                    ).ravel()  # index shift\n\n                    n_perms_batch = len(perm_batch)\n                    n_perms_done += n_perms_batch\n\n                    for i_batch in np.array_split(np.arange(self.n_train), _n_batches):\n\n                        ri_desc_perms_torch = torch.reshape(\n                            torch.tile(\n                                self.R_desc_torch[i_batch, :], (1, n_perms_batch)\n                            )[:, tril_perms_lin_batch],\n                            (len(i_batch), -1, n_perms_batch),\n                        )\n\n                        # Create decompressed a 'ri_d_desc'.\n                        ri_d_desc_decomp_torch = self._desc.d_desc_from_comp(\n                            self.R_d_desc_torch[i_batch, :, :]\n                        )\n\n                        ri_d_desc_perms_torch = torch.reshape(\n                            torch.tile(ri_d_desc_decomp_torch, (1, n_perms_batch, 1))[\n                                :, tril_perms_lin_batch, :\n                            ],\n                            (len(i_batch), self.dim_d, n_perms_batch, -1),\n                        )\n                        # del ri_d_desc_decomp_torch\n\n                        x_diffs = q * (\n                            self.R_desc_torch[j % self.n_train, None, :, None]\n                            - ri_desc_perms_torch\n                        )\n\n                        x_dists = x_diffs.norm(dim=1)\n\n                        exp_xs = torch.exp(-x_dists) * (q**2) / 3\n                        exp_xs_1_x_dists = exp_xs * (1 + x_dists)\n\n                        K_fe = x_diffs / q * exp_xs_1_x_dists[:, None, :]\n                        K_fe = -torch.einsum(\n                            '...ik,...ikj -> ...j', K_fe, ri_d_desc_perms_torch\n                        ).ravel()\n                        k_fe = K_fe.cpu().numpy()\n\n                        k_ee = -torch.einsum(\n                            '...i,...i -> ...',\n                            1 + x_dists * (1 + x_dists / 3),\n                            torch.exp(-x_dists),\n                        )\n                        k_ee = k_ee.cpu().numpy()\n\n                        E_off_i = (\n                            self.n_train * self.dim_i\n                        )  # Account for 'alloc_extra_rows'!.\n                        blk_i_full = slice(\n                            i_batch[0] * self.dim_i, (i_batch[-1] + 1) * self.dim_i\n                        )\n                        if (\n                            n_perms_done == n_perms_batch\n                        ):  # first permutation batch iteration\n                            self.out[blk_i_full, K_j] = k_fe\n                            self.out[E_off_i + i_batch, K_j] = k_ee\n                        else:\n                            self.out[blk_i_full, K_j] = self.out[blk_i_full, K_j] + k_fe\n                            self.out[E_off_i + i_batch, K_j] = (\n                                self.out[E_off_i + i_batch, K_j] + k_ee\n                            )\n\n        return blk_j.stop - blk_j.start\n\n    def forward(self, J_indx):\n\n        global _n_batches, _n_perm_batches\n\n        for i in J_indx:\n            while True:\n                try:\n                    done = self._forward(self.J[i])\n                except RuntimeError as e:\n                    if 'out of memory' in str(e):\n                        if _torch_cuda_is_available:\n                            torch.cuda.empty_cache()\n\n                        if _n_batches < self.n_train:\n                            _n_batches = _next_batch_size(self.n_train, _n_batches)\n\n                            self._log.debug(\n                                'Assembling each kernel column in {} batches, i.e. {} points/batch ({} points in total).'.format(\n                                    _n_batches,\n                                    self.n_train // _n_batches,\n                                    self.n_train,\n                                )\n                            )\n\n                        elif _n_perm_batches < self.n_perms:\n                            _n_perm_batches = _next_batch_size(\n                                self.n_perms, _n_perm_batches\n                            )\n\n                            self._log.debug(\n                                'Generating permutations in {} batches, i.e. {} permutations/batch ({} permutations in total).'.format(\n                                    _n_perm_batches,\n                                    self.n_perms // _n_perm_batches,\n                                    self.n_perms,\n                                )\n                            )\n\n                        else:\n                            self._log.critical(\n                                'Could not allocate enough memory to assemble kernel matrix, even block-by-block and/or handling perms in batches.'\n                            )\n                            print()\n                            os._exit(1)\n                    else:\n                        raise e\n                else:\n                    if self.callback is not None:\n                        self.callback(done)\n\n                    break\n\n\nclass GDMLTorchPredict(nn.Module):\n    \"\"\"\n    PyTorch version of :class:`~predict.GDMLPredict`. Derives from\n    :class:`torch.nn.Module`. Contains no trainable parameters.\n    \"\"\"\n\n    def __init__(\n        self,\n        model,\n        lat_and_inv=None,\n        batch_size=None,\n        n_perm_batches=1,\n        max_memory=None,\n        max_processes=None,\n        log_level=None,\n    ):\n        \"\"\"\n        Parameters\n        ----------\n        model : Mapping\n            Obtained from :meth:`~train.GDMLTrain.train`.\n        lat_and_inv : tuple of :obj:`numpy.ndarray`\n            Tuple of 3 x 3 matrix containing lattice vectors as columns and its inverse.\n        batch_size : int, optional\n            Maximum batch size of geometries for prediction. Calculated from\n            :paramref:`max_mem` if not given.\n        n_perm_batches : int, optional\n            Divide the processing of all symmetries for each point into smaller\n            batches or precompute all in the beginning (needs  more memmory, but faster)?\n        max_memory : float, optional\n            (unit GB) Maximum allowed CPU memory for prediction (GPU memory always unlimited)\n        \"\"\"\n\n        global _batch_size, _n_perm_batches\n\n        super(GDMLTorchPredict, self).__init__()\n\n        self._log = logging.getLogger(__name__)\n        if log_level is not None:\n            self._log.setLevel(log_level)\n\n        model = dict(model)\n\n        self._lat_and_inv = (\n            None\n            if lat_and_inv is None\n            else (\n                torch.tensor(lat_and_inv[0], dtype=_dtype),\n                torch.tensor(lat_and_inv[1], dtype=_dtype),\n            )\n        )\n\n        self.dim_d, self.n_train = model['R_desc'].shape[:2]\n        self.dim_i = 3 * int((1 + np.sqrt(8 * self.dim_d + 1)) / 2)\n        self.n_perms, self.n_atoms = model['perms'].shape\n\n        # Check dublicates in permutation list.\n        if model['perms'].shape[0] != np.unique(model['perms'], axis=0).shape[0]:\n            self._log.warning('Model contains dublicate permutations')\n\n        # Find index of identify permutation.\n        self.idx_id_perm = np.where(\n            (model['perms'] == np.arange(self.n_atoms)).all(axis=1)\n        )[0]\n\n        # No identity permutation found.\n        if len(self.idx_id_perm) == 0:\n            self._log.critical('Identity permutation is missing!')\n            print()\n            os._exit(1)\n\n        # Identity permutation not at index zero.\n        if len(self.idx_id_perm) > 0 and self.idx_id_perm[0] != 0:\n            self._log.debug(\n                'Identity is not at first position in permutation list (found at index {})'.format(\n                    self.idx_id_perm[0]\n                )\n            )\n\n        self.idx_id_perm = self.idx_id_perm[0]\n\n        self._sig = int(model['sig'])\n        self._c = float(model['c'])\n        self._std = float(model.get('std', 1))\n\n        self.tril_indices = np.tril_indices(self.n_atoms, k=-1)\n\n        if _torch_cuda_is_available:  # Ignore limits and take whatever the GPU has.\n            max_memory = (\n                min(\n                    [\n                        torch.cuda.get_device_properties(i).total_memory\n                        for i in range(torch.cuda.device_count())\n                    ]\n                )\n                // 2**30\n            )  # bytes to GB\n        else:  # TODO: what about MPS?\n            default_cpu_max_mem = 32\n            if max_memory is None:\n                self._log.warning(\n                    'PyTorch CPU memory budget is limited to {} by default, which may impact performance.\\n'.format(\n                        ui.gen_memory_str(2**30 * default_cpu_max_mem)\n                    )\n                    + 'If necessary, adjust memory limit with option \\'-m\\'.'\n                )\n            max_memory = (\n                max_memory or default_cpu_max_mem\n            )  # 32 GB as default (hardcoded for now...)\n        max_memory = int(2**30 * max_memory)  # GB to bytes\n\n        min_const_mem, min_per_sample_mem = self.est_mem_requirement(return_min=True)\n\n        log_type = (\n            self._log.warning\n            if min_const_mem + min_per_sample_mem >= max_memory\n            else self._log.info\n        )\n        log_type(\n            '{} memory report: max./avail. {}, min. req. (const./per-sample) ~{}/~{}'.format(\n                'GPU'\n                if (_torch_cuda_is_available or _torch_mps_is_available)\n                else 'CPU',\n                ui.gen_memory_str(max_memory),\n                ui.gen_memory_str(min_const_mem),\n                ui.gen_memory_str(min_per_sample_mem),\n            )\n        )\n\n        self.max_processes = max_processes\n\n        self.R_d_desc = None\n        self._xs_train = nn.Parameter(\n            torch.tensor(model['R_desc'], dtype=_dtype).t(), requires_grad=False\n        )\n        self._Jx_alphas = nn.Parameter(\n            torch.tensor(np.array(model['R_d_desc_alpha']), dtype=_dtype),\n            requires_grad=False,\n        )\n\n        self._alphas_E = None\n        if 'alphas_E' in model:\n            self._alphas_E = nn.Parameter(\n                torch.from_numpy(model['alphas_E'], dtype=_dtype), requires_grad=False\n            )\n\n        self.perm_idxs = (\n            torch.tensor(model['tril_perms_lin'], dtype=torch.long)\n            .view(-1, self.n_perms)\n            .t()\n        )\n\n        i, j = self.tril_indices\n        self.register_buffer(\n            'agg_mat', torch.zeros((self.n_atoms, self.dim_d), dtype=torch.int8)\n        )\n        self.agg_mat[i, range(self.dim_d)] = -1\n        self.agg_mat[j, range(self.dim_d)] = 1\n\n        # Try to cache all permutated variants of 'self._xs_train' and 'self._Jx_alphas'\n        try:\n            self.set_n_perm_batches(n_perm_batches)\n        except RuntimeError as e:\n            if 'out of memory' in str(e):\n                if _torch_cuda_is_available:\n                    torch.cuda.empty_cache()\n\n                if n_perm_batches == 1:\n                    self.set_n_perm_batches(\n                        2\n                    )  # Set to 2 perm batches, because that's the first batch size (and fastest) that is not cached.\n                    pass\n                else:\n                    self._log.critical(\n                        'Could not allocate enough memory to store model parameters on GPU. There is no hope!'\n                    )\n                    print()\n                    os._exit(1)\n            else:\n                raise e\n\n        const_mem, per_sample_mem = self.est_mem_requirement(return_min=False)\n        _batch_size = (\n            max((max_memory - const_mem) // per_sample_mem, 1)\n            if batch_size is None\n            else batch_size\n        )\n        max_batch_size = (\n            self.n_train // torch.cuda.device_count()\n            if _torch_cuda_is_available\n            else self.n_train\n        )\n        _batch_size = min(_batch_size, max_batch_size)\n\n        self._log.debug(\n            'Setting batch size to {}/{} points.'.format(_batch_size, self.n_train)\n        )\n\n        self.desc = Desc(self.n_atoms, max_processes=max_processes)\n\n    def get_n_perm_batches(self):\n\n        global _n_perm_batches\n        return _n_perm_batches\n\n    def set_n_perm_batches(self, n_perm_batches):\n\n        global _n_perm_batches\n\n        self._log.debug(\n            'Setting permutation batch size to {}/{}{}.'.format(\n                self.n_perms // n_perm_batches,\n                self.n_perms,\n                ' (no caching)' if n_perm_batches > 1 else '',\n            )\n        )\n\n        _n_perm_batches = n_perm_batches\n        if n_perm_batches == 1 and self.n_perms > 1:\n            self.cache_perms()\n        else:\n            self.uncache_perms()\n\n    def apply_perms_to_obj(self, xs, perm_idxs=None):\n\n        n_perms = 1 if perm_idxs is None else perm_idxs.numel() // self.dim_d\n        perm_idxs = (\n            slice(None) if perm_idxs is None else perm_idxs\n        )  # slice(None) same as [:]\n\n        # might run out of memory here, which will be handled by the caller\n        try:\n            return xs.repeat(1, n_perms)[:, perm_idxs].reshape(-1, self.dim_d)\n        except:\n            raise\n\n    def remove_perms_from_obj(self, xs):\n\n        return xs.reshape(self.n_train, -1, self.dim_d)[:, self.idx_id_perm, :].reshape(\n            -1, self.dim_d\n        )\n\n    def uncache_perms(self):\n\n        xs_train_n_perms = self._xs_train.numel() // (self.n_train * self.dim_d)\n        if xs_train_n_perms != 1:  # Uncached already?\n            self._xs_train = nn.Parameter(\n                self.remove_perms_from_obj(self._xs_train), requires_grad=False\n            )\n\n        Jx_alphas_n_perms = self._Jx_alphas.numel() // (self.n_train * self.dim_d)\n        if Jx_alphas_n_perms != 1:  # Uncached already?\n            self._Jx_alphas = nn.Parameter(\n                self.remove_perms_from_obj(self._Jx_alphas), requires_grad=False\n            )\n\n    def cache_perms(self):\n\n        xs_train_n_perms = self._xs_train.numel() // (self.n_train * self.dim_d)\n        if xs_train_n_perms == 1:  # Cached already?\n            self._xs_train = nn.Parameter(\n                self.apply_perms_to_obj(self._xs_train, perm_idxs=self.perm_idxs),\n                requires_grad=False,\n            )\n\n        Jx_alphas_n_perms = self._Jx_alphas.numel() // (self.n_train * self.dim_d)\n        if Jx_alphas_n_perms == 1:  # Cached already?\n            self._Jx_alphas = nn.Parameter(\n                self.apply_perms_to_obj(self._Jx_alphas, perm_idxs=self.perm_idxs),\n                requires_grad=False,\n            )\n\n    def est_mem_requirement(self, return_min=False):\n        \"\"\"\n        Calculate an estimate for the maximum/minimum memory needed to generate\n        a prediction for a single geometry.\n\n        Parameters\n        ----------\n        return_min : boolean, optional\n            Return a minimum estimate instead.\n\n        Returns\n        -------\n        const_mem : int\n            Constant memory overhead (bytes) (allocated upon instantiation of the class)\n        per_sample_mem : int\n            Memory requirement for a single prediction (bytes)\n        \"\"\"\n\n        n_perms_mem = 1 if return_min else self.n_perms\n\n        # Constant memory requirement (bytes)\n        const_mem = self.n_train * self.n_atoms * 3  # Rs (all)\n        const_mem += n_perms_mem * self.dim_d  # perm_idxs\n        const_mem += (\n            n_perms_mem * self.n_train * self.dim_d * 2\n        )  # _xs_train and _Jx_alphas\n        const_mem += self.n_atoms * self.dim_d  # agg_mat\n        const_mem *= 8\n        const_mem = int(const_mem)\n\n        # Peak memory requirement (bytes)\n        per_sample_mem = 2 * self.n_atoms * 3  # Rs (batch), # Fs (batch)\n        per_sample_mem += self.n_atoms  # Es (batch)\n        per_sample_mem += self.n_atoms**2 * 3  # diffs\n        per_sample_mem += self.dim_d  # xs\n        per_sample_mem += self.dim_d * n_perms_mem * self.n_train  # x_diffs\n        per_sample_mem += (\n            4 * n_perms_mem * self.n_train\n        )  # x_dists, exp_xs, dot_x_diff_Jx_alphas, exp_xs_1_x_dists\n        per_sample_mem *= 8\n        per_sample_mem = int(\n            2 * per_sample_mem\n        )  # HACK!!! Assume double that is needed. Seems to work better, maybe because of fragmentation issues?\n\n        # <class 'torch.Tensor'> torch.Size([21, 118, 3]) # Fs\n        # <class 'torch.Tensor'> torch.Size([21]) # Es\n        # <class 'torch.Tensor'> torch.Size([21, 118, 3]) # Rs (batch)\n        # <class 'torch.Tensor'> torch.Size([21, 118, 118, 3]) # diffs\n        # <class 'torch.Tensor'> torch.Size([21, 6903]) # xs\n        # <class 'torch.Tensor'> torch.Size([21, 5760, 6903])\n        # <class 'torch.Tensor'> torch.Size([21, 5760]) # x_dists\n        # <class 'torch.Tensor'> torch.Size([21, 5760]) # exp_xs\n        # <class 'torch.Tensor'> torch.Size([21, 5760]) # dot_x_diff_Jx_alphas\n        # <class 'torch.Tensor'> torch.Size([21, 5760]) # exp_xs_1_x_dists\n        # <class 'torch.Tensor'> torch.Size([96, 6903]) # perm_idxs\n        # <class 'torch.nn.parameter.Parameter'> torch.Size([5760, 6903]) # _xs_train\n        # <class 'torch.nn.parameter.Parameter'> torch.Size([5760, 6903]) # _Jx_alphas\n        # <class 'torch.Tensor'> torch.Size([60, 118, 3]) # Rs (all)\n\n        return const_mem, per_sample_mem\n\n    def set_R_d_desc(self, R_d_desc):\n        \"\"\"\n        Set reference to training descriptor Jacobians. They are needed when the\n        alpha coefficients are updated during iterative model training.\n\n        This routine will try to move them to the GPU memory, if enough is available.\n\n        Parameters\n        ----------\n        R_d_desc : :obj:`numpy.ndarray`\n            Array containing the Jacobian of the descriptor for\n            each training point.\n        \"\"\"\n\n        self.R_d_desc = torch.from_numpy(R_d_desc).type(_dtype)\n\n        # Try moving to GPU memory.\n        if _torch_cuda_is_available or _torch_mps_is_available:\n            try:\n                R_d_desc = self.R_d_desc.to(self._xs_train.device)\n            except RuntimeError as e:\n                if 'out of memory' in str(e):\n\n                    if _torch_cuda_is_available:\n                        torch.cuda.empty_cache()\n\n                    self._log.debug('Failed to cache \\'R_d_desc\\' on GPU.')\n                else:\n                    raise e\n            else:\n                self.R_d_desc = R_d_desc\n\n    def set_alphas(self, alphas, alphas_E=None):\n        \"\"\"\n        Reconfigure the current model with a new set of regression parameters.\n\n        This routine is used during iterative model training.\n\n        Parameters\n        ----------\n                alphas : :obj:`numpy.ndarray`\n                    1D array containing the new model parameters.\n                alphas_E : :obj:`numpy.ndarray`, optional\n                    1D array containing the additional new model parameters, if\n                    energy constraints are used in the kernel (`use_E_cstr=True`)\n        \"\"\"\n\n        global _n_perm_batches\n\n        if self.R_d_desc is None:\n            self._log.critical(\n                'The function \\'set_alphas()\\' requires \\'R_d_desc\\' to be set beforehand!'\n            )\n            print()\n            os._exit(1)\n\n        if alphas_E is not None:\n            self._alphas_E = nn.Parameter(\n                torch.from_numpy(alphas_E).to(self._xs_train.device).type(_dtype),\n                requires_grad=False,\n            )\n\n        del self._Jx_alphas\n        while True:\n            try:\n\n                alphas_torch = (\n                    torch.from_numpy(alphas).type(_dtype).to(self.R_d_desc.device)\n                )  # Send to whatever device 'R_d_desc' is on, first.\n                xs = self.desc.d_desc_dot_vec(\n                    self.R_d_desc, alphas_torch.reshape(-1, self.dim_i)\n                )\n                del alphas_torch\n\n                if (_torch_cuda_is_available and not xs.is_cuda) or (\n                    _torch_mps_is_available and not xs.is_mps\n                ):\n                    xs = xs.to(\n                        self._xs_train.device\n                    )  # Only now send it to the GPU ('_xs_train' will be for sure, if GPUs are available)\n\n            except RuntimeError as e:\n                if 'out of memory' in str(e):\n\n                    if _torch_cuda_is_available or _torch_mps_is_available:\n\n                        if _torch_cuda_is_available:\n                            torch.cuda.empty_cache()\n\n                        self.R_d_desc = self.R_d_desc.cpu()\n\n                        self._log.debug(\n                            'Failed to \\'set_alphas()\\': \\'R_d_desc\\' was moved back from GPU to CPU'\n                        )\n\n                        pass\n\n                    else:\n\n                        self._log.critical(\n                            'Not enough memory to cache \\'R_d_desc\\'! There nothing we can do...'\n                        )\n                        print()\n                        os._exit(1)\n\n                else:\n                    raise e\n            else:\n                break\n\n        try:\n\n            perm_idxs = self.perm_idxs if _n_perm_batches == 1 else None\n            self._Jx_alphas = nn.Parameter(\n                self.apply_perms_to_obj(xs, perm_idxs=perm_idxs), requires_grad=False\n            )\n\n        except RuntimeError as e:\n            if 'out of memory' in str(e):\n                if torch.cuda.is_available():\n                    torch.cuda.empty_cache()\n\n                if _n_perm_batches < self.n_perms:\n\n                    self._log.debug(\n                        'Setting permutation batch size to {}/{}{}.'.format(\n                            self.n_perms // n_perm_batches,\n                            self.n_perms,\n                            ' (no caching)' if n_perm_batches > 1 else '',\n                        )\n                    )\n\n                    _n_perm_batches += 1  # Do NOT change me to use 'self.set_n_perm_batches(_n_perm_batches + 1)'!\n                    self._xs_train = nn.Parameter(\n                        self.remove_perms_from_obj(self._xs_train), requires_grad=False\n                    )  # Remove any permutations from 'self._xs_train'.\n                    self._Jx_alphas = nn.Parameter(\n                        self.apply_perms_to_obj(xs, perm_idxs=None), requires_grad=False\n                    )  # Set 'self._Jx_alphas' without applying permutations.\n\n                else:\n                    self._log.critical(\n                        'Could not allocate enough memory to set new alphas in model.'\n                    )\n                    print()\n                    os._exit(1)\n            else:\n                raise e\n\n    def _forward(self, Rs_or_train_idxs, return_E=True):\n\n        global _n_perm_batches\n\n        q = np.sqrt(5) / self._sig\n        i, j = self.tril_indices\n\n        is_train_pred = Rs_or_train_idxs.dim() == 1\n        if not is_train_pred:  # Rs\n\n            Rs = Rs_or_train_idxs.type(_dtype)\n            diffs = Rs[:, :, None, :] - Rs[:, None, :, :]  # N, a, a, 3\n            diffs = diffs[:, i, j, :]  # N, d, 3\n\n            if self._lat_and_inv is not None:\n\n                diffs_shape = diffs.shape\n                # diffs = self.desc.pbc_diff(diffs.reshape(-1, 3), self._lat_and_inv).reshape(\n                #    diffs_shape\n                # )\n\n                lat, lat_inv = self._lat_and_inv\n                if lat.device != Rs.device:\n                    lat = lat.to(Rs.device)\n                    lat_inv = lat_inv.to(Rs.device)\n\n                diffs = diffs.reshape(-1, 3)\n\n                c = lat_inv.mm(diffs.t())\n                diffs -= lat.mm(c.round()).t()\n\n                diffs = diffs.reshape(diffs_shape)\n\n            xs = 1 / diffs.norm(dim=-1)  # N, d\n\n            diffs *= xs[:, :, None] ** 3\n            Jxs = diffs\n            del diffs\n\n        else:  # xs_train\n\n            train_idxs = Rs_or_train_idxs\n            \n            # Get index of identity permutation, depending on caching configuration.\n            xs_train_n_perms = self._xs_train.numel() // (self.n_train * self.dim_d)\n            idx_id_perm = 0 if xs_train_n_perms == 1 else self.idx_id_perm\n\n            xs = self._xs_train.reshape(self.n_train, -1, self.dim_d)[\n                train_idxs, idx_id_perm, :\n            ]  # ignore permutations\n\n            train_idxs = train_idxs.to(self.R_d_desc.device) # 'train_idxs' should be on the same device with 'R_d_desc'\n\n            Jxs = self.R_d_desc[train_idxs, :, :].to(\n                xs.device\n            )  # 'R_d_desc' might be living on the CPU...\n\n        # current:\n        # diffs: N, a, a, 3\n        # xs: # N, d\n\n        Fs_x = torch.zeros(xs.shape, device=xs.device, dtype=xs.dtype)\n        Es = (\n            torch.zeros((xs.shape[0],), device=xs.device, dtype=xs.dtype)\n            if return_E\n            else None\n        )\n\n        n_perms_done = 0\n        for perm_batch in np.array_split(np.arange(self.n_perms), _n_perm_batches):\n\n            if _n_perm_batches == 1:\n                xs_train_perm_split = self._xs_train\n                Jx_alphas_perm_split = self._Jx_alphas\n            else:\n                perm_idxs_batch = (\n                    self.perm_idxs[perm_batch, :] - n_perms_done * self.dim_d\n                )  # index shift\n                xs_train_perm_split = self.apply_perms_to_obj(\n                    self._xs_train, perm_idxs=perm_idxs_batch\n                )\n                Jx_alphas_perm_split = self.apply_perms_to_obj(\n                    self._Jx_alphas, perm_idxs=perm_idxs_batch\n                )\n\n            n_perms_done += len(perm_batch)\n\n            x_diffs = q * (\n                xs[:, None, :] - xs_train_perm_split\n            )  # N, n_perms*N_train, d\n            x_dists = x_diffs.norm(dim=-1)  # N, n_perms*N\n\n            exp_xs = torch.exp(-x_dists) * (q**2) / 3  # N, n_perms\n            exp_xs_1_x_dists = exp_xs * (1 + x_dists)  # N, n_perms*N_train\n\n            if self._alphas_E is None:\n                del x_dists\n\n            dot_x_diff_Jx_alphas = torch.einsum(\n                'ij...,j...->ij', x_diffs, Jx_alphas_perm_split\n            )  # N, n_perms*N_train\n\n            # Fs_x = ((exp_xs * dot_x_diff_Jx_alphas)[..., None] * x_diffs).sum(dim=1)\n            Fs_x += torch.einsum(  # NOTE ! Fs_x = Fs_x + torch.einsum(\n                '...j,...j,...jk', exp_xs, dot_x_diff_Jx_alphas, x_diffs\n            )  # N, d\n            del exp_xs\n\n            if self._alphas_E is None:\n                del x_diffs\n\n            # current:\n            # diffs: N, a, a, 3\n            # xs: # N, d\n            # x_diffs: # N, n_perms*N_train, d\n            # x_dists: # N, n_perms*N_train\n            # exp_xs: # N, n_perms*N_train\n            # dot_x_diff_Jx_alphas: N, n_perms*N_train\n            # exp_xs_1_x_dists: N, n_perms*N_train\n            # Fs_x: N, d\n\n            Fs_x -= exp_xs_1_x_dists.mm(Jx_alphas_perm_split)  # N, d\n\n            if return_E:\n                Es += (\n                    torch.einsum('...j,...j', exp_xs_1_x_dists, dot_x_diff_Jx_alphas)\n                    / q\n                )\n\n            del dot_x_diff_Jx_alphas\n\n            if self._alphas_E is None:\n                del exp_xs_1_x_dists\n\n            # Note: Energies are automatically predicted with a flipped sign here (because -E are trained, instead of E)\n            if self._alphas_E is not None:\n\n                K_fe = (x_diffs / q) * exp_xs_1_x_dists[:, :, None]\n                del exp_xs_1_x_dists\n                del x_diffs\n\n                K_fe = K_fe.reshape(-1, self.n_train, len(perm_batch), self.dim_d)\n                Fs_x += torch.einsum('j,...jkl->...l', self._alphas_E, K_fe)\n                del K_fe\n\n                K_ee = (1 + x_dists * (1 + x_dists / 3)) * torch.exp(-x_dists)\n                del x_dists\n\n                K_ee = K_ee.reshape(-1, self.n_train, len(perm_batch))\n                Es += torch.einsum('j,...jk->...', self._alphas_E, K_ee)\n                del K_ee\n\n        # current:\n        # diffs: N, a, a, 3\n        # xs: # N, d\n        # x_dists: # N, n_perms*N\n        # dot_x_diff_Jx_alphas: N, n_perms*N\n        # exp_xs_1_x_dists: N, n_perms*N\n        # Fs_x: N, d\n\n        Fs = torch.einsum('ji,...ik,...i->...jk', self.agg_mat.double(), Jxs, Fs_x)\n\n        if not is_train_pred:  # TODO: set std to zero in training mode?\n            Fs *= self._std\n\n        if return_E:\n            Es *= self._std\n            Es += self._c\n\n        return Es, Fs\n\n    def forward(self, Rs_or_train_idxs, return_E=True):\n        \"\"\"\n        Predict energy and forces for a batch of geometries.\n\n        Parameters\n        ----------\n        Rs_or_train_idxs : :obj:`torch.Tensor`\n            (dims M x N x 3) Cartesian coordinates of M molecules composed of N atoms or\n            (dims N) index list of training points to evaluate. Note that `self.R_d_desc`\n            needs to be set for the latter to work.\n        return_E : boolean, optional\n            If false (default: true), only the forces are returned.\n\n        Returns\n        -------\n        E : :obj:`torch.Tensor`\n            (dims M) Molecular energies (unless `return_E == False`)\n        F : :obj:`torch.Tensor`\n            (dims M x N x 3) Nuclear gradients of the energy\n        \"\"\"\n\n        global _batch_size, _n_perm_batches\n\n        # if Rs_or_train_idxs.dim() == 1:\n        #    # contains index list. return predictions for these training points\n        #    dtype = self.R_d_desc.dtype\n        # elif Rs_or_train_idxs.dim() == 3:\n        # this is real data\n\n        #    assert Rs_or_train_idxs.shape[1:] == (self.n_atoms, 3)\n        #    Rs_or_train_idxs = Rs_or_train_idxs.double()\n        #    dtype = Rs_or_train_idxs.dtype\n\n        # else:\n        #    # unknown input\n        #    self._log.critical('Invalid input for \\'Rs_or_train_idxs\\'.')\n        #    print()\n        #    os._exit(1)\n\n        while True:\n            try:\n                Es, Fs = zip(\n                    *map(\n                        partial(self._forward, return_E=return_E),\n                        DataLoader(Rs_or_train_idxs, batch_size=_batch_size),\n                    )\n                )\n            except RuntimeError as e:\n                if 'out of memory' in str(e):\n                    if torch.cuda.is_available():\n                        torch.cuda.empty_cache()\n\n                    if _batch_size > 1:\n\n                        self._log.debug(\n                            'Setting batch size to {}/{} points.'.format(\n                                _batch_size, self.n_train\n                            )\n                        )\n                        _batch_size -= 1\n\n                    elif _n_perm_batches < self.n_perms:\n                        n_perm_batches = _next_batch_size(self.n_perms, _n_perm_batches)\n                        self.set_n_perm_batches(n_perm_batches)\n\n                    else:\n                        self._log.critical(\n                            'Could not allocate enough (GPU) memory to evaluate model, despite reducing batch size.'\n                        )\n                        print()\n                        os._exit(1)\n                else:\n                    raise e\n            else:\n                break\n\n        ret = (torch.cat(Fs),)\n        if return_E:\n            ret = (torch.cat(Es),) + ret\n\n        return ret\n"
  },
  {
    "path": "sgdml/train.py",
    "content": "\"\"\"\nThis module contains all routines for training GDML and sGDML models.\n\"\"\"\n\n# MIT License\n#\n# Copyright (c) 2018-2022 Stefan Chmiela\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nfrom __future__ import print_function\n\nimport sys\nimport os\nimport logging\nimport psutil\n\nimport multiprocessing as mp\n\nPool = mp.get_context('fork').Pool\n\nimport timeit\nfrom functools import partial\n\nimport numpy as np\n\ntry:\n    import torch\nexcept ImportError:\n    _has_torch = False\nelse:\n    _has_torch = True\n\ntry:\n    _torch_mps_is_available = torch.backends.mps.is_available()\nexcept AttributeError:\n    _torch_mps_is_available = False\n_torch_mps_is_available = False\n\ntry:\n    _torch_cuda_is_available = torch.cuda.is_available()\nexcept AttributeError:\n    _torch_cuda_is_available = False\n\nfrom . import __version__, DONE, NOT_DONE\nfrom .solvers.analytic import Analytic\n\n# TODO: remove exception handling once iterative solver ships\ntry:\n    from .solvers.iterative import Iterative\nexcept ImportError:\n    pass\n\nfrom .predict import GDMLPredict\nfrom .utils.desc import Desc\nfrom .utils import io, perm, ui\n\n\ndef _share_array(arr_np, typecode_or_type):\n    \"\"\"\n    Return a ctypes array allocated from shared memory with data from a\n    NumPy array.\n\n    Parameters\n    ----------\n        arr_np : :obj:`numpy.ndarray`\n            NumPy array.\n        typecode_or_type : char or :obj:`ctype`\n            Either a ctypes type or a one character typecode of the\n            kind used by the Python array module.\n\n    Returns\n    -------\n        array of :obj:`ctype`\n    \"\"\"\n\n    arr = mp.RawArray(typecode_or_type, arr_np.ravel())\n    return arr, arr_np.shape\n\n\ndef _assemble_kernel_mat_wkr(\n    j, tril_perms_lin, sig, use_E_cstr=False, exploit_sym=False, cols_m_limit=None\n):\n    r\"\"\"\n    Compute one row and column of the force field kernel matrix.\n\n    The Hessian of the Matern kernel is used with n = 2 (twice\n    differentiable). Each row and column consists of matrix-valued\n    blocks, which encode the interaction of one training point with all\n    others. The result is stored in shared memory (a global variable).\n\n    Parameters\n    ----------\n        j : int\n            Index of training point.\n        tril_perms_lin : :obj:`numpy.ndarray`\n            1D array (int) containing all recovered permutations\n            expanded as one large permutation to be applied to a tiled\n            copy of the object to be permuted.\n        sig : int\n            Hyper-parameter :math:`\\sigma`.\n        use_E_cstr : bool, optional\n            True: include energy constraints in the kernel,\n            False: default (s)GDML kernel.\n        exploit_sym : boolean, optional\n            Do not create symmetric entries of the kernel matrix twice\n            (this only works for spectific inputs for `cols_m_limit`)\n        cols_m_limit : int, optional\n            Limit the number of columns (include training points 1-`M`).\n            Note that each training points consists of multiple columns.\n\n    Returns\n    -------\n        int\n            Number of kernel matrix blocks created, divided by 2\n            (symmetric blocks are always created at together).\n    \"\"\"\n\n    global glob\n\n    R_desc = np.frombuffer(glob['R_desc']).reshape(glob['R_desc_shape'])\n    R_d_desc = np.frombuffer(glob['R_d_desc']).reshape(glob['R_d_desc_shape'])\n    K = np.frombuffer(glob['K']).reshape(glob['K_shape'])\n\n    desc_func = glob['desc_func']\n\n    n_train, dim_d = R_d_desc.shape[:2]\n    n_atoms = int((1 + np.sqrt(8 * dim_d + 1)) / 2)\n    dim_i = 3 * n_atoms\n    n_perms = int(len(tril_perms_lin) / dim_d)\n\n    if type(j) is tuple:  # Selective/\"fancy\" indexing\n        (\n            K_j,\n            j,\n            keep_idxs_3n,\n        ) = j  # (block index in final K, block index global, indices of partials within block)\n        blk_j = slice(K_j, K_j + len(keep_idxs_3n))\n\n    else:  # Sequential indexing\n        K_j = j * dim_i if j < n_train else n_train * dim_i + (j % n_train)\n        blk_j = slice(K_j, K_j + dim_i) if j < n_train else slice(K_j, K_j + 1)\n        keep_idxs_3n = slice(None)  # same as [:]\n\n    # Note: The modulo-operator wraps around the index pointer on the training points when\n    # energy constraints are used in the kernel. In that case each point is accessed twice.\n\n    # Create permutated variants of 'rj_desc' and 'rj_d_desc'.\n    rj_desc_perms = np.reshape(\n        np.tile(R_desc[j % n_train, :], n_perms)[tril_perms_lin],\n        (n_perms, -1),\n        order='F',\n    )\n\n    rj_d_desc = desc_func.d_desc_from_comp(R_d_desc[j % n_train, :, :])[0][\n        :, keep_idxs_3n\n    ]  # convert descriptor back to full representation\n\n    rj_d_desc_perms = np.reshape(\n        np.tile(rj_d_desc.T, n_perms)[:, tril_perms_lin], (-1, dim_d, n_perms)\n    )\n\n    mat52_base_div = 3 * sig**4\n    sqrt5 = np.sqrt(5.0)\n    sig_pow2 = sig**2\n\n    dim_i_keep = rj_d_desc.shape[1]\n    diff_ab_outer_perms = np.empty((dim_d, dim_i_keep))\n    diff_ab_perms = np.empty((n_perms, dim_d))\n    ri_d_desc = np.zeros((1, dim_d, dim_i))  # must be zeros!\n    k = np.empty((dim_i, dim_i_keep))\n\n    if (\n        j < n_train\n    ):  # This column only contrains second and first derivative constraints.\n\n        # for i in range(j if exploit_sym else 0, n_train):\n        for i in range(0, n_train):\n\n            blk_i = slice(i * dim_i, (i + 1) * dim_i)\n\n            # diff_ab_perms = R_desc[i, :] - rj_desc_perms\n            np.subtract(R_desc[i, :], rj_desc_perms, out=diff_ab_perms)\n\n            norm_ab_perms = sqrt5 * np.linalg.norm(diff_ab_perms, axis=1)\n            mat52_base_perms = np.exp(-norm_ab_perms / sig) / mat52_base_div * 5\n\n            # diff_ab_outer_perms = 5 * np.einsum(\n            #    'ki,kj->ij',\n            #    diff_ab_perms * mat52_base_perms[:, None],\n            #    np.einsum('ik,jki -> ij', diff_ab_perms, rj_d_desc_perms)\n            # )\n            np.einsum(\n                'ki,kj->ij',\n                diff_ab_perms * mat52_base_perms[:, None] * 5,\n                np.einsum('ki,jik -> kj', diff_ab_perms, rj_d_desc_perms),\n                out=diff_ab_outer_perms,\n            )\n\n            diff_ab_outer_perms -= np.einsum(\n                'ikj,j->ki',\n                rj_d_desc_perms,\n                (sig_pow2 + sig * norm_ab_perms) * mat52_base_perms,\n            )\n\n            # ri_d_desc = desc_func.d_desc_from_comp(R_d_desc[i, :, :])[0]\n            desc_func.d_desc_from_comp(R_d_desc[i, :, :], out=ri_d_desc)\n\n            # K[blk_i, blk_j] = ri_d_desc[0].T.dot(diff_ab_outer_perms)\n            np.dot(ri_d_desc[0].T, diff_ab_outer_perms, out=k)\n            K[blk_i, blk_j] = k\n\n            if exploit_sym and (\n                cols_m_limit is None or i < cols_m_limit\n            ):  # this will never be called with 'keep_idxs_3n' set to anything else than [:]\n                K[blk_j, blk_i] = K[blk_i, blk_j].T\n\n            # First derivative constraints\n            if use_E_cstr:\n\n                K_fe = (\n                    5\n                    * diff_ab_perms\n                    / (3 * sig**3)\n                    * (norm_ab_perms[:, None] + sig)\n                    * np.exp(-norm_ab_perms / sig)[:, None]\n                )\n\n                K_fe = -np.einsum('ik,jki -> j', K_fe, rj_d_desc_perms)\n\n                E_off_i = n_train * dim_i  # , K.shape[1] - n_train\n                K[E_off_i + i, blk_j] = K_fe\n\n    else:\n\n        if use_E_cstr:\n\n            # rj_d_desc = desc_func.d_desc_from_comp(R_d_desc[j % n_train, :, :])[0][\n            #    :, :\n            # ]  # convert descriptor back to full representation\n\n            # rj_d_desc_perms = np.reshape(\n            #    np.tile(rj_d_desc.T, n_perms)[:, tril_perms_lin], (-1, dim_d, n_perms)\n            # )\n\n            E_off_i = n_train * dim_i  # Account for 'alloc_extra_rows'!.\n            # blk_j_full = slice((j % n_train) * dim_i, ((j % n_train) + 1) * dim_i)\n            # for i in range((j % n_train) if exploit_sym else 0, n_train):\n            for i in range(0, n_train):\n\n                ri_desc_perms = np.reshape(\n                    np.tile(R_desc[i, :], n_perms)[tril_perms_lin],\n                    (n_perms, -1),\n                    order='F',\n                )\n\n                ri_d_desc = desc_func.d_desc_from_comp(R_d_desc[i, :, :])[\n                    0\n                ]  # convert descriptor back to full representation\n                ri_d_desc_perms = np.reshape(\n                    np.tile(ri_d_desc.T, n_perms)[:, tril_perms_lin],\n                    (-1, dim_d, n_perms),\n                )\n\n                diff_ab_perms = R_desc[j % n_train, :] - ri_desc_perms\n\n                norm_ab_perms = sqrt5 * np.linalg.norm(diff_ab_perms, axis=1)\n\n                K_fe = (\n                    5\n                    * diff_ab_perms\n                    / (3 * sig**3)\n                    * (norm_ab_perms[:, None] + sig)\n                    * np.exp(-norm_ab_perms / sig)[:, None]\n                )\n\n                K_fe = -np.einsum('ik,jki -> j', K_fe, ri_d_desc_perms)\n\n                blk_i_full = slice(i * dim_i, (i + 1) * dim_i)\n                K[blk_i_full, K_j] = K_fe  # vertical\n\n                K[E_off_i + i, K_j] = -(\n                    1 + (norm_ab_perms / sig) * (1 + norm_ab_perms / (3 * sig))\n                ).dot(np.exp(-norm_ab_perms / sig))\n\n    return blk_j.stop - blk_j.start\n\n\nclass GDMLTrain(object):\n    def __init__(self, max_memory=None, max_processes=None, use_torch=False):\n        \"\"\"\n        Train sGDML force fields.\n\n        This class is used to train models using different closed-form\n        and numerical solvers. GPU support is provided\n        through PyTorch (requires optional `torch` dependency to be\n        installed) for some solvers.\n\n        Parameters\n        ----------\n                max_memory : int, optional\n                        Limit the max. memory usage [GB]. This is only a\n                        soft limit that can not always be enforced.\n                max_processes : int, optional\n                        Limit the max. number of processes. Otherwise\n                        all CPU cores are used. This parameters has no\n                        effect if `use_torch=True`\n                use_torch : boolean, optional\n                        Use PyTorch to calculate predictions (if\n                        supported by solver)\n\n        Raises\n        ------\n            Exception\n                If multiple instsances of this class are created.\n            ImportError\n                If the optional PyTorch dependency is missing, but PyTorch features are used.\n        \"\"\"\n\n        global glob\n        if 'glob' not in globals():  # Don't allow more than one instance of this class.\n            glob = {}\n        else:\n            raise Exception(\n                'You can not create multiple instances of this class. Please reuse your first one.'\n            )\n\n        self.log = logging.getLogger(__name__)\n\n        total_memory = psutil.virtual_memory().total // 2**30  # bytes to GB)\n        self._max_memory = (\n            min(max_memory, total_memory) if max_memory is not None else total_memory\n        )\n\n        total_cpus = mp.cpu_count()\n        self._max_processes = (\n            min(max_processes, total_cpus) if max_processes is not None else total_cpus\n        )\n\n        self._use_torch = use_torch\n\n        if use_torch and not _has_torch:\n            raise ImportError(\n                'Optional PyTorch dependency not found! Please run \\'pip install sgdml[torch]\\' to install it or disable the PyTorch option.'\n            )\n\n    def __del__(self):\n\n        global glob\n\n        if 'glob' in globals():\n            del glob\n\n    def create_task(\n        self,\n        train_dataset,\n        n_train,\n        valid_dataset,\n        n_valid,\n        sig,\n        lam=1e-10,\n        perms=None,\n        use_sym=True,\n        use_E=True,\n        use_E_cstr=False,\n        callback=None,  # TODO: document me\n    ):\n        \"\"\"\n        Create a data structure of custom type `task`.\n\n        These data structures serve as recipes for model creation,\n        summarizing the configuration of one particular training run.\n        Training and test points are sampled from the provided dataset,\n        without replacement. If the same dataset if given for training\n        and testing, the subsets are drawn without overlap.\n\n        Each task also contains a choice for the hyper-parameters of the\n        training process and the MD5 fingerprints of the used datasets.\n\n        Parameters\n        ----------\n            train_dataset : :obj:`dict`\n                Data structure of custom type :obj:`dataset` containing\n                train dataset.\n            n_train : int\n                Number of training points to sample.\n            valid_dataset : :obj:`dict`\n                Data structure of custom type :obj:`dataset` containing\n                validation dataset.\n            n_valid : int\n                Number of validation points to sample.\n            sig : int\n                Hyper-parameter (kernel length scale).\n            lam : float, optional\n                Hyper-parameter lambda (regularization strength).\n            perms : :obj:`numpy.ndarray`, optional\n                An 2D array of size P x N containing P possible permutations\n                of the N atoms in the system. This argument takes priority over the ones\n                provided in the trainig dataset. No automatic discovery is run when this\n                argument is provided.\n            use_sym : bool, optional\n                True: include symmetries (sGDML), False: GDML.\n            use_E : bool, optional\n                True: reconstruct force field with corresponding potential energy surface,\n                False: ignore energy during training, even if energy labels are available\n                       in the dataset. The trained model will still be able to predict\n                       energies up to an unknown integration constant. Note, that the\n                       energy predictions accuracy will be untested.\n            use_E_cstr : bool, optional\n                True: include energy constraints in the kernel,\n                False: default (s)GDML.\n            callback : callable, optional\n                Progress callback function that takes three\n                arguments:\n                    current : int\n                        Current progress.\n                    total : int\n                        Task size.\n                    done_str : :obj:`str`, optional\n                        Once complete, this string is shown.\n\n        Returns\n        -------\n            dict\n                Data structure of custom type :obj:`task`.\n\n        Raises\n        ------\n            ValueError\n                If a reconstruction of the potential energy surface is requested,\n                but the energy labels are missing in the dataset.\n        \"\"\"\n\n        if use_E and 'E' not in train_dataset:\n            raise ValueError(\n                'No energy labels found in dataset!\\n'\n                + 'By default, force fields are always reconstructed including the\\n'\n                + 'corresponding potential energy surface (this can be turned off).\\n'\n                + 'However, the energy labels are missing in the provided dataset.\\n'\n            )\n\n        use_E_cstr = use_E and use_E_cstr\n\n        n_atoms = train_dataset['R'].shape[1]\n\n        if callback is not None:\n            callback = partial(callback, disp_str='Hashing dataset(s)')\n            callback(NOT_DONE)\n\n        md5_train = io.dataset_md5(train_dataset)\n        md5_valid = io.dataset_md5(valid_dataset)\n\n        if callback is not None:\n            callback(DONE)\n\n        if callback is not None:\n            callback = partial(\n                callback, disp_str='Sampling training and validation subsets'\n            )\n            callback(NOT_DONE)\n\n        if 'E' in train_dataset:\n            idxs_train = self.draw_strat_sample(train_dataset['E'], n_train)\n        else:\n            idxs_train = np.random.choice(\n                np.arange(train_dataset['F'].shape[0]),\n                n_train,\n                replace=False,\n            )\n\n        excl_idxs = (\n            idxs_train if md5_train == md5_valid else np.array([], dtype=np.uint)\n        )\n\n        if 'E' in valid_dataset:\n            idxs_valid = self.draw_strat_sample(\n                valid_dataset['E'],\n                n_valid,\n                excl_idxs=excl_idxs,\n            )\n        else:\n            idxs_valid_cands = np.setdiff1d(\n                np.arange(valid_dataset['F'].shape[0]), excl_idxs, assume_unique=True\n            )\n            idxs_valid = np.random.choice(idxs_valid_cands, n_valid, replace=False)\n\n        if callback is not None:\n            callback(DONE)\n\n        R_train = train_dataset['R'][idxs_train, :, :]\n        task = {\n            'type': 't',\n            'code_version': __version__,\n            'dataset_name': train_dataset['name'].astype(str),\n            'dataset_theory': train_dataset['theory'].astype(str),\n            'z': train_dataset['z'],\n            'R_train': R_train,\n            'F_train': train_dataset['F'][idxs_train, :, :],\n            'idxs_train': idxs_train,\n            'md5_train': md5_train,\n            'idxs_valid': idxs_valid,\n            'md5_valid': md5_valid,\n            'sig': sig,\n            'lam': lam,\n            'use_E': use_E,\n            'use_E_cstr': use_E_cstr,\n            'use_sym': use_sym,\n        }\n\n        if use_E:\n            task['E_train'] = train_dataset['E'][idxs_train]\n\n        lat_and_inv = None\n        if 'lattice' in train_dataset:\n            task['lattice'] = train_dataset['lattice']\n\n            try:\n                lat_and_inv = (task['lattice'], np.linalg.inv(task['lattice']))\n            except np.linalg.LinAlgError:\n                raise ValueError(  # TODO: Document me\n                    'Provided dataset contains invalid lattice vectors (not invertible). Note: Only rank 3 lattice vector matrices are supported.'\n                )\n\n        if 'r_unit' in train_dataset and 'e_unit' in train_dataset:\n            task['r_unit'] = train_dataset['r_unit']\n            task['e_unit'] = train_dataset['e_unit']\n\n        if use_sym:\n\n            # No permuations provided externally.\n            if perms is None:\n\n                if (\n                    'perms' in train_dataset\n                ):  # take perms from training dataset, if available\n\n                    n_perms = train_dataset['perms'].shape[0]\n                    self.log.info(\n                        'Using {:d} permutations included in dataset.'.format(n_perms)\n                    )\n\n                    task['perms'] = train_dataset['perms']\n\n                else:  # find perms from scratch\n\n                    n_train = R_train.shape[0]\n                    R_train_sync_mat = R_train\n                    if n_train > 1000:\n                        R_train_sync_mat = R_train[\n                            np.random.choice(n_train, 1000, replace=False), :, :\n                        ]\n                        self.log.info(\n                            'Symmetry search has been restricted to a random subset of 1000/{:d} training points for faster convergence.'.format(\n                                n_train\n                            )\n                        )\n\n                    # TOOD: PBCs disabled when matching (for now).\n                    # task['perms'] = perm.find_perms(\n                    #    R_train_sync_mat, train_dataset['z'], lat_and_inv=lat_and_inv, max_processes=self._max_processes,\n                    # )\n                    task['perms'] = perm.find_perms(\n                        R_train_sync_mat,\n                        train_dataset['z'],\n                        # lat_and_inv=None,\n                        lat_and_inv=lat_and_inv,\n                        callback=callback,\n                        max_processes=self._max_processes,\n                    )\n\n                    # NEW\n\n                    USE_EXTRA_PERMS = False\n\n                    if USE_EXTRA_PERMS:\n                        task['perms'] = perm.find_extra_perms(\n                            R_train_sync_mat,\n                            train_dataset['z'],\n                            # lat_and_inv=None,\n                            lat_and_inv=lat_and_inv,\n                            callback=callback,\n                            max_processes=self._max_processes,\n                        )\n\n                    # NEW\n\n                    # NEW\n\n                    USE_FRAG_PERMS = False\n\n                    if USE_FRAG_PERMS:\n                        frag_perms = perm.find_frag_perms(\n                            R_train_sync_mat,\n                            train_dataset['z'],\n                            lat_and_inv=lat_and_inv,\n                            max_processes=self._max_processes,\n                        )\n                        task['perms'] = np.vstack((task['perms'], frag_perms))\n                        task['perms'] = np.unique(task['perms'], axis=0)\n\n                        print(\n                            '| Keeping '\n                            + str(task['perms'].shape[0])\n                            + ' unique permutations.'\n                        )\n\n                    # NEW\n\n            else:  # use provided perms\n\n                n_atoms = len(task['z'])\n                n_perms, perms_len = perms.shape\n\n                if perms_len != n_atoms:\n                    raise ValueError(  # TODO: Document me\n                        'Provided permutations do not match the number of atoms in dataset.'\n                    )\n                else:\n\n                    self.log.info(\n                        'Using {:d} externally provided permutations.'.format(n_perms)\n                    )\n\n                    task['perms'] = perms\n\n        else:\n            task['perms'] = np.arange(train_dataset['R'].shape[1])[\n                None, :\n            ]  # no symmetries\n\n        return task\n\n    def create_task_from_model(self, model, dataset):\n        \"\"\"\n        Create a data structure of custom type `task` from existing\n        an structure of custom type `model`. This method is used to\n        resume training of unconverged models.\n\n        Any hyperparameter (including all symmetry permutations) in the\n        provided model file is reused without further optimization. The\n        current linear coeffiecient are used as starting point for the\n        iterative training procedure.\n\n        Parameters\n        ----------\n            model : :obj:`dict`\n                Data structure of custom type :obj:`model` based on which\n                to create the training task.\n            dataset : :obj:`dict`\n                Data structure of custom type :obj:`dataset` containing\n                the original dataset from which the provided model emerged.\n\n        Returns\n        -------\n            dict\n                Data structure of custom type :obj:`task`.\n        \"\"\"\n\n        idxs_train = model['idxs_train']\n        R_train = dataset['R'][idxs_train, :, :]\n        F_train = dataset['F'][idxs_train, :, :]\n\n        use_E = 'e_err' in model\n        use_E_cstr = 'alphas_E' in model\n        use_sym = model['perms'].shape[0] > 1\n\n        task = {\n            'type': 't',\n            'code_version': __version__,\n            'dataset_name': model['dataset_name'],\n            'dataset_theory': model['dataset_theory'],\n            'z': model['z'],\n            'R_train': R_train,\n            'F_train': F_train,\n            'idxs_train': idxs_train,\n            'md5_train': model['md5_train'],\n            'idxs_valid': model['idxs_valid'],\n            'md5_valid': model['md5_valid'],\n            'sig': model['sig'],\n            'lam': model['lam'],\n            'use_E': model['use_E'],\n            'use_E_cstr': use_E_cstr,\n            'use_sym': use_sym,\n            'perms': model['perms'],\n        }\n\n        if use_E:\n            task['E_train'] = dataset['E'][idxs_train]\n\n        if 'lattice' in model:\n            task['lattice'] = model['lattice']\n\n        if 'r_unit' in model and 'e_unit' in model:\n            task['r_unit'] = model['r_unit']\n            task['e_unit'] = model['e_unit']\n\n        if 'alphas_F' in model:\n            task['alphas0_F'] = model['alphas_F']\n\n        if 'alphas_E' in model:\n            task['alphas0_E'] = model['alphas_E']\n\n        if 'solver_iters' in model:\n            task['solver_iters'] = model['solver_iters']\n\n        if 'inducing_pts_idxs' in model:\n            task['inducing_pts_idxs'] = model['inducing_pts_idxs']\n\n        return task\n\n    def create_model(\n        self,\n        task,\n        solver,\n        R_desc,\n        R_d_desc,\n        tril_perms_lin,\n        std,\n        alphas_F,\n        alphas_E=None,\n    ):\n        \"\"\"\n        Create a data structure of custom type `model`.\n\n        These data structures contain the trained model are everything\n        that is needed to generate predictions for new inputs.\n\n        Each task also contains the MD5 fingerprints of the used datasets.\n\n        Parameters\n        ----------\n            task : :obj:`dict`\n                Data structure of custom type :obj:`task` from which\n                the model emerged.\n            solver : :obj:`str`\n                Identifier string for the solver that has been used to\n                train this model.\n            R_desc : :obj:`numpy.ndarray`, optional\n                    An 2D array of size M x D containing the\n                    descriptors of dimension D for M\n                    molecules.\n            R_d_desc : :obj:`numpy.ndarray`, optional\n                    A 2D array of size M x D x 3N containing of the\n                    descriptor Jacobians for M molecules. The descriptor\n                    has dimension D with 3N partial derivatives with\n                    respect to the 3N Cartesian coordinates of each atom.\n            tril_perms_lin : :obj:`numpy.ndarray`\n                1D array containing all recovered permutations\n                expanded as one large permutation to be applied to a\n                tiled copy of the object to be permuted.\n            std : float\n                Standard deviation of the training labels.\n            alphas_F : :obj:`numpy.ndarray`\n                    A 1D array of size 3NM containing of the linear\n                    coefficients that correspond to the force constraints.\n            alphas_E : :obj:`numpy.ndarray`, optional\n                    A 1D array of size N containing of the linear\n                    coefficients that correspond to the energy constraints.\n\n        Returns\n        -------\n            dict\n                Data structure of custom type :obj:`model`.\n        \"\"\"\n\n        n_train, dim_d = R_d_desc.shape[:2]\n        n_atoms = int((1 + np.sqrt(8 * dim_d + 1)) / 2)\n\n        desc = Desc(\n            n_atoms,\n            max_processes=self._max_processes,\n        )\n\n        dim_i = desc.dim_i\n        R_d_desc_alpha = desc.d_desc_dot_vec(R_d_desc, alphas_F.reshape(-1, dim_i))\n\n        model = {\n            'type': 'm',\n            'code_version': __version__,\n            'dataset_name': task['dataset_name'],\n            'dataset_theory': task['dataset_theory'],\n            'solver_name': solver,\n            'z': task['z'],\n            'idxs_train': task['idxs_train'],\n            'md5_train': task['md5_train'],\n            'idxs_valid': task['idxs_valid'],\n            'md5_valid': task['md5_valid'],\n            'n_test': 0,\n            'md5_test': None,\n            'f_err': {'mae': np.nan, 'rmse': np.nan},\n            'R_desc': R_desc.T,\n            'R_d_desc_alpha': R_d_desc_alpha,\n            'c': 0.0,\n            'std': std,\n            'sig': task['sig'],\n            'lam': task['lam'],\n            'alphas_F': alphas_F,\n            'perms': task['perms'],\n            'tril_perms_lin': tril_perms_lin,\n            'use_E': task['use_E'],\n        }\n\n        if task['use_E']:\n            model['e_err'] = {'mae': np.nan, 'rmse': np.nan}\n\n            if task['use_E_cstr']:\n                model['alphas_E'] = alphas_E\n\n        if 'lattice' in task:\n            model['lattice'] = task['lattice']\n\n        if 'r_unit' in task and 'e_unit' in task:\n            model['r_unit'] = task['r_unit']\n            model['e_unit'] = task['e_unit']\n\n        return model\n\n    # from memory_profiler import profile\n    # @profile\n    def train(  # noqa: C901\n        self,\n        task,\n        save_progr_callback=None,  # TODO: document me\n        callback=None,\n    ):\n        \"\"\"\n        Train a model based on a training task.\n\n        Parameters\n        ----------\n            task : :obj:`dict`\n                Data structure of custom type :obj:`task`.\n            desc_callback : callable, optional\n                Descriptor and descriptor Jacobian generation status.\n                    current : int\n                        Current progress (number of completed descriptors).\n                    total : int\n                        Task size (total number of descriptors to create).\n                    done_str : :obj:`str`, optional\n                        Once complete, this string contains the\n                        time it took complete this task (seconds).\n            ker_progr_callback : callable, optional\n                Kernel assembly progress function that takes three\n                arguments:\n                    current : int\n                        Current progress (number of completed entries).\n                    total : int\n                        Task size (total number of entries to create).\n                    done_str : :obj:`str`, optional\n                        Once complete, this string contains the\n                        time it took to assemble the kernel (seconds).\n            solve_callback : callable, optional\n                Linear system solver status.\n                    done : bool\n                        False when solver starts, True when it finishes.\n                    done_str : :obj:`str`, optional\n                        Once done, this string contains the runtime\n                        of the solver (seconds).\n\n        Returns\n        -------\n            :obj:`dict`\n                Data structure of custom type :obj:`model`.\n\n        Raises\n        ------\n            ValueError\n                If the provided dataset contains invalid lattice\n                vectors.\n        \"\"\"\n\n        task = dict(task)  # make mutable\n\n        n_train, n_atoms = task['R_train'].shape[:2]\n\n        desc = Desc(\n            n_atoms,\n            max_processes=self._max_processes,\n        )\n\n        n_perms = task['perms'].shape[0]\n        tril_perms = np.array([Desc.perm(p) for p in task['perms']])\n\n        dim_i = 3 * n_atoms\n        dim_d = desc.dim\n\n        perm_offsets = np.arange(n_perms)[:, None] * dim_d\n        tril_perms_lin = (tril_perms + perm_offsets).flatten('F')\n\n        # TODO: check if all atoms are in span of lattice vectors, otherwise suggest that\n        # rows and columns might have been switched.\n        lat_and_inv = None\n        if 'lattice' in task:\n            try:\n                lat_and_inv = (task['lattice'], np.linalg.inv(task['lattice']))\n            except np.linalg.LinAlgError:\n                raise ValueError(  # TODO: Document me\n                    'Provided dataset contains invalid lattice vectors (not invertible). Note: Only rank 3 lattice vector matrices are supported.'\n                )\n\n            # # TODO: check if all atoms are within unit cell\n            # for r in task['R_train']:\n            #    r_lat = lat_and_inv[1].dot(r.T)\n            #    if not (r_lat >= 0).all():\n            #         raise ValueError( # TODO: Document me\n            #            'Some atoms appear outside of the unit cell! Please check lattice vectors in dataset file.'\n            #         )\n            #        #pass\n\n        R = task['R_train'].reshape(n_train, -1)\n        R_desc, R_d_desc = desc.from_R(\n            R,\n            lat_and_inv=lat_and_inv,\n            callback=partial(\n                callback, disp_str='Generating descriptors and their Jacobians'\n            )\n            if callback is not None\n            else None,\n        )\n\n        # Generate label vector.\n        E_train_mean = None\n        y = task['F_train'].ravel().copy()\n        if task['use_E'] and task['use_E_cstr']:\n            E_train = task['E_train'].ravel().copy()\n            E_train_mean = np.mean(E_train)\n\n            y = np.hstack((y, -E_train + E_train_mean))\n\n        y_std = np.std(y)\n        y /= y_std\n\n        max_memory_bytes = self._max_memory * 1024**3\n\n        # Memory cost of analytic solver\n        est_bytes_analytic = Analytic.est_memory_requirement(n_train, n_atoms)\n\n        # Memory overhead (solver independent)\n        est_bytes_overhead = y.nbytes\n        est_bytes_overhead += R.nbytes\n        est_bytes_overhead += R_desc.nbytes\n        est_bytes_overhead += R_d_desc.nbytes\n\n        solver_keys = {}\n\n        use_analytic_solver = (\n            est_bytes_analytic + est_bytes_overhead\n        ) < max_memory_bytes\n\n        # Fall back to analytic solver, if iterative solver file is missing.\n        base_path = os.path.dirname(os.path.abspath(__file__))\n        iter_solver_path = os.path.join(base_path, 'solvers/iterative.py')\n        if not os.path.exists(iter_solver_path):\n            self.log.debug('Iterative solver not installed.')\n            use_analytic_solver = True\n\n        # use_analytic_solver = True  # remove me!\n\n        if use_analytic_solver:\n\n            self.log.info(\n                'Using analytic solver (expected memory use: ~{})'.format(\n                    ui.gen_memory_str(est_bytes_analytic + est_bytes_overhead)\n                )\n            )\n\n            analytic = Analytic(self, desc, callback=callback)\n            alphas = analytic.solve(task, R_desc, R_d_desc, tril_perms_lin, y)\n\n        else:\n\n            max_n_inducing_pts = Iterative.max_n_inducing_pts(\n                n_train, n_atoms, max_memory_bytes\n            )\n            est_bytes_iterative = Iterative.est_memory_requirement(\n                n_train, max_n_inducing_pts, n_atoms\n            )\n\n            self.log.info(\n                'Using iterative solver (expected memory use: ~{})'.format(\n                    ui.gen_memory_str(est_bytes_iterative + est_bytes_overhead)\n                )\n            )\n\n            alphas_F = task['alphas0_F'] if 'alphas0_F' in task else None\n            alphas_E = task['alphas0_E'] if 'alphas0_E' in task else None\n\n            iterative = Iterative(\n                self,\n                desc,\n                self._max_memory,\n                self._max_processes,\n                self._use_torch,\n                callback=callback,\n            )\n            (\n                alphas,\n                solver_keys['solver_tol'],\n                solver_keys[\n                    'solver_iters'\n                ],  # number of iterations performed (cg solver)\n                solver_keys['solver_resid'],  # residual of solution\n                train_rmse,\n                solver_keys['inducing_pts_idxs'],\n                is_conv,\n            ) = iterative.solve(\n                task,\n                R_desc,\n                R_d_desc,\n                tril_perms_lin,\n                y,\n                y_std,\n                save_progr_callback=save_progr_callback,\n            )\n\n            solver_keys['norm_y_train'] = np.linalg.norm(y)\n\n            if not is_conv:\n                self.log.warning(\n                    'Iterative solver did not converge!\\n'\n                    + 'The optimization problem underlying this force field reconstruction task seems to be highly ill-conditioned.\\n\\n'\n                    + ui.color_str('Troubleshooting tips:\\n', bold=True)\n                    + ui.wrap_indent_str(\n                        '(1) ',\n                        'Are the provided geometries highly correlated (i.e. very similar to each other)?',\n                    )\n                    + '\\n'\n                    + ui.wrap_indent_str(\n                        '(2) ', 'Try a larger length scale (sigma) parameter.'\n                    )\n                    + '\\n\\n'\n                    + ui.color_str('Note:', bold=True)\n                    + ' We will continue with this unconverged model, but its accuracy will likely be very bad.'\n                )\n\n        alphas_E = None\n        alphas_F = alphas\n        if task['use_E_cstr']:\n            alphas_E = alphas[-n_train:]\n            alphas_F = alphas[:-n_train]\n\n        model = self.create_model(\n            task,\n            'analytic' if use_analytic_solver else 'cg',\n            R_desc,\n            R_d_desc,\n            tril_perms_lin,\n            y_std,\n            alphas_F,\n            alphas_E=alphas_E,\n        )\n        model.update(solver_keys)\n\n        # Recover integration constant.\n        # Note: if energy constraints are included in the kernel (via 'use_E_cstr'), do not\n        # compute the integration constant, but simply set it to the mean of the training energies\n        # (which was subtracted from the labels before training).\n        if model['use_E']:\n            c = (\n                self._recov_int_const(model, task, R_desc=R_desc, R_d_desc=R_d_desc)\n                if E_train_mean is None\n                else E_train_mean\n            )\n            # if c is None:\n            #    # Something does not seem right. Turn off energy predictions for this model, only output force predictions.\n            #    model['use_E'] = False\n            # else:\n            #    model['c'] = c\n\n            model['c'] = c\n\n        return model\n\n    def _recov_int_const(\n        self, model, task, R_desc=None, R_d_desc=None\n    ):  # TODO: document e_err_inconsist return\n        \"\"\"\n        Estimate the integration constant for a force field model.\n\n        The offset between the energies predicted for the original training\n        data and the true energy labels is computed in the least square sense.\n        Furthermore, common issues with the user-provided datasets are self\n        diagnosed here.\n\n        Parameters\n        ----------\n            model : :obj:`dict`\n                Data structure of custom type :obj:`model`.\n            task : :obj:`dict`\n                Data structure of custom type :obj:`task`.\n            R_desc : :obj:`numpy.ndarray`, optional\n                    An 2D array of size M x D containing the\n                    descriptors of dimension D for M\n                    molecules.\n            R_d_desc : :obj:`numpy.ndarray`, optional\n                    A 2D array of size M x D x 3N containing of the\n                    descriptor Jacobians for M molecules. The descriptor\n                    has dimension D with 3N partial derivatives with\n                    respect to the 3N Cartesian coordinates of each atom.\n\n        Returns\n        -------\n            float\n                Estimate for the integration constant.\n\n        Raises\n        ------\n            ValueError\n                If the sign of the force labels in the dataset from\n                which the model emerged is switched (e.g. gradients\n                instead of forces).\n            ValueError\n                If inconsistent/corrupted energy labels are detected\n                in the provided dataset.\n            ValueError\n                If potentially inconsistent scales in energy vs.\n                force labels are detected in the provided dataset.\n        \"\"\"\n\n        gdml_predict = GDMLPredict(\n            model,\n            max_memory=self._max_memory,\n            max_processes=self._max_processes,\n            use_torch=self._use_torch,\n            log_level=logging.CRITICAL,\n        )\n\n        gdml_predict.set_R_desc(R_desc)\n        gdml_predict.set_R_d_desc(R_d_desc)\n\n        E_pred, _ = gdml_predict.predict()\n        E_ref = np.squeeze(task['E_train'])\n\n        e_fact = np.linalg.lstsq(\n            np.column_stack((E_pred, np.ones(E_ref.shape))), E_ref, rcond=-1\n        )[0][0]\n        corrcoef = np.corrcoef(E_ref, E_pred)[0, 1]\n\n        # import matplotlib.pyplot as plt\n        # sidx = np.argsort(E_ref)\n        # plt.plot(E_ref[sidx])\n        # c = np.sum(E_ref - E_pred) / E_ref.shape[0]\n        # plt.plot(E_pred[sidx]+c)\n        # plt.show()\n        # sys.exit()\n\n        # import matplotlib.pyplot as plt\n        # sidx = np.argsort(F_ref)\n        # plt.plot(F_ref[sidx])\n        # c = np.sum(F_ref - F_pred) / F_ref.shape[0]\n        # plt.plot(F_pred[sidx],'--')\n        # plt.show()\n        # sys.exit()\n\n        if np.sign(e_fact) == -1:\n            self.log.warning(\n                'It looks like the provided dataset may contain gradients instead of force labels (flipped sign).\\n\\n'\n                + ui.color_str('Troubleshooting tips:\\n', bold=True)\n                + ui.wrap_indent_str(\n                    '(1) ',\n                    'Verify the sign of your force labels.',\n                )\n                + '\\n'\n                + ui.wrap_indent_str(\n                    '(2) ',\n                    'This issue might very well just be a sympthom of using too few trainnig data and your labels are correct.',\n                )\n            )\n\n        if corrcoef < 0.95:\n            self.log.warning(\n                'Potentially inconsistent energy labels detected!\\n'\n                + 'The predicted energies for the training data are only weakly correlated with the reference labels (correlation coefficient {:.2f}). Note that correlation is independent of scale, which indicates that the issue is most likely not just a unit conversion error.\\n\\n'.format(\n                    corrcoef\n                )\n                + ui.color_str('Troubleshooting tips:\\n', bold=True)\n                + ui.wrap_indent_str(\n                    '(1) ',\n                    'Verify the correct correspondence between geometries and labels in the provided dataset.',\n                )\n                + '\\n'\n                + ui.wrap_indent_str(\n                    '(2) ',\n                    'This issue might very well just be a sympthom of using too few trainnig data and your labels are correct.',\n                )\n                + '\\n'\n                + ui.wrap_indent_str(\n                    '(3) ', 'Verify the consistency between energy and force labels.'\n                )\n                + '\\n'\n                + ui.wrap_indent_str(\n                    '    - ', 'Correspondence between force and energy labels correct?'\n                )\n                + '\\n'\n                + ui.wrap_indent_str(\n                    '    - ',\n                    'Accuracy of forces (convergence of your ab-initio calculations)?',\n                )\n                + '\\n'\n                + ui.wrap_indent_str(\n                    '    - ',\n                    'Was the same level of theory used to compute forces and energies?',\n                )\n                + '\\n'\n                + ui.wrap_indent_str(\n                    '(4) ',\n                    'Is the training data spread too broadly (i.e. weakly sampled transitions between example clusters)?',\n                )\n                + '\\n'\n                + ui.wrap_indent_str(\n                    '(5) ', 'Are there duplicate geometries in the training data?'\n                )\n                + '\\n'\n                + ui.wrap_indent_str(\n                    '(6) ', 'Are there any corrupted data points (e.g. parsing errors)?'\n                )\n            )\n\n        if np.abs(e_fact - 1) > 1e-1:\n            self.log.warning(\n                'Potentially inconsistent scales in energy vs. force labels detected!\\n'\n                + 'The integrated force predictions differ from the reference energy labels by factor ~{:.2f} (for the training data), meaning that this model will likely fail to predict energies accurately in real-world use.\\n\\n'.format(\n                    e_fact\n                )\n                + ui.color_str('Troubleshooting tips:\\n', bold=True)\n                + ui.wrap_indent_str(\n                    '(1) ', 'Verify consistency of units in energy and force labels.'\n                )\n                + '\\n'\n                + ui.wrap_indent_str(\n                    '(2) ',\n                    'This issue might very well just be a sympthom of using too few trainnig data and your labels are correct.',\n                )\n                + '\\n'\n                + ui.wrap_indent_str(\n                    '(3) ',\n                    'Is the training data spread too broadly (i.e. weakly sampled transitions between example clusters)?',\n                )\n            )\n\n        # Least squares estimate for integration constant.\n        return np.sum(E_ref - E_pred) / E_ref.shape[0]\n\n    def _assemble_kernel_mat(\n        self,\n        R_desc,\n        R_d_desc,\n        tril_perms_lin,\n        sig,\n        desc,  # TODO: document me\n        use_E_cstr=False,\n        col_idxs=np.s_[:],  # TODO: document me\n        alloc_extra_rows=0,  # TODO: document me\n        callback=None,\n    ):\n        r\"\"\"\n        Compute force field kernel matrix.\n\n        The Hessian of the Matern kernel is used with n = 2 (twice\n        differentiable). Each row and column consists of matrix-valued blocks,\n        which encode the interaction of one training point with all others. The\n        result is stored in shared memory (a global variable).\n\n        Parameters\n        ----------\n            R_desc : :obj:`numpy.ndarray`\n                Array containing the descriptor for each training point.\n            R_d_desc : :obj:`numpy.ndarray`\n                Array containing the gradient of the descriptor for\n                each training point.\n            tril_perms_lin : :obj:`numpy.ndarray`\n                1D array containing all recovered permutations\n                expanded as one large permutation to be applied to a\n                tiled copy of the object to be permuted.\n            sig : int\n                Hyper-parameter :math:`\\sigma`(kernel length scale).\n            use_E_cstr : bool, optional\n                True: include energy constraints in the kernel,\n                False: default (s)GDML kernel.\n            callback : callable, optional\n                Kernel assembly progress function that takes three\n                arguments:\n                    current : int\n                        Current progress (number of completed entries).\n                    total : int\n                        Task size (total number of entries to create).\n                    done_str : :obj:`str`, optional\n                        Once complete, this string contains the\n                        time it took to assemble the kernel (seconds).\n            cols_m_limit : int, optional (DEPRECATED)\n                Only generate the columns up to index 'cols_m_limit'. This creates\n                a M*3N x cols_m_limit*3N kernel matrix, instead of M*3N x M*3N.\n            cols_3n_keep_idxs : :obj:`numpy.ndarray`, optional\n                Only generate columns with the given indices in the 3N x 3N\n                kernel function. The resulting kernel matrix will have dimension\n                M*3N x M*len(cols_3n_keep_idxs).\n\n        Returns\n        -------\n            :obj:`numpy.ndarray`\n                Force field kernel matrix.\n        \"\"\"\n\n        global glob\n\n        # Note: This function does not support unsorted (ascending) index arrays.\n        # if not isinstance(col_idxs, slice):\n        #    assert np.array_equal(col_idxs, np.sort(col_idxs))\n\n        n_train, dim_d = R_d_desc.shape[:2]\n        dim_i = 3 * int((1 + np.sqrt(8 * dim_d + 1)) / 2)\n\n        # Determine size of kernel matrix.\n        K_n_rows = n_train * dim_i\n\n        # Account for additional rows (and columns) due to energy constraints in the kernel matrix.\n        if use_E_cstr:\n            K_n_rows += n_train\n\n        if isinstance(col_idxs, slice):  # indexed by slice\n            K_n_cols = len(range(*col_idxs.indices(K_n_rows)))\n        else:  # indexed by list\n\n            # TODO: throw exeption with description\n            assert len(col_idxs) == len(set(col_idxs))  # assume no dublicate indices\n\n            # TODO: throw exeption with description\n            # Note: This function does not support unsorted (ascending) index arrays.\n            assert np.array_equal(col_idxs, np.sort(col_idxs))\n\n            K_n_cols = len(col_idxs)\n\n        # Make sure no indices are outside of the valid range.\n        if K_n_cols > K_n_rows:\n            raise ValueError('Columns indexed beyond range.')\n\n        exploit_sym = False\n        cols_m_limit = None\n\n        # Check if range is a subset of training points (as opposed to a subset of partials of multiple points).\n        is_M_subset = (\n            isinstance(col_idxs, slice)\n            and (col_idxs.start is None or col_idxs.start % dim_i == 0)\n            and (col_idxs.stop is None or col_idxs.stop % dim_i == 0)\n            and col_idxs.step is None\n        )\n        if is_M_subset:\n            M_slice_start = (\n                None if col_idxs.start is None else int(col_idxs.start / dim_i)\n            )\n            M_slice_stop = None if col_idxs.stop is None else int(col_idxs.stop / dim_i)\n            M_slice = slice(M_slice_start, M_slice_stop)\n\n            J = range(*M_slice.indices(n_train + (n_train if use_E_cstr else 0)))\n\n            if M_slice_start is None:\n                exploit_sym = True\n                cols_m_limit = M_slice_stop\n\n        else:\n\n            if isinstance(col_idxs, slice):\n                # random = list(range(*col_idxs.indices(n_train * dim_i)))\n                col_idxs = list(range(*col_idxs.indices(K_n_rows)))\n\n            # Separate column indices of force-force and force-energy constraints.\n            cond = col_idxs >= (n_train * dim_i)\n            ff_col_idxs, fe_col_idxs = col_idxs[~cond], col_idxs[cond]\n\n            # M - number training\n            # N - number atoms\n\n            n_idxs = np.concatenate(\n                [np.mod(ff_col_idxs, dim_i), np.zeros(fe_col_idxs.shape, dtype=int)]\n            )  # Column indices that go beyond force-force correlations need a different treatment.\n\n            m_idxs = np.concatenate([np.array(ff_col_idxs) // dim_i, fe_col_idxs])\n            m_idxs_uniq = np.unique(m_idxs)  # which points to include?\n\n            m_n_idxs = [\n                list(n_idxs[np.where(m_idxs == m_idx)]) for m_idx in m_idxs_uniq\n            ]\n            m_n_idxs_lens = [len(m_n_idx) for m_n_idx in m_n_idxs]\n\n            m_n_idxs_lens.insert(0, 0)\n            blk_start_idxs = list(\n                np.cumsum(m_n_idxs_lens[:-1])\n            )  # index within K at which each block starts\n\n            # tupels: (block index in final K, block index global, indices of partials within block)\n            J = list(zip(blk_start_idxs, m_idxs_uniq, m_n_idxs))\n\n        if callback is not None:\n            callback(0, 100)  # 0%\n\n        if self._use_torch:\n            if not _has_torch:\n                raise ImportError(\n                    'Optional PyTorch dependency not found! Please run \\'pip install sgdml[torch]\\' to install it or disable the PyTorch option.'\n                )\n\n            K = np.empty((K_n_rows + alloc_extra_rows, K_n_cols))\n\n            if J is not list:\n                J = list(J)\n\n            global torch_assemble_done\n            torch_assemble_todo, torch_assemble_done = K_n_cols, 0\n\n            def progress_callback(done):\n\n                global torch_assemble_done\n                torch_assemble_done += done\n\n                if callback is not None:\n                    callback(\n                        torch_assemble_done,\n                        torch_assemble_todo,\n                        newline_when_done=False,\n                    )\n\n            start = timeit.default_timer()\n\n            if _torch_cuda_is_available:\n                torch_device = 'cuda'\n            elif _torch_mps_is_available:\n                torch_device = 'mps'\n            else:\n                torch_device = 'cpu'\n\n            R_desc_torch = torch.from_numpy(R_desc).to(torch_device)  # N, d\n            R_d_desc_torch = torch.from_numpy(R_d_desc).to(torch_device)\n\n            from .torchtools import GDMLTorchAssemble\n\n            torch_assemble = GDMLTorchAssemble(\n                J,\n                tril_perms_lin,\n                sig,\n                use_E_cstr,\n                R_desc_torch,\n                R_d_desc_torch,\n                out=K[:K_n_rows, :],\n                callback=progress_callback,\n            )\n\n            # Enable data parallelism\n            n_gpu = torch.cuda.device_count()\n            if n_gpu > 1:\n                torch_assemble = torch.nn.DataParallel(torch_assemble)\n            torch_assemble.to(torch_device)\n\n            torch_assemble.forward(torch.arange(len(J)))\n            del torch_assemble\n\n            del R_desc_torch\n            del R_d_desc_torch\n\n            stop = timeit.default_timer()\n\n            if callback is not None:\n                dur_s = stop - start\n                sec_disp_str = 'took {:.1f} s'.format(dur_s) if dur_s >= 0.1 else ''\n                callback(DONE, sec_disp_str=sec_disp_str)\n\n            return K\n\n        K = mp.RawArray('d', (K_n_rows + alloc_extra_rows) * K_n_cols)\n        glob['K'], glob['K_shape'] = K, (K_n_rows + alloc_extra_rows, K_n_cols)\n        glob['R_desc'], glob['R_desc_shape'] = _share_array(R_desc, 'd')\n        glob['R_d_desc'], glob['R_d_desc_shape'] = _share_array(R_d_desc, 'd')\n\n        glob['desc_func'] = desc\n\n        start = timeit.default_timer()\n\n        pool = None\n        map_func = map\n        if self._max_processes != 1 and mp.cpu_count() > 1:\n            pool = Pool(\n                (self._max_processes or mp.cpu_count()) - 1\n            )  # exclude main process\n            map_func = pool.imap_unordered\n\n        todo, done = K_n_cols, 0\n        for done_wkr in map_func(\n            partial(\n                _assemble_kernel_mat_wkr,\n                tril_perms_lin=tril_perms_lin,\n                sig=sig,\n                use_E_cstr=use_E_cstr,\n                exploit_sym=exploit_sym,\n                cols_m_limit=cols_m_limit,\n            ),\n            J,\n        ):\n            done += done_wkr\n\n            if callback is not None:\n                callback(done, todo, newline_when_done=False)\n\n        if pool is not None:\n            pool.close()\n            pool.join()  # Wait for the worker processes to terminate (to measure total runtime correctly).\n            pool = None\n\n        stop = timeit.default_timer()\n\n        if callback is not None:\n            dur_s = stop - start\n            sec_disp_str = 'took {:.1f} s'.format(dur_s) if dur_s >= 0.1 else ''\n            callback(DONE, sec_disp_str=sec_disp_str)\n\n        # Release some memory.\n        glob.pop('K', None)\n        glob.pop('R_desc', None)\n        glob.pop('R_d_desc', None)\n\n        return np.frombuffer(K).reshape((K_n_rows + alloc_extra_rows), K_n_cols)\n\n    def draw_strat_sample(self, T, n, excl_idxs=None):\n        \"\"\"\n        Draw sample from dataset that preserves its original distribution.\n\n        The distribution is estimated from a histogram were the bin size is\n        determined using the Freedman-Diaconis rule. This rule is designed to\n        minimize the difference between the area under the empirical\n        probability distribution and the area under the theoretical\n        probability distribution. A reduced histogram is then constructed by\n        sampling uniformly in each bin. It is intended to populate all bins\n        with at least one sample in the reduced histogram, even for small\n        training sizes.\n\n        Parameters\n        ----------\n            T : :obj:`numpy.ndarray`\n                Dataset to sample from.\n            n : int\n                Number of examples.\n            excl_idxs : :obj:`numpy.ndarray`, optional\n                Array of indices to exclude from sample.\n\n        Returns\n        -------\n            :obj:`numpy.ndarray`\n                Array of indices that form the sample.\n        \"\"\"\n\n        if excl_idxs is None or len(excl_idxs) == 0:\n            excl_idxs = None\n\n        if n == 0:\n            return np.array([], dtype=np.uint)\n\n        if T.size == n:  # TODO: this only works if excl_idxs=None\n            assert excl_idxs is None\n            return np.arange(n)\n\n        if n == 1:\n            idxs_all_non_excl = np.setdiff1d(\n                np.arange(T.size), excl_idxs, assume_unique=True\n            )\n            return np.array([np.random.choice(idxs_all_non_excl)])\n\n        # Freedman-Diaconis rule\n        h = 2 * np.subtract(*np.percentile(T, [75, 25])) / np.cbrt(n)\n        n_bins = int(np.ceil((np.max(T) - np.min(T)) / h)) if h > 0 else 1\n        n_bins = min(\n            n_bins, int(n / 2)\n        )  # Limit number of bins to half of requested subset size.\n\n        bins = np.linspace(np.min(T), np.max(T), n_bins, endpoint=False)\n        idxs = np.digitize(T, bins)\n\n        # Exclude restricted indices.\n        if excl_idxs is not None and excl_idxs.size > 0:\n            idxs[excl_idxs] = n_bins + 1  # Impossible bin.\n\n        uniq_all, cnts_all = np.unique(idxs, return_counts=True)\n\n        # Remove restricted bin.\n        if excl_idxs is not None and excl_idxs.size > 0:\n            excl_bin_idx = np.where(uniq_all == n_bins + 1)\n            cnts_all = np.delete(cnts_all, excl_bin_idx)\n            uniq_all = np.delete(uniq_all, excl_bin_idx)\n\n        # Compute reduced bin counts.\n        reduced_cnts = np.ceil(cnts_all / np.sum(cnts_all, dtype=float) * n).astype(int)\n        reduced_cnts = np.minimum(\n            reduced_cnts, cnts_all\n        )  # limit reduced_cnts to what is available in cnts_all\n\n        # Reduce/increase bin counts to desired total number of points.\n        reduced_cnts_delta = n - np.sum(reduced_cnts)\n\n        while np.abs(reduced_cnts_delta) > 0:\n\n            # How many members can we remove from an arbitrary bucket, without any bucket with more than one member going to zero?\n            max_bin_reduction = np.min(reduced_cnts[np.where(reduced_cnts > 1)]) - 1\n\n            # Generate additional bin members to fill up/drain bucket counts of subset. This array contains (repeated) bucket IDs.\n            outstanding = np.random.choice(\n                uniq_all,\n                min(max_bin_reduction, np.abs(reduced_cnts_delta)),\n                p=(reduced_cnts - 1) / np.sum(reduced_cnts - 1, dtype=float),\n                replace=True,\n            )\n            uniq_outstanding, cnts_outstanding = np.unique(\n                outstanding, return_counts=True\n            )  # Aggregate bucket IDs.\n\n            outstanding_bucket_idx = np.where(\n                np.in1d(uniq_all, uniq_outstanding, assume_unique=True)\n            )[\n                0\n            ]  # Bucket IDs to Idxs.\n            reduced_cnts[outstanding_bucket_idx] += (\n                np.sign(reduced_cnts_delta) * cnts_outstanding\n            )\n            reduced_cnts_delta = n - np.sum(reduced_cnts)\n\n        # Draw examples for each bin.\n        idxs_train = np.empty((0,), dtype=int)\n        for uniq_idx, bin_cnt in zip(uniq_all, reduced_cnts):\n            idx_in_bin_all = np.where(idxs.ravel() == uniq_idx)[0]\n            idxs_train = np.append(\n                idxs_train, np.random.choice(idx_in_bin_all, bin_cnt, replace=False)\n            )\n\n        return idxs_train\n"
  },
  {
    "path": "sgdml/utils/__init__.py",
    "content": ""
  },
  {
    "path": "sgdml/utils/desc.py",
    "content": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2022 Stefan Chmiela, Luis Galvez\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nimport numpy as np\nimport scipy as sp\nfrom scipy import spatial\n\nimport multiprocessing as mp\n\nPool = mp.get_context('fork').Pool\n\nfrom functools import partial\nimport timeit\n\ntry:\n    import torch\nexcept ImportError:\n    _has_torch = False\nelse:\n    _has_torch = True\n\n\ndef _pbc_diff(diffs, lat_and_inv, use_torch=False):\n    \"\"\"\n    Clamp differences of vectors to super cell.\n\n    Parameters\n    ----------\n        diffs : :obj:`numpy.ndarray`\n            N x 3 matrix of N pairwise differences between vectors `u - v`\n        lat_and_inv : tuple of :obj:`numpy.ndarray`\n            Tuple of 3 x 3 matrix containing lattice vectors as columns and its inverse.\n        use_torch : boolean, optional\n            Enable, if the inputs are PyTorch objects.\n\n    Returns\n    -------\n        :obj:`numpy.ndarray`\n            N x 3 matrix clamped differences\n    \"\"\"\n\n    lat, lat_inv = lat_and_inv\n\n    if use_torch and not _has_torch:\n        raise ImportError(\n            'Optional PyTorch dependency not found! Please run \\'pip install sgdml[torch]\\' to install it or disable the PyTorch option.'\n        )\n\n    if use_torch:\n        c = lat_inv.mm(diffs.t())\n        diffs -= lat.mm(c.round()).t()\n    else:\n        c = lat_inv.dot(diffs.T)\n        diffs -= lat.dot(np.around(c)).T\n\n    return diffs\n\n\ndef _pdist(r, lat_and_inv=None):\n    \"\"\"\n    Compute pairwise Euclidean distance matrix between all atoms.\n\n    Parameters\n    ----------\n        r : :obj:`numpy.ndarray`\n            Array of size 3N containing the Cartesian coordinates of\n            each atom.\n        lat_and_inv : tuple of :obj:`numpy.ndarray`, optional\n            Tuple of 3x3 matrix containing lattice vectors as columns and its inverse.\n\n    Returns\n    -------\n        :obj:`numpy.ndarray`\n            Array of size N(N-1)/2 containing the upper triangle of the pairwise\n            distance matrix between atoms.\n    \"\"\"\n\n    r = r.reshape(-1, 3)\n    n_atoms = r.shape[0]\n\n    if lat_and_inv is None:\n        pdist = sp.spatial.distance.pdist(r, 'euclidean')\n    else:\n        pdist = sp.spatial.distance.pdist(\n            r, lambda u, v: np.linalg.norm(_pbc_diff(u - v, lat_and_inv))\n        )\n\n    tril_idxs = np.tril_indices(n_atoms, k=-1)\n    return sp.spatial.distance.squareform(pdist, checks=False)[tril_idxs]\n\n\ndef _squareform(vec_or_mat):\n\n    # vector to matrix representation\n    if vec_or_mat.ndim == 1:\n\n        n_tril = vec_or_mat.size\n        n = int((1 + np.sqrt(8 * n_tril + 1)) / 2)\n\n        i, j = np.tril_indices(n, k=-1)\n\n        mat = np.zeros((n, n))\n        mat[i, j] = vec_or_mat\n        mat[j, i] = vec_or_mat\n\n        return mat\n\n    else:  # matrix to vector\n\n        assert vec_or_mat.shape[0] == vec_or_mat.shape[1]  # matrix is square\n\n        n = vec_or_mat.shape[0]\n        i, j = np.tril_indices(n, k=-1)\n\n        return vec_or_mat[i, j]\n\n\ndef _r_to_desc(r, pdist):\n    \"\"\"\n    Generate descriptor for a set of atom positions in Cartesian\n    coordinates.\n\n    Parameters\n    ----------\n        r : :obj:`numpy.ndarray`\n            Array of size 3N containing the Cartesian coordinates of\n            each atom.\n        pdist : :obj:`numpy.ndarray`\n            Array of size N x N containing the Euclidean distance\n            (2-norm) for each pair of atoms.\n\n    Returns\n    -------\n        :obj:`numpy.ndarray`\n            Descriptor representation as 1D array of size N(N-1)/2\n    \"\"\"\n\n    # Add singleton dimension if input is (,3N).\n    if r.ndim == 1:\n        r = r[None, :]\n\n    return 1.0 / pdist\n\n\ndef _r_to_d_desc(r, pdist, lat_and_inv=None):\n    \"\"\"\n    Generate descriptor Jacobian for a set of atom positions in\n    Cartesian coordinates.\n\n    This method can apply the minimum-image convention as periodic\n    boundary condition for distances between atoms, given the lattice vectors.\n\n    Parameters\n    ----------\n        r : :obj:`numpy.ndarray`\n            Array of size 3N containing the Cartesian coordinates of\n            each atom.\n        pdist : :obj:`numpy.ndarray`\n            Array of size N x N containing the Euclidean distance\n            (2-norm) for each pair of atoms.\n        lat_and_inv : tuple of :obj:`numpy.ndarray`, optional\n            Tuple of 3 x 3 matrix containing lattice vectors as columns and its inverse.\n\n    Returns\n    -------\n        :obj:`numpy.ndarray`\n            Array of size N(N-1)/2 x 3N containing all partial\n            derivatives of the descriptor.\n    \"\"\"\n\n    r = r.reshape(-1, 3)\n    pdiff = r[:, None] - r[None, :]  # pairwise differences ri - rj\n\n    n_atoms = r.shape[0]\n    i, j = np.tril_indices(n_atoms, k=-1)\n\n    pdiff = pdiff[i, j, :]  # lower triangular\n\n    if lat_and_inv is not None:\n        pdiff = _pbc_diff(pdiff, lat_and_inv)\n\n    d_desc_elem = pdiff / (pdist**3)[:, None]\n\n    return d_desc_elem\n\n\ndef _from_r(r, lat_and_inv=None):\n    \"\"\"\n    Generate descriptor and its Jacobian for one molecular geometry\n    in Cartesian coordinates.\n\n    Parameters\n    ----------\n        r : :obj:`numpy.ndarray`\n            Array of size 3N containing the Cartesian coordinates of\n            each atom.\n        lat_and_inv : tuple of :obj:`numpy.ndarray`, optional\n            Tuple of 3 x 3 matrix containing lattice vectors as columns and its inverse.\n\n    Returns\n    -------\n        :obj:`numpy.ndarray`\n            Descriptor representation as 1D array of size N(N-1)/2\n        :obj:`numpy.ndarray`\n            Array of size N(N-1)/2 x 3N containing all partial\n            derivatives of the descriptor.\n    \"\"\"\n\n    # Add singleton dimension if input is (,3N).\n    if r.ndim == 1:\n        r = r[None, :]\n\n    pd = _pdist(r, lat_and_inv)\n\n    r_desc = _r_to_desc(r, pd)\n    r_d_desc = _r_to_d_desc(r, pd, lat_and_inv)\n\n    return r_desc, r_d_desc\n\n\nclass Desc(object):\n    # def __init__(self, n_atoms, interact_cut_off=None, max_processes=None):\n    def __init__(self, n_atoms, max_processes=None):\n        \"\"\"\n        Generate descriptors and their Jacobians for molecular geometries,\n        including support for periodic boundary conditions.\n\n        Parameters\n        ----------\n                n_atoms : int\n                        Number of atoms in the represented system.\n                max_processes : int, optional\n                        Limit the max. number of processes. Otherwise\n                        all CPU cores are used.\n        \"\"\"\n\n        self.n_atoms = n_atoms\n        self.dim_i = 3 * n_atoms\n\n        # Size of the resulting descriptor vector.\n        self.dim = (n_atoms * (n_atoms - 1)) // 2\n\n        self.tril_indices = np.tril_indices(n_atoms, k=-1)\n\n        # Precompute indices for nonzero entries in desriptor derivatives.\n        self.d_desc_mask = np.zeros((n_atoms, n_atoms - 1), dtype=int)\n        for a in range(n_atoms):  # for each partial derivative\n            rows, cols = self.tril_indices\n            self.d_desc_mask[a, :] = np.concatenate(\n                [np.where(rows == a)[0], np.where(cols == a)[0]]\n            )\n\n        self.dim_range = np.arange(self.dim)  # [0, 1, ..., dim-1]\n\n        # Precompute indices for nonzero entries in desriptor derivatives.\n\n        self.M = np.arange(1, n_atoms)  # indexes matrix row-wise, skipping diagonal\n        for a in range(1, n_atoms):\n            self.M = np.concatenate((self.M, np.delete(np.arange(n_atoms), a)))\n\n        self.A = np.repeat(\n            np.arange(n_atoms), n_atoms - 1\n        )  # [0, 0, ..., 1, 1, ..., 2, 2, ...]\n\n        self.max_processes = max_processes\n\n    def from_R(self, R, lat_and_inv=None, max_processes=None, callback=None):\n        \"\"\"\n        Generate descriptor and its Jacobian for multiple molecular geometries\n        in Cartesian coordinates.\n\n        Parameters\n        ----------\n            R : :obj:`numpy.ndarray`\n                Array of size M x 3N containing the Cartesian coordinates of\n                each atom.\n            lat_and_inv : tuple of :obj:`numpy.ndarray`, optional\n                Tuple of 3 x 3 matrix containing lattice vectors as columns and its inverse.\n            max_processes : int, optional\n                Limit the max. number of processes. Otherwise\n                all CPU cores are used. This parameter overwrites the global setting as\n                set during initialization.\n            callback : callable, optional\n                Descriptor and descriptor Jacobian generation status.\n                    current : int\n                        Current progress (number of completed descriptors).\n                    total : int\n                        Task size (total number of descriptors to create).\n                    sec_disp_str : :obj:`str`, optional\n                        Once complete, this string contains the\n                        time it took complete this task (seconds).\n\n        Returns\n        -------\n            :obj:`numpy.ndarray`\n                Array of size M x N(N-1)/2 containing the descriptor representation\n                for each geometry.\n            :obj:`numpy.ndarray`\n                Array of size M x N(N-1)/2 x 3N containing all partial\n                derivatives of the descriptor for each geometry.\n        \"\"\"\n\n        # Add singleton dimension if input is (,3N).\n        if R.ndim == 1:\n            R = R[None, :]\n\n        M = R.shape[0]\n        if M == 1:\n            return _from_r(R, lat_and_inv)\n\n        R_desc = np.empty([M, self.dim])\n        R_d_desc = np.empty([M, self.dim, 3])\n\n        # Generate descriptor and their Jacobians\n        start = timeit.default_timer()\n\n        pool = None\n        map_func = map\n        max_processes = max_processes or self.max_processes\n        if max_processes != 1 and mp.cpu_count() > 1:\n            pool = Pool((max_processes or mp.cpu_count()) - 1)  # exclude main process\n            map_func = pool.imap\n\n        for i, r_desc_r_d_desc in enumerate(\n            map_func(partial(_from_r, lat_and_inv=lat_and_inv), R)\n        ):\n            R_desc[i, :], R_d_desc[i, :, :] = r_desc_r_d_desc\n\n            if callback is not None and i < M - 1:\n                callback(i, M - 1)\n\n        if pool is not None:\n            pool.close()\n            pool.join()  # Wait for the worker processes to terminate (to measure total runtime correctly).\n            pool = None\n\n        stop = timeit.default_timer()\n\n        if callback is not None:\n            dur_s = stop - start\n            sec_disp_str = 'took {:.1f} s'.format(dur_s) if dur_s >= 0.1 else ''\n            callback(M, M, sec_disp_str=sec_disp_str)\n\n        return R_desc, R_d_desc\n\n    # Multiplies descriptor(s) jacobian with 3N-vector(s) from the right side\n    def d_desc_dot_vec(self, R_d_desc, vecs, overwrite_vecs=False):\n\n        if R_d_desc.ndim == 2:\n            R_d_desc = R_d_desc[None, ...]\n\n        if vecs.ndim == 1:\n            vecs = vecs[None, ...]\n\n        i, j = self.tril_indices\n\n        vecs = vecs.reshape(vecs.shape[0], -1, 3)\n\n        einsum = np.einsum\n        if _has_torch and torch.is_tensor(R_d_desc):\n            assert torch.is_tensor(vecs)\n            einsum = torch.einsum\n\n        return einsum('...ij,...ij->...i', R_d_desc, vecs[:, j, :] - vecs[:, i, :])\n\n    # Multiplies descriptor(s) jacobian with N(N-1)/2-vector(s) from the left side\n    def vec_dot_d_desc(self, R_d_desc, vecs, out=None):\n\n        if R_d_desc.ndim == 2:\n            R_d_desc = R_d_desc[None, ...]\n\n        if vecs.ndim == 1:\n            vecs = vecs[None, ...]\n\n        assert (\n            R_d_desc.shape[0] == 1\n            or vecs.shape[0] == 1\n            or R_d_desc.shape[0] == vecs.shape[0]\n        )  # either multiple descriptors or multiple vectors at once, not both (or the same number of both, than it will must be a multidot)\n\n        n = np.max((R_d_desc.shape[0], vecs.shape[0]))\n        i, j = self.tril_indices\n\n        out = np.zeros((n, self.n_atoms, self.n_atoms, 3))\n        out[:, i, j, :] = R_d_desc * vecs[..., None]\n        out[:, j, i, :] = -out[:, i, j, :]\n        return out.sum(axis=1).reshape(n, -1)\n\n        # if out is None or out.shape != (n, self.n_atoms*3):\n        #    out = np.zeros((n, self.n_atoms*3))\n\n        # R_d_desc_full = np.zeros((self.n_atoms, self.n_atoms, 3))\n        # for a in range(n):\n\n        #   R_d_desc_full[i, j, :] = R_d_desc * vecs[a, :, None]\n        #    R_d_desc_full[j, i, :] = -R_d_desc_full[i, j, :]\n        #    out[a,:] = R_d_desc_full.sum(axis=0).ravel()\n\n        # return out\n\n    def d_desc_from_comp(self, R_d_desc, out=None):\n        \"\"\"\n        Convert a compressed representation of a descriptor Jacobian back\n        to its full representation.\n\n        The compressed representation omits all zeros and scales with N\n        instead of N(N-1)/2.\n\n        Parameters\n        ----------\n            R_d_desc : :obj:`numpy.ndarray` or :obj:`torch.tensor`\n                Array of size M x N x N x 3 containing the compressed\n                descriptor Jacobian.\n            out : :obj:`numpy.ndarray` or :obj:`torch.tensor`, optional\n                Output argument. This must have the exact kind that would\n                be returned if it was not used.\n\n        Note\n        ----\n                If used, the output argument must be initialized with zeros!\n\n        Returns\n        -------\n            :obj:`numpy.ndarray` or :obj:`torch.tensor`\n                Array of size M x N(N-1)/2 x 3N containing the full\n                representation.\n        \"\"\"\n\n        if R_d_desc.ndim == 2:\n            R_d_desc = R_d_desc[None, ...]\n\n        n = R_d_desc.shape[0]\n        i, j = self.tril_indices\n\n        if out is None:\n            if _has_torch and torch.is_tensor(R_d_desc):\n                device = R_d_desc.device\n                dtype = R_d_desc.dtype\n                out = torch.zeros((n, self.dim, self.n_atoms, 3), device=device).to(\n                    dtype\n                )\n            else:\n                out = np.zeros((n, self.dim, self.n_atoms, 3))\n        else:\n            out = out.reshape(n, self.dim, self.n_atoms, 3)\n\n        out[:, self.dim_range, j, :] = R_d_desc\n        out[:, self.dim_range, i, :] = -R_d_desc\n\n        return out.reshape(-1, self.dim, self.dim_i)\n\n    def d_desc_to_comp(self, R_d_desc):\n        \"\"\"\n        Convert a descriptor Jacobian to a compressed representation.\n\n        The compressed representation omits all zeros and scales with N\n        instead of N(N-1)/2.\n\n        Parameters\n        ----------\n            R_d_desc : :obj:`numpy.ndarray`\n                Array of size M x N(N-1)/2 x 3N containing the descriptor\n                Jacobian.\n\n        Returns\n        -------\n            :obj:`numpy.ndarray`\n                Array of size M x N x N x 3 containing the compressed\n                representation.\n        \"\"\"\n\n        # Add singleton dimension for single inputs.\n        if R_d_desc.ndim == 2:\n            R_d_desc = R_d_desc[None, ...]\n\n        n = R_d_desc.shape[0]\n        n_atoms = int(R_d_desc.shape[2] / 3)\n\n        R_d_desc = R_d_desc.reshape(n, -1, n_atoms, 3)\n\n        ret = np.zeros((n, n_atoms, n_atoms, 3))\n        ret[:, self.M, self.A, :] = R_d_desc[:, self.d_desc_mask.ravel(), self.A, :]\n\n        # Take the upper triangle.\n        i, j = self.tril_indices\n        return ret[:, i, j, :]\n\n    @staticmethod\n    def perm(perm):\n        \"\"\"\n        Convert atom permutation to descriptor permutation.\n\n        A permutation of N atoms is converted to a permutation that acts on\n        the corresponding descriptor representation. Applying the converted\n        permutation to a descriptor is equivalent to permuting the atoms\n        first and then generating the descriptor.\n\n        Parameters\n        ----------\n            perm : :obj:`numpy.ndarray`\n                Array of size N containing the atom permutation.\n\n        Returns\n        -------\n            :obj:`numpy.ndarray`\n                Array of size N(N-1)/2 containing the corresponding\n                descriptor permutation.\n        \"\"\"\n\n        n = len(perm)\n\n        rest = np.zeros((n, n))\n        rest[np.tril_indices(n, -1)] = list(range((n**2 - n) // 2))\n        rest = rest + rest.T\n        rest = rest[perm, :]\n        rest = rest[:, perm]\n\n        return rest[np.tril_indices(n, -1)].astype(int)\n"
  },
  {
    "path": "sgdml/utils/io.py",
    "content": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2021 Stefan Chmiela\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nimport argparse\nimport hashlib\nimport os\nimport re\nimport sys\n\nimport numpy as np\n\nfrom . import ui\n\n_z_str_to_z_dict = {\n    'H': 1,\n    'He': 2,\n    'Li': 3,\n    'Be': 4,\n    'B': 5,\n    'C': 6,\n    'N': 7,\n    'O': 8,\n    'F': 9,\n    'Ne': 10,\n    'Na': 11,\n    'Mg': 12,\n    'Al': 13,\n    'Si': 14,\n    'P': 15,\n    'S': 16,\n    'Cl': 17,\n    'Ar': 18,\n    'K': 19,\n    'Ca': 20,\n    'Sc': 21,\n    'Ti': 22,\n    'V': 23,\n    'Cr': 24,\n    'Mn': 25,\n    'Fe': 26,\n    'Co': 27,\n    'Ni': 28,\n    'Cu': 29,\n    'Zn': 30,\n    'Ga': 31,\n    'Ge': 32,\n    'As': 33,\n    'Se': 34,\n    'Br': 35,\n    'Kr': 36,\n    'Rb': 37,\n    'Sr': 38,\n    'Y': 39,\n    'Zr': 40,\n    'Nb': 41,\n    'Mo': 42,\n    'Tc': 43,\n    'Ru': 44,\n    'Rh': 45,\n    'Pd': 46,\n    'Ag': 47,\n    'Cd': 48,\n    'In': 49,\n    'Sn': 50,\n    'Sb': 51,\n    'Te': 52,\n    'I': 53,\n    'Xe': 54,\n    'Cs': 55,\n    'Ba': 56,\n    'La': 57,\n    'Ce': 58,\n    'Pr': 59,\n    'Nd': 60,\n    'Pm': 61,\n    'Sm': 62,\n    'Eu': 63,\n    'Gd': 64,\n    'Tb': 65,\n    'Dy': 66,\n    'Ho': 67,\n    'Er': 68,\n    'Tm': 69,\n    'Yb': 70,\n    'Lu': 71,\n    'Hf': 72,\n    'Ta': 73,\n    'W': 74,\n    'Re': 75,\n    'Os': 76,\n    'Ir': 77,\n    'Pt': 78,\n    'Au': 79,\n    'Hg': 80,\n    'Tl': 81,\n    'Pb': 82,\n    'Bi': 83,\n    'Po': 84,\n    'At': 85,\n    'Rn': 86,\n    'Fr': 87,\n    'Ra': 88,\n    'Ac': 89,\n    'Th': 90,\n    'Pa': 91,\n    'U': 92,\n    'Np': 93,\n    'Pu': 94,\n    'Am': 95,\n    'Cm': 96,\n    'Bk': 97,\n    'Cf': 98,\n    'Es': 99,\n    'Fm': 100,\n    'Md': 101,\n    'No': 102,\n    'Lr': 103,\n    'Rf': 104,\n    'Db': 105,\n    'Sg': 106,\n    'Bh': 107,\n    'Hs': 108,\n    'Mt': 109,\n    'Ds': 110,\n    'Rg': 111,\n    'Cn': 112,\n    'Uuq': 114,\n    'Uuh': 116,\n}\n_z_to_z_str_dict = {v: k for k, v in _z_str_to_z_dict.items()}\n\n\ndef z_str_to_z(z_str):\n    return np.array([_z_str_to_z_dict[x] for x in z_str])\n\n\ndef z_to_z_str(z):\n    return [_z_to_z_str_dict[int(x)] for x in z]\n\n\ndef train_dir_name(dataset, n_train, use_sym, use_E, use_E_cstr):\n\n    theory_level_str = re.sub(r'[^\\w\\-_\\.]', '.', str(dataset['theory']))\n    theory_level_str = re.sub(r'\\.\\.', '.', theory_level_str)\n\n    sym_str = '-sym' if use_sym else ''\n    # cprsn_str = '-cprsn' if use_cprsn else ''\n    noE_str = '-noE' if not use_E else ''\n    Ecstr_str = '-Ecstr' if use_E_cstr else ''\n\n    return 'sgdml_cv_%s-%s-train%d%s%s%s' % (\n        dataset['name'].astype(str),\n        theory_level_str,\n        n_train,\n        sym_str,\n        # cprsn_str,\n        noE_str,\n        Ecstr_str,\n    )\n\n\ndef task_file_name(task):\n\n    n_train = task['idxs_train'].shape[0]\n    n_perms = task['perms'].shape[0]\n    sig = np.squeeze(task['sig'])\n\n    return 'task-train%d-sym%d-sig%04d.npz' % (n_train, n_perms, sig)\n\n\ndef model_file_name(task_or_model, is_extended=False):\n\n    n_train = task_or_model['idxs_train'].shape[0]\n    n_perms = task_or_model['perms'].shape[0]\n    sig = np.squeeze(task_or_model['sig'])\n\n    if is_extended:\n        dataset = np.squeeze(task_or_model['dataset_name'])\n        theory_level_str = re.sub(\n            r'[^\\w\\-_\\.]', '.', str(np.squeeze(task_or_model['dataset_theory']))\n        )\n        theory_level_str = re.sub(r'\\.\\.', '.', theory_level_str)\n        return '%s-%s-train%d-sym%d.npz' % (dataset, theory_level_str, n_train, n_perms)\n    return 'model-train%d-sym%d-sig%04d.npz' % (n_train, n_perms, sig)\n\n\ndef dataset_md5(dataset):\n\n    md5_hash = hashlib.md5()\n\n    keys = ['z', 'R']\n    if 'E' in dataset:\n        keys.append('E')\n    keys.append('F')\n\n    # only include new extra keys in fingerprint for 'modern' dataset files\n    # 'code_version' was included from 0.4.0.dev1\n    # opt_keys = ['lattice', 'e_unit', 'E_min', 'E_max', 'E_mean', 'E_var', 'f_unit', 'F_min', 'F_max', 'F_mean', 'F_var']\n    # for k in opt_keys:\n    #    if k in dataset:\n    #        keys.append(k)\n\n    for k in keys:\n        d = dataset[k]\n        if type(d) is np.ndarray:\n            d = d.ravel()\n        md5_hash.update(hashlib.md5(d).digest())\n\n    return md5_hash.hexdigest().encode('utf-8')\n\n\n# ## FILES\n\n# Read geometry file (xyz format).\n# R: (n_geo,3*n_atoms)\n# z: (3*n_atoms,)\ndef read_xyz(file_path):\n\n    with open(file_path, 'r') as f:\n        n_atoms = None\n\n        R, z = [], []\n        for i, line in enumerate(f):\n            line = line.strip()\n            if not n_atoms:\n                n_atoms = int(line)\n\n            cols = line.split()\n            file_i, line_i = divmod(i, n_atoms + 2)\n            if line_i >= 2:\n                R.append(list(map(float, cols[1:4])))\n                if file_i == 0:  # first molecule\n                    z.append(_z_str_to_z_dict[cols[0]])\n\n        R = np.array(R).reshape(-1, 3 * n_atoms)\n        z = np.array(z)\n\n        f.close()\n    return R, z\n\n\n# Write geometry file (xyz format).\ndef write_geometry(filename, r, z, comment_str=''):\n\n    r = np.squeeze(r)\n    try:\n        with open(filename, 'w') as f:\n            f.write(str(len(r)) + '\\n' + comment_str)\n            for i, atom in enumerate(r):\n                f.write('\\n' + _z_to_z_str_dict[z[i]] + '\\t')\n                f.write('\\t'.join(str(x) for x in atom))\n    except IOError:\n        sys.exit(\"ERROR: Writing xyz file failed.\")\n\n\n# Write geometry file (xyz format).\ndef generate_xyz_str(r, z, e=None, f=None, lattice=None):\n\n    comment_str = ''\n    if lattice is not None:\n        comment_str += 'Lattice=\\\"{}\\\" '.format(\n            ' '.join(['{:.12g}'.format(l) for l in lattice.T.ravel()])\n        )\n    if e is not None:\n        comment_str += 'Energy={:.12g} '.format(e)\n    comment_str += 'Properties=species:S:1:pos:R:3'\n    if f is not None:\n        comment_str += ':forces:R:3'\n\n    species_str = '\\n'.join([_z_to_z_str_dict[z_i] for z_i in z])\n\n    r_f_str = ui.gen_mat_str(r)[0]\n    if f is not None:\n        r_f_str = ui.merge_col_str(r_f_str, ui.gen_mat_str(f)[0])\n\n    xyz_str = str(len(r)) + '\\n' + comment_str + '\\n'\n    xyz_str += ui.merge_col_str(species_str, r_f_str)\n\n    return xyz_str\n\n\ndef lattice_vec_to_par(lat):\n\n    lat = lat.T\n    lengths = [np.linalg.norm(v) for v in lat]\n\n    angles = []\n    for i in range(3):\n        j = i - 1\n        k = i - 2\n\n        ll = lengths[j] * lengths[k]\n        if ll > 1e-16:\n            x = np.dot(lat[j], lat[k]) / ll\n            angle = 180.0 / np.pi * np.arccos(x)\n        else:\n            angle = 90.0\n        angles.append(angle)\n\n    return lengths, angles\n\n\n### FILE HANDLING\n\n\ndef is_file_type(arg, type):\n    \"\"\"\n    Validate file path and check if the file is of the specified type.\n\n    Parameters\n    ----------\n        arg : :obj:`str`\n            File path.\n        type : {'dataset', 'task', 'model'}\n            Possible file types.\n\n    Returns\n    -------\n        (:obj:`str`, :obj:`dict`)\n            Tuple of file path (as provided) and data stored in the\n            file. The returned instance of NpzFile class must be\n            closed to avoid leaking file descriptors.\n\n    Raises\n    ------\n        ArgumentTypeError\n            If the provided file path does not lead to a NpzFile.\n        ArgumentTypeError\n            If the file is not readable.\n        ArgumentTypeError\n            If the file is of wrong type.\n        ArgumentTypeError\n            If path/fingerprint is provided, but the path is not valid.\n        ArgumentTypeError\n            If fingerprint could not be resolved.\n        ArgumentTypeError\n            If multiple files with the same fingerprint exist.\n\n    \"\"\"\n\n    # Replace MD5 dataset fingerprint with file name, if necessary.\n    if type == 'dataset' and not arg.endswith('.npz') and not os.path.isdir(arg):\n        dir = '.'\n        if re.search(r'^[a-f0-9]{32}$', arg):  # arg looks similar to MD5 hash string\n            md5_str = arg\n        else:  # is it a path with a MD5 hash at the end?\n            md5_str = os.path.basename(os.path.normpath(arg))\n            dir = os.path.dirname(os.path.normpath(arg))\n\n            if dir == '':  # it is only a filename after all, hence not the right type\n                raise argparse.ArgumentTypeError('{0} is not a .npz file'.format(arg))\n\n            if re.search(r'^[a-f0-9]{32}$', md5_str) and not os.path.isdir(\n                dir\n            ):  # path has MD5 hash string at the end, but directory is not valid\n                raise argparse.ArgumentTypeError('{0} is not a directory'.format(dir))\n\n        file_names = filter_file_type(dir, type, md5_match=md5_str)\n\n        if not len(file_names):\n            raise argparse.ArgumentTypeError(\n                \"No {0} files with fingerprint '{1}' found in '{2}'\".format(\n                    type, md5_str, dir\n                )\n            )\n        elif len(file_names) > 1:\n            error_str = (\n                \"Multiple {0} files with fingerprint '{1}' found in '{2}'\".format(\n                    type, md5_str, dir\n                )\n            )\n            for file_name in file_names:\n                error_str += '\\n       {0}'.format(file_name)\n\n            raise argparse.ArgumentTypeError(error_str)\n        else:\n            arg = os.path.join(dir, file_names[0])\n\n    if not arg.endswith('.npz'):\n        argparse.ArgumentTypeError('{0} is not a .npz file'.format(arg))\n\n    try:\n        file = np.load(arg, allow_pickle=True)\n    except Exception:\n        raise argparse.ArgumentTypeError('{0} is not readable'.format(arg))\n\n    if 'type' not in file or file['type'].astype(str) != type[0]:\n        raise argparse.ArgumentTypeError('{0} is not a {1} file'.format(arg, type))\n\n    return arg, file\n\n\ndef filter_file_type(dir, type, md5_match=None):\n    \"\"\"\n    Filters all files from a directory that match a given type and (optionally)\n    a given fingerprint.\n\n    Parameters\n    ----------\n        arg : :obj:`str`\n            File path.\n        type : {'dataset', 'task', 'model'}\n            Possible file types.\n        md5_match : :obj:`str`, optional\n            Fingerprint string.\n\n    Returns\n    -------\n        :obj:`list` of :obj:`str`\n            List of file names that match the specified type and fingerprint\n            (if provided).\n\n    Raises\n    ------\n        ArgumentTypeError\n            If the directory contains unreadable .npz files.\n\n    \"\"\"\n\n    file_names = []\n    for file_name in sorted(os.listdir(dir)):\n        if file_name.endswith('.npz'):\n            file_path = os.path.join(dir, file_name)\n            try:\n                file = np.load(file_path, allow_pickle=True)\n            except Exception:\n                raise argparse.ArgumentTypeError(\n                    '{0} contains unreadable .npz files'.format(arg)\n                )\n\n            if 'type' in file and file['type'].astype(str) == type[0]:\n\n                if md5_match is None:\n                    file_names.append(file_name)\n                elif 'md5' in file and file['md5'] == md5_match:\n                    file_names.append(file_name)\n\n            file.close()\n\n    return file_names\n\n\ndef is_valid_file_type(arg_in):\n    \"\"\"\n    Check if file is either a valid dataset, task or model file.\n\n    Parameters\n    ----------\n        arg_in : :obj:`str`\n            File path.\n\n    Returns\n    -------\n        (:obj:`str`, :obj:`dict`)\n            Tuple of file path (as provided) and data stored in the\n            file. The returned instance of NpzFile class must be\n            closed to avoid leaking file descriptors.\n\n    Raises\n    ------\n        ArgumentTypeError\n            If the provided file path does not point to a supported\n            file type.\n\n    \"\"\"\n\n    arg, file = None, None\n    try:\n        arg, file = is_file_type(arg_in, 'dataset')\n    except argparse.ArgumentTypeError:\n        pass\n\n    if file is None:\n        try:\n            arg, file = is_file_type(arg_in, 'task')\n        except argparse.ArgumentTypeError:\n            pass\n\n    if file is None:\n        try:\n            arg, file = is_file_type(arg_in, 'model')\n        except argparse.ArgumentTypeError:\n            pass\n\n    if file is None:\n        raise argparse.ArgumentTypeError(\n            '{0} is neither a dataset, task, nor model file'.format(arg)\n        )\n\n    return arg, file\n\n\ndef is_dir_with_file_type(arg, type, or_file=False):\n    \"\"\"\n    Validate directory path and check if it contains files of the specified type.\n\n    Note\n    ----\n        If a file path is provided, this function acts like its a directory with\n        just one file.\n\n    Parameters\n    ----------\n        arg : :obj:`str`\n            File path.\n        type : {'dataset', 'task', 'model'}\n            Possible file types.\n        or_file : bool\n            If `arg` contains a file path, act like it's a directory\n            with just a single file inside.\n\n    Returns\n    -------\n        (:obj:`str`, :obj:`list` of :obj:`str`)\n            Tuple of directory path (as provided) and a list of\n            contained file names of the specified type.\n\n    Raises\n    ------\n        ArgumentTypeError\n            If the provided directory path does not lead to a directory.\n        ArgumentTypeError\n            If directory contains unreadable files.\n        ArgumentTypeError\n            If directory contains no files of the specified type.\n    \"\"\"\n\n    if or_file and os.path.isfile(arg):  # arg: file path\n        _, file = is_file_type(\n            arg, type\n        )  # raises exception if there is a problem with the file\n        file.close()\n        file_name = os.path.basename(arg)\n        file_dir = os.path.dirname(arg)\n        return file_dir, [file_name]\n    else:  # arg: dir\n\n        if not os.path.isdir(arg):\n            raise argparse.ArgumentTypeError('{0} is not a directory'.format(arg))\n\n        file_names = filter_file_type(arg, type)\n\n        # if not len(file_names):\n        #    raise argparse.ArgumentTypeError(\n        #        '{0} contains no {1} files'.format(arg, type)\n        #    )\n\n        return arg, file_names\n\n\ndef is_task_dir_resumeable(\n    train_dir, train_dataset, test_dataset, n_train, n_test, sigs, gdml\n):\n    r\"\"\"\n    Check if a directory contains `task` and/or `model` files that\n    match the configuration of a training process specified in the\n    remaining arguments.\n\n    Check if the training and test datasets in each task match\n    `train_dataset` and `test_dataset`, if the number of training and\n    test points matches and if the choices for the kernel\n    hyper-parameter :math:`\\sigma` are contained in the list. Check\n    also, if the existing tasks/models contain symmetries and if\n    that's consistent with the flag `gdml`. This function is useful\n    for determining if a training process can be resumed using the\n    existing files or not.\n\n    Parameters\n    ----------\n        train_dir : :obj:`str`\n            Path to training directory.\n        train_dataset : :obj:`dataset`\n            Dataset from which training points are sampled.\n        test_dataset : :obj:`test_dataset`\n            Dataset from which test points are sampled (may be the\n            same as `train_dataset`).\n        n_train : int\n            Number of training points to sample.\n        n_test : int\n            Number of test points to sample.\n        sigs : :obj:`list` of int\n            List of :math:`\\sigma` kernel hyper-parameter choices\n            (usually: the hyper-parameter search grid)\n        gdml : bool\n            If `True`, don't include any symmetries in model (GDML),\n            otherwise do (sGDML).\n\n    Returns\n    -------\n        bool\n            False, if any of the files in the directory do not match\n            the training configuration.\n    \"\"\"\n\n    for file_name in sorted(os.listdir(train_dir)):\n        if file_name.endswith('.npz'):\n            file_path = os.path.join(train_dir, file_name)\n            file = np.load(file_path, allow_pickle=True)\n\n            if 'type' not in file:\n                continue\n            elif file['type'] == 't' or file['type'] == 'm':\n\n                if (\n                    file['md5_train'] != train_dataset['md5']\n                    or file['md5_valid'] != test_dataset['md5']\n                    or len(file['idxs_train']) != n_train\n                    or len(file['idxs_valid']) != n_test\n                    or gdml\n                    and file['perms'].shape[0] > 1\n                    or file['sig'] not in sigs\n                ):\n                    return False\n\n    return True\n\n\n### ARGUMENT VALIDATION\n\n\ndef is_strict_pos_int(arg):\n    \"\"\"\n    Validate strictly positive integer input.\n\n    Parameters\n    ----------\n        arg : :obj:`str`\n            Integer as string.\n\n    Returns\n    -------\n        int\n            Parsed integer.\n\n    Raises\n    ------\n        ArgumentTypeError\n            If integer is not > 0.\n    \"\"\"\n    x = int(arg)\n    if x <= 0:\n        raise argparse.ArgumentTypeError('must be strictly positive')\n    return x\n\n\ndef parse_list_or_range(arg):\n    \"\"\"\n    Parses a string that represents either an integer or a range in\n    the notation ``<start>:<step>:<stop>``.\n\n    Parameters\n    ----------\n        arg : :obj:`str`\n            Integer or range string.\n\n    Returns\n    -------\n        int or :obj:`list` of int\n\n    Raises\n    ------\n        ArgumentTypeError\n            If input can neither be interpreted as an integer nor a valid range.\n    \"\"\"\n\n    if re.match(r'^\\d+:\\d+:\\d+$', arg) or re.match(r'^\\d+:\\d+$', arg):\n        rng_params = list(map(int, arg.split(':')))\n\n        step = 1\n        if len(rng_params) == 2:  # start, stop\n            start, stop = rng_params\n        else:  # start, step, stop\n            start, step, stop = rng_params\n\n        rng = list(range(start, stop + 1, step))  # include last stop-element in range\n        if len(rng) == 0:\n            raise argparse.ArgumentTypeError('{0} is an empty range'.format(arg))\n\n        return rng\n    elif re.match(r'^\\d+$', arg):\n        return int(arg)\n\n    raise argparse.ArgumentTypeError(\n        '{0} is neither a integer list, nor valid range in the form <start>:[<step>:]<stop>'.format(\n            arg\n        )\n    )\n"
  },
  {
    "path": "sgdml/utils/perm.py",
    "content": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2021 Stefan Chmiela\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nfrom __future__ import print_function\n\nimport multiprocessing as mp\n\nPool = mp.get_context('fork').Pool\n\nimport sys\nimport timeit\nfrom functools import partial\n\nimport numpy as np\nimport scipy.optimize\nimport scipy.spatial.distance\nfrom scipy.sparse import csr_matrix\nfrom scipy.sparse.csgraph import minimum_spanning_tree\n\nfrom .. import DONE, NOT_DONE\nfrom .desc import Desc\nfrom . import ui\n\nglob = {}\n\n\ndef share_array(arr_np, typecode):\n    arr = mp.RawArray(typecode, arr_np.ravel())\n    return arr, arr_np.shape\n\n\ndef _bipartite_match_wkr(i, n_train, same_z_cost):\n\n    global glob\n\n    adj_set = np.frombuffer(glob['adj_set']).reshape(glob['adj_set_shape'])\n    v_set = np.frombuffer(glob['v_set']).reshape(glob['v_set_shape'])\n    match_cost = np.frombuffer(glob['match_cost']).reshape(glob['match_cost_shape'])\n\n    adj_i = scipy.spatial.distance.squareform(adj_set[i, :])\n    v_i = v_set[i, :, :]\n\n    match_perms = {}\n    for j in range(i + 1, n_train):\n\n        adj_j = scipy.spatial.distance.squareform(adj_set[j, :])\n        v_j = v_set[j, :, :]\n\n        cost = -np.fabs(v_i).dot(np.fabs(v_j).T)\n        cost += same_z_cost * np.max(np.abs(cost))\n\n        _, perm = scipy.optimize.linear_sum_assignment(cost)\n\n        adj_i_perm = adj_i[:, perm]\n        adj_i_perm = adj_i_perm[perm, :]\n\n        score_before = np.linalg.norm(adj_i - adj_j)\n        score = np.linalg.norm(adj_i_perm - adj_j)\n\n        match_cost[i, j] = score\n        if score >= score_before:\n            match_cost[i, j] = score_before\n        elif not np.isclose(score_before, score):  # otherwise perm is identity\n            match_perms[i, j] = perm\n\n    return match_perms\n\n\ndef bipartite_match(R, z, lat_and_inv=None, max_processes=None, callback=None):\n\n    global glob\n\n    n_train, n_atoms, _ = R.shape\n\n    # penalty matrix for mixing atom species\n    same_z_cost = np.repeat(z[:, None], len(z), axis=1) - z\n    same_z_cost[same_z_cost != 0] = 1\n\n    # NEW\n\n    # penalty matrix for mixing differently bonded atoms\n    # NOTE: needs ASE, expects R to be in angstrom, does not support bond breaking\n\n    # from ase import Atoms\n    # from ase.geometry.analysis import Analysis\n\n    # atoms = Atoms(\n    #     z, positions=R[0]\n    # )  # only use first molecule in dataset to find connected components (fix me later, maybe) # *0.529177249\n\n    # bonds = Analysis(atoms).all_bonds[0]\n    # #n_bonds = np.array([len(bonds_i) for bonds_i in bonds])\n\n    # same_bonding_cost = np.zeros((n_atoms, n_atoms))\n    # for i in range(n_atoms):\n    #     bi = bonds[i]\n    #     z_bi = z[bi]\n    #     for j in range(i+1,n_atoms):\n    #         bj = bonds[j]\n    #         z_bj = z[bj]\n\n    #         if set(z_bi) == set(z_bj):\n    #             same_bonding_cost[i,j] = 1\n\n    # same_bonding_cost += same_bonding_cost.T\n\n    # same_bonding_cost[np.diag_indices(n_atoms)] = 1\n    # same_bonding_cost = 1-same_bonding_cost\n\n    # set(a) & set(b)\n\n    # same_bonding_cost = np.repeat(n_bonds[:, None], len(n_bonds), axis=1) - n_bonds\n    # same_bonding_cost[same_bonding_cost != 0] = 1\n\n    # NEW\n\n    match_cost = np.zeros((n_train, n_train))\n\n    desc = Desc(n_atoms, max_processes=max_processes)\n\n    adj_set = np.empty((n_train, desc.dim))\n    v_set = np.empty((n_train, n_atoms, n_atoms))\n    for i in range(n_train):\n        r = np.squeeze(R[i, :, :])\n\n        if lat_and_inv is None:\n            adj = scipy.spatial.distance.pdist(r, 'euclidean')\n\n            # from ase import Atoms\n            # from ase.geometry.analysis import Analysis\n\n            # atoms = Atoms(\n            #     z, positions=r\n            # )  # only use first molecule in dataset to find connected components (fix me later, maybe) # *0.529177249\n\n            # bonds = Analysis(atoms).all_bonds[0]\n\n            # adj = scipy.spatial.distance.squareform(adj)\n\n            # bonded = np.zeros((z.size, z.size))\n\n            # for j, bonded_to in enumerate(bonds):\n            # inv_bonded_to = np.arange(n_atoms)\n            # inv_bonded_to[bonded_to] = 0\n\n            # adj[j, inv_bonded_to] = 0\n\n            #    bonded[j, bonded_to] = 1\n\n            # bonded = bonded + bonded.T\n\n            # print(bonded)\n\n        else:\n\n            from .desc import _pdist, _squareform\n\n            adj_tri = _pdist(r, lat_and_inv)\n            adj = _squareform(adj_tri)  # our vectorized format to full matrix\n            adj = scipy.spatial.distance.squareform(\n                adj\n            )  # full matrix to numpy vectorized format\n\n        w, v = np.linalg.eig(scipy.spatial.distance.squareform(adj))\n        v = v[:, w.argsort()[::-1]]\n\n        adj_set[i, :] = adj\n        v_set[i, :, :] = v\n\n    glob['adj_set'], glob['adj_set_shape'] = share_array(adj_set, 'd')\n    glob['v_set'], glob['v_set_shape'] = share_array(v_set, 'd')\n    glob['match_cost'], glob['match_cost_shape'] = share_array(match_cost, 'd')\n\n    if callback is not None:\n        callback = partial(callback, disp_str='Bi-partite matching')\n\n    start = timeit.default_timer()\n\n    pool = None\n    map_func = map\n    if max_processes != 1 and mp.cpu_count() > 1:\n        pool = Pool((max_processes or mp.cpu_count()) - 1)  # exclude main process\n        map_func = pool.imap_unordered\n\n    match_perms_all = {}\n    for i, match_perms in enumerate(\n        map_func(\n            partial(_bipartite_match_wkr, n_train=n_train, same_z_cost=same_z_cost),\n            list(range(n_train)),\n        )\n    ):\n        match_perms_all.update(match_perms)\n\n        if callback is not None:\n            callback(i, n_train)\n\n    if pool is not None:\n        pool.close()\n        pool.join()  # Wait for the worker processes to terminate (to measure total runtime correctly).\n        pool = None\n\n    stop = timeit.default_timer()\n\n    dur_s = stop - start\n    sec_disp_str = 'took {:.1f} s'.format(dur_s) if dur_s >= 0.1 else ''\n    if callback is not None:\n        callback(n_train, n_train, sec_disp_str=sec_disp_str)\n\n    match_cost = np.frombuffer(glob['match_cost']).reshape(glob['match_cost_shape'])\n    match_cost = match_cost + match_cost.T\n    match_cost[np.diag_indices_from(match_cost)] = np.inf\n    match_cost = csr_matrix(match_cost)\n\n    return match_perms_all, match_cost\n\n\ndef sync_perm_mat(match_perms_all, match_cost, n_atoms, callback=None):\n\n    if callback is not None:\n        callback = partial(\n            callback, disp_str='Multi-partite matching (permutation synchronization)'\n        )\n        callback(NOT_DONE)\n\n    tree = minimum_spanning_tree(match_cost, overwrite=True)\n\n    perms = np.arange(n_atoms, dtype=int)[None, :]\n    rows, cols = tree.nonzero()\n    for com in zip(rows, cols):\n        perm = match_perms_all.get(com)\n        if perm is not None:\n            perms = np.vstack((perms, perm))\n    perms = np.unique(perms, axis=0)\n\n    if callback is not None:\n        callback(DONE)\n\n    return perms\n\n\n# convert permutation to dijoined cycles\ndef to_cycles(perm):\n    pi = {i: perm[i] for i in range(len(perm))}\n    cycles = []\n\n    while pi:\n        elem0 = next(iter(pi))  # arbitrary starting element\n        this_elem = pi[elem0]\n        next_item = pi[this_elem]\n\n        cycle = []\n        while True:\n            cycle.append(this_elem)\n            del pi[this_elem]\n            this_elem = next_item\n            if next_item in pi:\n                next_item = pi[next_item]\n            else:\n                break\n\n        cycles.append(cycle)\n\n    return cycles\n\n\n# find permutation group with larges cardinality\n# note: this is used if transitive closure fails (to salvage at least some permutations)\ndef salvage_subgroup(perms):\n\n    n_perms, n_atoms = perms.shape\n\n    all_long_cycles = []\n    for i in range(n_perms):\n        long_cycles = [cy for cy in to_cycles(list(perms[i, :])) if len(cy) > 1]\n        all_long_cycles += long_cycles\n\n    # print(all_long_cycles)\n    # print('--------------')\n\n    def _cycle_intersects_with_larger_one(cy):\n\n        for ac in all_long_cycles:\n            if len(cy) < len(ac):\n                if not set(cy).isdisjoint(ac):\n                    return True\n\n        return False\n\n    lcms = []\n    keep_idx_many = []\n    for i in range(n_perms):\n\n        # print(to_cycles(list(perms[i, :])))\n\n        # is this permutation valid?\n        # remove permutations that contain cycles that share elements with larger cycles in other perms\n        long_cycles = [cy for cy in to_cycles(list(perms[i, :])) if len(cy) > 1]\n\n        # print('long cycles:')\n        # print(long_cycles)\n\n        ignore_perm = any(list(map(_cycle_intersects_with_larger_one, long_cycles)))\n\n        if not ignore_perm:\n            keep_idx_many.append(i)\n\n        # print(ignore_perm)\n\n        # print()\n\n        # cy_lens = [len(cy) for cy in to_cycles(list(perms[i, :]))]\n        # lcm = np.lcm.reduce(cy_lens)\n        # lcms.append(lcm)\n    # keep_idx = np.argmax(lcms)\n    # perms = np.vstack((np.arange(n_atoms), perms[keep_idx,:]))\n    perms = perms[keep_idx_many, :]\n\n    # print(perms)\n\n    return perms\n\n\ndef complete_sym_group(\n    perms, n_perms_max=None, disp_str='Permutation group completion', callback=None\n):\n\n    if callback is not None:\n        callback = partial(callback, disp_str=disp_str)\n        callback(NOT_DONE)\n\n    perm_added = True\n    while perm_added:\n        perm_added = False\n        n_perms = perms.shape[0]\n        for i in range(n_perms):\n            for j in range(n_perms):\n\n                new_perm = perms[i, perms[j, :]]\n                if not (new_perm == perms).all(axis=1).any():\n                    perm_added = True\n                    perms = np.vstack((perms, new_perm))\n\n                    # Transitive closure is not converging! Give up and return identity permutation.\n                    if n_perms_max is not None and perms.shape[0] == n_perms_max:\n\n                        if callback is not None:\n                            callback(\n                                DONE,\n                                sec_disp_str='transitive closure has failed',\n                                done_with_warning=True,\n                            )\n                        return None\n\n    if callback is not None:\n        callback(\n            DONE,\n            sec_disp_str='found {:d} symmetries'.format(perms.shape[0]),\n        )\n\n    return perms\n\n\ndef find_perms(R, z, lat_and_inv=None, callback=None, max_processes=None):\n\n    m, n_atoms = R.shape[:2]\n\n    # Find matching for all pairs.\n    match_perms_all, match_cost = bipartite_match(\n        R, z, lat_and_inv, max_processes, callback=callback\n    )\n\n    # Remove inconsistencies.\n    match_perms = sync_perm_mat(match_perms_all, match_cost, n_atoms, callback=callback)\n\n    # Commplete symmetric group.\n    # Give up, if transitive closure yields more than 100 unique permutations.\n    sym_group_perms = complete_sym_group(\n        match_perms, n_perms_max=100, callback=callback\n    )\n\n    # Limit closure to largest cardinality permutation in the set to get at least some symmetries.\n    if sym_group_perms is None:\n        match_perms_subset = salvage_subgroup(match_perms)\n        sym_group_perms = complete_sym_group(\n            match_perms_subset,\n            n_perms_max=100,\n            disp_str='Closure disaster recovery',\n            callback=callback,\n        )\n\n    return sym_group_perms\n\n\ndef find_extra_perms(R, z, lat_and_inv=None, callback=None, max_processes=None):\n\n    m, n_atoms = R.shape[:2]\n\n    # NEW\n\n    # catcher\n    # p = np.arange(n_atoms)\n    # plane_3idxs = [19,17,47] # left to right\n    # perm = find_perms_via_reflection(R[0], z, np.arange(n_atoms), plane_3idxs, lat_and_inv=None, max_processes=None)\n    # perms = np.vstack((p[None,:], perm))\n    # plane_3idxs = [(4,5),(2,1),(34,33)]  # top to bottom\n    # perm = find_perms_via_reflection(R[0], z, np.arange(n_atoms), plane_3idxs, lat_and_inv=None, max_processes=None)\n    # perms = np.vstack((perm[None,:], perms))\n    # sym_group_perms = complete_sym_group(perms, n_perms_max=100, callback=callback)\n\n    # nanotube\n    R = R.copy()\n    frags = find_frags(R[0], z, lat_and_inv=lat_and_inv)\n    print(frags)\n\n    perms = np.arange(n_atoms)[None, :]\n\n    plane_3idxs = [280, 281, 273]  # half outer\n    add_perms = find_perms_via_reflection(\n        R[0], z, frags[1], plane_3idxs, lat_and_inv=None, max_processes=None\n    )\n    perms = np.vstack((perms, add_perms))\n\n    # rotate inner\n    # add_perms = find_perms_via_alignment(R[0], frags[0], [214, 215, 210, 211], [209, 208, 212, 213], z, lat_and_inv=lat_and_inv, max_processes=max_processes)\n    # perms = np.vstack((perms, add_perms))\n    # sym_group_perms = complete_sym_group(perms, callback=callback)\n\n    # rotate outer\n    # add_perms = find_perms_via_alignment(R[0], frags[1], [361, 360, 368, 369], [363, 362, 356, 357], z, lat_and_inv=lat_and_inv, max_processes=max_processes)\n    # perms = np.vstack((perms, add_perms))\n    # sym_group_perms = complete_sym_group(perms, callback=callback)\n\n    perms = np.unique(perms, axis=0)\n    sym_group_perms = complete_sym_group(perms, callback=callback)\n    print(sym_group_perms.shape)\n\n    return sym_group_perms\n\n    # buckycatcher\n    R = R.copy()  # *0.529177\n    frags = find_frags(R[0], z, lat_and_inv=lat_and_inv)\n\n    perms = np.arange(n_atoms)[None, :]\n\n    # syms of catcher\n    plane_3idxs = [54, 47, 17]  # left to right\n    add_perms = find_perms_via_reflection(\n        R[0], z, frags[0], plane_3idxs, lat_and_inv=None, max_processes=None\n    )\n    perms = np.vstack((perms, add_perms))\n\n    plane_3idxs = [(33, 34), (31, 30), (5, 4)]  # top to bottom\n    add_perms = find_perms_via_reflection(\n        R[0], z, frags[0], plane_3idxs, lat_and_inv=None, max_processes=None\n    )\n    perms = np.vstack((perms, add_perms))\n\n    # move cells\n    # add_perms = find_perms_via_alignment(R[0], frags[1], [128, 129, 127], [133, 132, 134], z, lat_and_inv=lat_and_inv, max_processes=max_processes)\n    # perms = np.vstack((perms, add_perms))\n    # sym_group_perms = complete_sym_group(perms, callback=callback)\n\n    # print(sym_group_perms.shape)\n\n    # rotate cells\n    add_perms = find_perms_via_alignment(\n        R[0],\n        frags[1],\n        [129, 128, 127],\n        [128, 127, 135],\n        z,\n        lat_and_inv=lat_and_inv,\n        max_processes=max_processes,\n    )\n    perms = np.vstack((perms, add_perms))\n    # print(add_perms.shape)\n    # sym_group_perms = complete_sym_group(perms, callback=callback)\n\n    # rotate cells (triangle)\n    # add_perms = find_perms_via_alignment(R[0], frags[1], [132, 129, 134], [129, 134, 132], z, lat_and_inv=lat_and_inv, max_processes=max_processes)\n    # perms = np.vstack((perms, add_perms))\n    sym_group_perms = complete_sym_group(perms, callback=callback)\n\n    # print(perms.shape)\n    print(sym_group_perms.shape)\n\n    # frag 1: bucky ball\n    # perms = find_perms_in_frag(R, z, frags[1], lat_and_inv=lat_and_inv, max_processes=max_processes)\n    # perms = np.vstack((p[None,:], perms))\n\n    # print('perms')\n    # print(perms.shape)\n\n    # perms = np.unique(perms, axis=0)\n    # perms = complete_sym_group(perms, callback=callback)\n\n    # print('perms')\n    # print(perms.shape)\n    # print(sym_group_perms.shape)\n\n    return sym_group_perms\n\n    # NEW\n\n\ndef find_frags(r, z, lat_and_inv=None):\n\n    from ase import Atoms\n    from ase.geometry.analysis import Analysis\n    from scipy.sparse.csgraph import connected_components\n\n    print('Finding permutable non-bonded fragments... (assumes Ang!)')\n\n    lat = None\n    if lat_and_inv:\n        lat = lat_and_inv[0]\n\n    n_atoms = r.shape[0]\n    atoms = Atoms(\n        z, positions=r, cell=lat, pbc=lat is not None\n    )  # only use first molecule in dataset to find connected components (fix me later, maybe) # *0.529177249\n\n    adj = Analysis(atoms).adjacency_matrix[0]\n    _, labels = connected_components(csgraph=adj, directed=False, return_labels=True)\n\n    # frags = []\n    # for label in np.unique(labels):\n    #    frags.append(np.where(labels == label)[0])\n    frags = [np.where(labels == label)[0] for label in np.unique(labels)]\n    n_frags = len(frags)\n\n    if n_frags == n_atoms:\n        print(\n            'Skipping fragment symmetry search (something went wrong, e.g. length unit not in Angstroms, etc.)'\n        )\n        return None\n\n    print('| Found ' + str(n_frags) + ' disconnected fragments.')\n\n    return frags\n\n\ndef find_frag_perms(R, z, lat_and_inv=None, callback=None, max_processes=None):\n\n    from ase import Atoms\n    from ase.geometry.analysis import Analysis\n    from scipy.sparse.csgraph import connected_components\n\n    # TODO: positions must be in Angstrom for this to work!!\n\n    n_train, n_atoms = R.shape[:2]\n    lat, lat_inv = lat_and_inv\n\n    atoms = Atoms(\n        z, positions=R[0], cell=lat, pbc=lat is not None\n    )  # only use first molecule in dataset to find connected components (fix me later, maybe) # *0.529177249\n\n    adj = Analysis(atoms).adjacency_matrix[0]\n    _, labels = connected_components(csgraph=adj, directed=False, return_labels=True)\n\n    # frags = []\n    # for label in np.unique(labels):\n    #    frags.append(np.where(labels == label)[0])\n    frags = [np.where(labels == label)[0] for label in np.unique(labels)]\n    n_frags = len(frags)\n\n    if n_frags == n_atoms:\n        print(\n            'Skipping fragment symmetry search (something went wrong, e.g. length unit not in Angstroms, etc.)'\n        )\n        return [range(n_atoms)]\n\n    # print(labels)\n\n    # from . import ui, io\n    # xyz_str = io.generate_xyz_str(R[0][np.where(labels == 0)[0], :]*0.529177249, z[np.where(labels == 0)[0]])\n    # xyz_str = ui.indent_str(xyz_str, 2)\n    # sprint(xyz_str)\n\n    # NEW\n\n    # uniq_labels = np.unique(labels)\n    # R_cg = np.empty((R.shape[0], len(uniq_labels), R.shape[2]))\n    # z_frags = []\n    # z_cg = []\n    # for label in uniq_labels:\n    #     frag_idxs = np.where(labels == label)[0]\n\n    #     R_cg[:,label,:] = np.mean(R[:,frag_idxs,:], axis=1)\n    #     z_frag = np.sort(z[frag_idxs])\n\n    #     z_frag_label = 0\n    #     if len(z_frags) == 0:\n    #         z_frags.append(z_frag)\n    #     else:\n    #         z_frag_label = np.where(np.all(z_frags == z_frag, axis=1))[0]\n\n    #         if len(z_frag_label) == 0: # not found\n    #             z_frag_label = len(z_frags)\n    #             z_frags.append(z_frag)\n    #         else:\n    #             z_frag_label = z_frag_label[0]\n\n    #     z_cg.append(z_frag_label)\n\n    # print(z_cg)\n    # print(R_cg.shape)\n\n    # perms = find_perms(R_cg, np.array(z_cg), lat_and_inv=lat_and_inv, max_processes=max_processes)\n\n    # print('cg perms')\n    # print(perms)\n\n    # NEW\n\n    # print(n_frags)\n\n    print('| Found ' + str(n_frags) + ' disconnected fragments.')\n\n    # ufrags = np.unique([np.sort(z[frag]) for frag in frags])\n    # print(ufrags)\n\n    # sys.exit()\n\n    # n_frags_unique = 0 # number of unique fragments\n\n    # match fragments to find identical ones (allows permutations of fragments)\n    swap_perms = [np.arange(n_atoms)]\n    for f1 in range(n_frags):\n        for f2 in range(f1 + 1, n_frags):\n\n            sort_idx_f1 = np.argsort(z[frags[f1]])\n            sort_idx_f2 = np.argsort(z[frags[f2]])\n            inv_sort_idx_f2 = inv_perm(sort_idx_f2)\n\n            z1 = z[frags[f1]][sort_idx_f1]\n            z2 = z[frags[f2]][sort_idx_f2]\n\n            if np.array_equal(z1, z2):  # fragment have the same composition\n\n                for ri in range(\n                    min(10, R.shape[0])\n                ):  # only use first molecule in dataset for matching (fix me later)\n\n                    R_match1 = R[ri, frags[f1], :]\n                    R_match2 = R[ri, frags[f2], :]\n\n                    # if np.array_equal(z1, z2):\n\n                    R_pair = np.concatenate(\n                        (R_match1[None, sort_idx_f1, :], R_match2[None, sort_idx_f2, :])\n                    )\n\n                    perms = find_perms(\n                        R_pair, z1, lat_and_inv=lat_and_inv, max_processes=max_processes\n                    )\n\n                    # embed local permutation into global context\n                    for p in perms:\n\n                        match_perm = sort_idx_f1[p][inv_sort_idx_f2]\n\n                        swap_perm = np.arange(n_atoms)\n                        swap_perm[frags[f1]] = frags[f2][match_perm]\n                        swap_perm[frags[f2][match_perm]] = frags[f1]\n                        swap_perms.append(swap_perm)\n\n            # else:\n            #    n_frags_unique += 1\n\n    swap_perms = np.unique(np.array(swap_perms), axis=0)\n\n    # print(swap_perms)\n\n    # print('| Found ' + str(n_frags_unique) + ' (likely to be) *unique* disconnected fragments.')\n\n    # commplete symmetric group\n    sym_group_perms = complete_sym_group(swap_perms)\n    print(\n        '| Found '\n        + str(sym_group_perms.shape[0])\n        + ' fragment permutations after closure.'\n    )\n\n    # return sym_group_perms\n\n    # match fragments with themselves (to find symmetries in each fragment)\n\n    def _frag_perm_to_perm(n_atoms, frag_idxs, frag_perms):\n\n        # frag_idxs - indices of the fragment (one fragment!)\n        # frag_perms - N fragment permutations (Nxn_atoms)\n\n        perms = np.arange(n_atoms)[None, :]\n        for fp in frag_perms:\n\n            p = np.arange(n_atoms)\n            p[frag_idxs] = frag_idxs[fp]\n            perms = np.vstack((p[None, :], perms))\n\n        return perms\n\n    if n_frags > 1:\n        print('| Finding symmetries in individual fragments.')\n        for f in range(n_frags):\n\n            R_frag = R[:, frags[f], :]\n            z_frag = z[frags[f]]\n\n            frag_perms = find_perms(\n                R_frag, z_frag, lat_and_inv=lat_and_inv, max_processes=max_processes\n            )\n\n            perms = _frag_perm_to_perm(n_atoms, frags[f], frag_perms)\n            sym_group_perms = np.vstack((perms, sym_group_perms))\n\n            print('{:d} perms'.format(perms.shape[0]))\n\n        sym_group_perms = np.unique(sym_group_perms, axis=0)\n    sym_group_perms = complete_sym_group(sym_group_perms, callback=callback)\n\n    return sym_group_perms\n\n    # f = 0\n    # perms = find_perms_via_alignment(R[0, :, :], frags[f], [215, 214, 210, 211], [209, 208, 212, 213], z, lat_and_inv=lat_and_inv, max_processes=max_processes)\n    # #perms = find_perms_via_alignment(R[0, :, :], frags[f], [214, 215, 210, 211], [209, 208, 212, 213], z, lat_and_inv=lat_and_inv, max_processes=max_processes)\n    # sym_group_perms = np.vstack((perms[None,:], sym_group_perms))\n    # sym_group_perms = complete_sym_group(sym_group_perms, callback=callback)\n\n    # #print(sym_group_perms.shape)\n\n    # #import sys\n    # #sys.exit()\n\n    # return sym_group_perms\n\n\ndef _frag_perm_to_perm(n_atoms, frag_idxs, frag_perms):\n\n    # frag_idxs - indices of the fragment (one fragment!)\n    # frag_perms - N fragment permutations (Nxn_atoms)\n\n    perms = np.arange(n_atoms)[None, :]\n    for fp in frag_perms:\n\n        p = np.arange(n_atoms)\n        p[frag_idxs] = frag_idxs[fp]\n        perms = np.vstack((p[None, :], perms))\n\n    return perms\n\n\ndef find_perms_in_frag(R, z, frag_idxs, lat_and_inv=None, max_processes=None):\n\n    n_atoms = R.shape[1]\n\n    R_frag = R[:, frag_idxs, :]\n    z_frag = z[frag_idxs]\n\n    frag_perms = find_perms(\n        R_frag, z_frag, lat_and_inv=lat_and_inv, max_processes=max_processes\n    )\n\n    perms = _frag_perm_to_perm(n_atoms, frag_idxs, frag_perms)\n\n    return perms\n\n\ndef find_perms_via_alignment(\n    pts_full,\n    frag_idxs,\n    align_a_idxs,\n    align_b_idxs,\n    z,\n    lat_and_inv=None,\n    max_processes=None,\n):\n\n    # 1. find rotatino that aligns points (Nx3 matrix) in 'align_a_idxs' with points in 'align_b_idxs'\n    # 2. rotate the whole thing\n    # find perms by matching those two structures (match atoms that are closest after transformation)\n\n    # align_a_ctr = np.mean(align_a_pts, axis=0)\n    # align_b_ctr = np.mean(align_b_pts, axis=0)\n\n    # alignment indices are included in fragment\n    assert np.isin(align_a_idxs, frag_idxs).all()\n    assert np.isin(align_b_idxs, frag_idxs).all()\n\n    assert len(align_a_idxs) == len(align_b_idxs)\n\n    # align_a_frag_idxs = np.where(np.in1d(frag_idxs, align_a_idxs))[0]\n    # align_b_frag_idxs = np.where(np.in1d(frag_idxs, align_b_idxs))[0]\n\n    pts = pts_full[frag_idxs, :]\n\n    align_a_pts = pts_full[align_a_idxs, :]\n    align_b_pts = pts_full[align_b_idxs, :]\n\n    ctr = np.mean(pts, axis=0)\n    align_a_pts -= ctr\n    align_b_pts -= ctr\n\n    ab_cov = align_a_pts.T.dot(align_b_pts)\n    u, s, vh = np.linalg.svd(ab_cov)\n    R = u.dot(vh)\n\n    if np.linalg.det(R) < 0:\n        vh[2, :] *= -1  # multiply 3rd column of V by -1\n        R = u.dot(vh)\n\n    pts -= ctr\n    pts_R = pts.copy()\n\n    pts_R = R.dot(pts_R.T).T\n\n    pts += ctr\n    pts_R += ctr\n\n    pts_full_R = pts_full.copy()\n    pts_full_R[frag_idxs, :] = pts_R\n\n    R_pair = np.vstack((pts_full[None, :, :], pts_full_R[None, :, :]))\n\n    # from . import io\n\n    # xyz_str = io.generate_xyz_str(pts_full, z)\n    # print(xyz_str)\n\n    # xyz_str = io.generate_xyz_str(pts_full_R, z)\n    # print(xyz_str)\n\n    # z_frag = z[frag_idxs]\n\n    adj = scipy.spatial.distance.cdist(R_pair[0], R_pair[1], 'euclidean')\n    _, perm = scipy.optimize.linear_sum_assignment(adj)\n\n    # score_before = np.linalg.norm(adj)\n\n    # adj_perm = scipy.spatial.distance.cdist(R_pair[0,:], R_pair[0, perm], 'euclidean')\n    # score = np.linalg.norm(adj_perm)\n\n    # print(score_before)\n    # print(score)\n\n    # print('---')\n\n    # print('data \\'model example\\'', '|', end='')\n    # rint('testing', '|', end='')\n    # n_atoms = pts_full.shape[1]\n    # print(n_atoms)\n\n    # for p in pts_full[:,:]:\n    #    print('H {:.5f} {:.5f} {:.5f}'.format(*p), '|', end='')\n\n    # print('end \\'model example\\';show data')\n\n    # draw selection\n    if False:\n\n        print('---')\n\n        from matplotlib import cm\n\n        viridis = cm.get_cmap('prism')\n        colors = viridis(np.linspace(0, 1, len(align_a_idxs)))\n\n        for i, idx in enumerate(align_a_idxs):\n            color_str = (\n                '['\n                + str(int(colors[i, 0] * 255))\n                + ','\n                + str(int(colors[i, 1] * 255))\n                + ','\n                + str(int(colors[i, 2] * 255))\n                + ']'\n            )\n            print('select atomno=' + str(idx + 1) + '; color ' + color_str)\n\n        for i, idx in enumerate(align_b_idxs):\n            color_str = (\n                '['\n                + str(int(colors[i, 0] * 255))\n                + ','\n                + str(int(colors[i, 1] * 255))\n                + ','\n                + str(int(colors[i, 2] * 255))\n                + ']'\n            )\n            print('select atomno=' + str(idx + 1) + '; color ' + color_str)\n        print('---')\n\n    return perm\n\n\ndef find_perms_via_reflection(\n    r, z, frag_idxs, plane_3idxs, lat_and_inv=None, max_processes=None\n):\n\n    # plane_3idxs can be tuples of atoms (to take their center) or atom indices\n\n    # pts = pts_full[frag_idxs, :]\n    # pts = r.copy()\n\n    # compute normal of plane defined by atoms in 'plane_idxs'\n\n    is_plane_defined_by_bond_centers = type(plane_3idxs[0]) is tuple\n    if is_plane_defined_by_bond_centers:\n        a = (r[plane_3idxs[0][0], :] + r[plane_3idxs[0][1], :]) / 2\n        b = (r[plane_3idxs[1][0], :] + r[plane_3idxs[1][1], :]) / 2\n        c = (r[plane_3idxs[2][0], :] + r[plane_3idxs[2][1], :]) / 2\n    else:\n        a = r[plane_3idxs[0], :]\n        b = r[plane_3idxs[1], :]\n        c = r[plane_3idxs[2], :]\n\n    ab = b - a\n    ab /= np.linalg.norm(ab)\n\n    ac = c - a\n    ac /= np.linalg.norm(ac)\n\n    normal = np.cross(ab, ac)[:, None]\n\n    # compute reflection matrix\n    reflection = np.eye(3) - 2 * normal.dot(normal.T)\n\n    r_R = r.copy()\n    r_R[frag_idxs, :] = reflection.dot(r[frag_idxs, :].T).T\n\n    # R_pair = np.vstack((r[None,:,:], r_R[None,:,:]))\n\n    adj = scipy.spatial.distance.cdist(r, r_R, 'euclidean')\n    _, perm = scipy.optimize.linear_sum_assignment(adj)\n\n    print_perm_colors(perm, r, plane_3idxs)\n\n    # score_before = np.linalg.norm(adj)\n\n    # adj_perm = scipy.spatial.distance.cdist(R_pair[0,:], R_pair[0, perm], 'euclidean')\n    # score = np.linalg.norm(adj_perm)\n\n    return perm\n\n\ndef print_perm_colors(perm, pts, plane_3idxs=None):\n\n    idx_done = []\n    c = -1\n    for i in range(perm.shape[0]):\n        if i not in idx_done and perm[i] not in idx_done:\n            c += 1\n            idx_done += [i]\n            idx_done += [perm[i]]\n\n    from matplotlib import cm\n\n    viridis = cm.get_cmap('prism')\n    colors = viridis(np.linspace(0, 1, c + 1))\n\n    print('---')\n    print('select all; color [255,255,255]')\n\n    if plane_3idxs is not None:\n\n        def pts_str(x):\n            return '{' + str(x[0]) + ', ' + str(x[1]) + ', ' + str(x[2]) + '}'\n\n        is_plane_defined_by_bond_centers = type(plane_3idxs[0]) is tuple\n        if is_plane_defined_by_bond_centers:\n            a = (pts[plane_3idxs[0][0], :] + pts[plane_3idxs[0][1], :]) / 2\n            b = (pts[plane_3idxs[1][0], :] + pts[plane_3idxs[1][1], :]) / 2\n            c = (pts[plane_3idxs[2][0], :] + pts[plane_3idxs[2][1], :]) / 2\n        else:\n            a = pts[plane_3idxs[0], :]\n            b = pts[plane_3idxs[1], :]\n            c = pts[plane_3idxs[2], :]\n\n        print(\n            'draw plane1 300 PLANE '\n            + pts_str(a)\n            + ' '\n            + pts_str(b)\n            + ' '\n            + pts_str(c)\n            + ';color $plane1 green'\n        )\n\n    idx_done = []\n    c = -1\n    for i in range(perm.shape[0]):\n        if i not in idx_done and perm[i] not in idx_done:\n\n            c += 1\n            color_str = (\n                '['\n                + str(int(colors[c, 0] * 255))\n                + ','\n                + str(int(colors[c, 1] * 255))\n                + ','\n                + str(int(colors[c, 2] * 255))\n                + ']'\n            )\n\n            if i != perm[i]:\n                print('select atomno=' + str(i + 1) + '; color ' + color_str)\n                print('select atomno=' + str(perm[i] + 1) + '; color ' + color_str)\n            idx_done += [i]\n            idx_done += [perm[i]]\n\n    print('---')\n\n\ndef inv_perm(perm):\n\n    inv_perm = np.empty(perm.size, perm.dtype)\n    inv_perm[perm] = np.arange(perm.T.size)\n\n    return inv_perm\n"
  },
  {
    "path": "sgdml/utils/ui.py",
    "content": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2021 Stefan Chmiela\n#\n# Permission is hereby granted, free of charge, to any person obtaining a copy\n# of this software and associated documentation files (the \"Software\"), to deal\n# in the Software without restriction, including without limitation the rights\n# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n# copies of the Software, and to permit persons to whom the Software is\n# furnished to do so, subject to the following conditions:\n#\n# The above copyright notice and this permission notice shall be included in all\n# copies or substantial portions of the Software.\n#\n# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n# SOFTWARE.\n\nfrom __future__ import print_function\nfrom functools import partial\n\nfrom .. import __version__, MAX_PRINT_WIDTH, LOG_LEVELNAME_WIDTH\nimport textwrap\nimport re\nimport sys\n\nif sys.version[0] == '3':\n    raw_input = input\n\nimport numpy as np\n\n\ndef yes_or_no(question):\n    \"\"\"\n    Ask for yes/no user input on a question.\n\n    Any response besides ``y`` yields a negative answer.\n\n    Parameters\n    ----------\n        question : :obj:`str`\n            User question.\n    \"\"\"\n\n    reply = raw_input(question + ' (y/n): ').lower().strip()\n    if not reply or reply[0] != 'y':\n        return False\n    else:\n        return True\n\n\nlast_callback_pct = 0\n\n\ndef callback(\n    current,\n    total=1,\n    disp_str='',\n    sec_disp_str=None,\n    done_with_warning=False,\n    newline_when_done=True,\n):\n    \"\"\"\n    Print progress or toggle bar.\n\n    Example (progress):\n    ``[ 45%] Task description (secondary string)``\n\n    Example (toggle, not done):\n    ``[ .. ] Task description (secondary string)``\n\n    Example (toggle, done):\n    ``[DONE] Task description (secondary string)``\n\n    Parameters\n    ----------\n        current : int\n            How many items already processed?\n        total : int, optional\n            Total number of items? If there is only\n            one item, the toggle style is used.\n        disp_str : :obj:`str`, optional\n            Task description.\n        sec_disp_str : :obj:`str`, optional\n            Additional string shown in gray.\n        done_with_warning : bool, optional\n            Indicate that the process did not\n            finish successfully.\n        newline_when_done : bool, optional\n            Finish with a newline character once\n            current=total (default: True)?\n    \"\"\"\n\n    global last_callback_pct\n\n    is_toggle = total == 1\n    is_done = np.isclose(current - total, 0.0)\n\n    bold_color_str = partial(color_str, bold=True)\n\n    if is_toggle:\n\n        if is_done:\n            if done_with_warning:\n                flag_str = bold_color_str('[WARN]', fore_color=YELLOW)\n            else:\n                flag_str = bold_color_str('[DONE]', fore_color=GREEN)\n\n        else:\n            flag_str = bold_color_str('[' + blink_str(' .. ') + ']')\n    else:\n\n        # Only show progress in 10 percent steps when not printing to terminal.\n        pct = int(float(current) * 100 / total)\n        pct = int(np.ceil(pct / 10.0)) * 10 if not sys.stdout.isatty() else pct\n\n        # Do not print, if there is no need to.\n        if not is_done and pct == last_callback_pct:\n            return\n        else:\n            last_callback_pct = pct\n\n        flag_str = bold_color_str(\n            '[{:3d}%]'.format(pct), fore_color=GREEN if is_done else WHITE\n        )\n\n    sys.stdout.write('\\r{} {}'.format(flag_str, disp_str))\n\n    if sec_disp_str is not None:\n        w = MAX_PRINT_WIDTH - LOG_LEVELNAME_WIDTH - len(disp_str) - 1\n        # sys.stdout.write(' \\x1b[90m{0: >{width}}\\x1b[0m'.format(sec_disp_str, width=w))\n        sys.stdout.write(\n            color_str(' {:>{width}}'.format(sec_disp_str, width=w), fore_color=GRAY)\n        )\n\n    if is_done and newline_when_done:\n        sys.stdout.write('\\n')\n\n    sys.stdout.flush()\n\n\n# use this to integrate a callback for a subtask with an existing callback function\n# 'subtask_callback = partial(ui.sec_callback, main_callback=self.callback)'\ndef sec_callback(\n    current, total=1, disp_str=None, sec_disp_str=None, main_callback=None, **kwargs\n):\n    global last_callback_pct\n\n    assert main_callback is not None\n\n    is_toggle = total == 1\n    is_done = np.isclose(current - total, 0.0)\n\n    sec_disp_str = disp_str\n    if is_toggle:\n        sec_disp_str = '{} | {}'.format(disp_str, 'DONE' if is_done else ' .. ')\n    else:\n\n        # Only show progress in 10 percent steps when not printing to terminal.\n        pct = int(float(current) * 100 / total)\n        pct = int(np.ceil(pct / 10.0)) * 10 if not sys.stdout.isatty() else pct\n\n        # Do not print, if there is no need to.\n        if pct == last_callback_pct:\n            return\n\n        last_callback_pct = pct\n        sec_disp_str = '{} | {:3d}%'.format(disp_str, pct)\n\n    main_callback(0, sec_disp_str=sec_disp_str, **kwargs)\n\n\n# COLORS\n\nBLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE, GRAY = list(range(8)) + [60]\nCOLOR_SEQ, RESET_SEQ = '\\033[{:d};{:d};{:d}m', '\\033[0m'\n\nENABLE_COLORED_OUTPUT = (\n    sys.stdout.isatty()\n)  # Running in a real terminal or piped/redirected?\n\n\ndef color_str(str, fore_color=WHITE, back_color=BLACK, bold=False):\n\n    if ENABLE_COLORED_OUTPUT:\n\n        # foreground is set with 30 plus the number of the color, background with 40\n        return (\n            COLOR_SEQ.format(1 if bold else 0, 30 + fore_color, 40 + back_color)\n            + str\n            + RESET_SEQ\n        )\n    else:\n        return str\n\n\ndef blink_str(str):\n\n    return '\\x1b[5m' + str + '\\x1b[0m' if ENABLE_COLORED_OUTPUT else str\n\n\ndef unicode_str(s):\n\n    if sys.version[0] == '3':\n        s = str(s, 'utf-8', 'ignore')\n    else:\n        s = str(s)\n\n    return s.rstrip('\\x00')  # remove null-characters\n\n\ndef gen_memory_str(bytes):\n\n    pwr = 1024\n    n = 0\n    pwr_strs = {0: '', 1: 'K', 2: 'M', 3: 'G', 4: 'T'}\n    while bytes > pwr and n < 4:\n        bytes /= pwr\n        n += 1\n\n    return '{:.{num_dec_pts}f} {}B'.format(\n        bytes, pwr_strs[n], num_dec_pts=max(0, n - 2)\n    )  # 1 decimal point for GB, 2 for TB\n\n\ndef gen_lattice_str(lat):\n\n    lat_str, col_widths = gen_mat_str(lat)\n    desc_str = (' '.join([('{:' + str(w) + '}') for w in col_widths])).format(\n        'a', 'b', 'c'\n    ) + '\\n'\n\n    lat_str = indent_str(lat_str, 21)\n\n    return desc_str + lat_str\n\n\ndef str_plen(str):\n    \"\"\"\n    Returns printable length of string. This function can only account for invisible characters due to string styling with ``color_str``.\n\n    Parameters\n    ----------\n        str : :obj:`str`\n            String.\n\n    Returns\n    -------\n        :obj:`str`\n\n    \"\"\"\n\n    num_colored_subs = str.count(RESET_SEQ)\n    return len(str) - (\n        14 * num_colored_subs\n    )  # 14: length of invisible characters per colored segment\n\n\ndef wrap_str(str, width=MAX_PRINT_WIDTH - LOG_LEVELNAME_WIDTH):\n    \"\"\"\n    Wrap multiline string after a given number of characters. The default maximum line already accounts for the indentation due to the logging level label.\n\n    Parameters\n    ----------\n        str : :obj:`str`\n            Multiline string.\n        width : int, optional\n            Max number of characters in a line.\n\n    Returns\n    -------\n        :obj:`str`\n\n    \"\"\"\n\n    return '\\n'.join(\n        [\n            '\\n'.join(\n                textwrap.wrap(\n                    line,\n                    width + (len(line) - str_plen(line)),\n                    break_long_words=False,\n                    replace_whitespace=False,\n                )\n            )\n            for line in str.splitlines()\n        ]\n    )\n\n\ndef indent_str(str, indent):\n    \"\"\"\n    Indents all lines of a multiline string right by a given number of\n    characters.\n\n    Parameters\n    ----------\n        str : :obj:`str`\n            Multiline string.\n        indent : int\n            Number of characters added in front of each line.\n\n    Returns\n    -------\n        :obj:`str`\n\n    \"\"\"\n\n    return re.sub('^', ' ' * indent, str, flags=re.MULTILINE)\n\n\ndef wrap_indent_str(label, str, width=MAX_PRINT_WIDTH - LOG_LEVELNAME_WIDTH):\n    \"\"\"\n    Wraps and indents a multiline string to arrange it with the provided label in two columns. The default maximum line already accounts for the indentation due to the logging level label.\n\n    Example:\n    ``<label><multiline string>``\n\n    Parameters\n    ----------\n        label : :obj:`str`\n            Label\n        str : :obj:`str`\n            Multiline string.\n\n    Returns\n    -------\n        :obj:`str`\n\n    \"\"\"\n\n    label_len = str_plen(label)\n\n    str = wrap_str(str, width - label_len)\n    str = indent_str(str, label_len)\n\n    return label + str[label_len:]\n\n\ndef merge_col_str(\n    col_str1, col_str2\n):  # merge two multiline strings that represent columns in a table\n    \"\"\"\n    Merges two multiline strings that represent columns in a table by\n    concatenating each pair of lines.\n\n    Note\n    ----\n        Both strings must have the same number of lines.\n\n    Parameters\n    ----------\n        col_str1 : :obj:`str`\n            First multiline string.\n        col_str2 : :obj:`str`\n            Second multiline string.\n\n    Returns\n    -------\n        :obj:`str`\n\n    \"\"\"\n\n    return '\\n'.join(\n        [\n            ' '.join([c1, c2])\n            for c1, c2 in zip(col_str1.split('\\n'), col_str2.split('\\n'))\n        ]\n    )\n\n\ndef gen_mat_str(mat):\n    \"\"\"\n    Converts a matrix to a multiline string such that the decimal points\n    align in each column. Trailing zeros are replaced with spaces.\n\n    Parameters\n    ----------\n        mat : :obj:`numpy.ndarray`\n\n    Returns\n    -------\n        :obj:`str`\n            String representation of matrix.\n\n    \"\"\"\n\n    def _int_len(\n        x,\n    ):  # length of string representation before decimal point (including sign)\n        return len(str(int(abs(x)))) + (0 if x >= 0 else 1)\n\n    def _dec_len(x):  # length of string representation after decimal point\n\n        x_str_split = '{:g}'.format(x).split('.')\n        return len(x_str_split[1]) if len(x_str_split) > 1 else 0\n\n    def _max_int_len_for_col(\n        mat, col\n    ):  # length of string representation before decimal point for each col\n        col_min = np.min(mat[:, col])\n        col_max = np.max(mat[:, col])\n        return max(_int_len(col_min), _int_len(col_max))\n\n    def _max_dec_len_for_col(\n        mat, col\n    ):  # length of string representation after decimal point for each col\n        return max([_dec_len(cell) for cell in mat[:, col]])\n\n    n_cols = mat.shape[1]\n    col_int_widths = [_max_int_len_for_col(mat, i) for i in range(n_cols)]\n    col_dec_widths = [_max_dec_len_for_col(mat, i) for i in range(n_cols)]\n    col_widths = [iw + cd + 1 for iw, cd in zip(col_int_widths, col_dec_widths)]\n\n    mat_str = ''\n    for row in mat:\n        if mat_str != '':\n            mat_str += '\\n'\n        mat_str += ' '.join(\n            ' ' * max(col_int_widths[j] - _int_len(x), 0)\n            + ('{: <' + str(_int_len(x) + col_dec_widths[j] + 1) + 'g}').format(x)\n            for j, x in enumerate(row)\n        )\n\n    return mat_str, col_widths\n\n\ndef gen_range_str(min, max):\n    \"\"\"\n    Generates a string that shows a minimum and maximum value, as well as the range.\n\n    Example:\n    ``<min> |-- <range> --| <max>``\n\n    Parameters\n    ----------\n        min : float\n            Minimum value.\n        max : float\n            Maximum value.\n\n    Returns\n    -------\n        :obj:`str`\n\n    \"\"\"\n\n    return '{:<.3f} |-- {:^8.3f} --| {:<9.3f}'.format(min, max - min, max)\n\n\ndef print_step_title(title_str, sec_title_str='', underscore=True):\n\n    if sec_title_str != '':\n        sec_title_str = ' ' + sec_title_str\n\n    underscore_str = '\\n' + '-' * MAX_PRINT_WIDTH if underscore else ''\n\n    print(\n        '\\n'\n        + color_str(\n            ' ' + title_str + ' ', fore_color=BLACK, back_color=WHITE, bold=True\n        )\n        + sec_title_str\n        + underscore_str\n    )\n\n\ndef print_two_column_str(str, sec_str=''):\n\n    sec_str = color_str(\n        '{:>{width}}'.format(sec_str, width=MAX_PRINT_WIDTH - str_plen(str) - 1),\n        fore_color=GRAY,\n    )\n    print('{} {}'.format(str, sec_str))\n\n    # print(\n    #     '{} \\x1b[90m{:>{width}}\\x1b[0m'.format(\n    #         str, sec_str, width=MAX_PRINT_WIDTH - str_plen(str) - 1\n    #     )\n    # )\n\n\ndef print_lattice(lat=None, inset=False):\n\n    from . import io\n\n    lat_str = 'n/a'\n    if lat is not None:\n        lat_str = gen_lattice_str(lat)\n        lengths, angles = io.lattice_vec_to_par(lat)\n\n    if inset:\n        print('    {:<16} {}'.format('Lattice:', lat_str))\n    else:\n        print('  {:<18} {}'.format('Lattice:', lat_str))\n    if lat is not None:\n        print('    {:<16} a = {:g}, b = {:g}, c = {:g}'.format('Lengths:', *lengths))\n        print(\n            '    {:<16} alpha = {:g}, beta = {:g}, gamma = {:g}'.format(\n                'Angles [deg]:', *angles\n            )\n        )\n"
  }
]