Full Code of stefanch/sGDML for AI

master a6ae5e86f88c cached
28 files
380.5 KB
92.7k tokens
169 symbols
1 requests
Download .txt
Showing preview only (394K chars total). Download the full file or copy to clipboard to get everything.
Repository: stefanch/sGDML
Branch: master
Commit: a6ae5e86f88c
Files: 28
Total size: 380.5 KB

Directory structure:
gitextract_7dj7aa45/

├── .gitignore
├── LICENSE.txt
├── README.md
├── pyproject.toml
├── scripts/
│   ├── sgdml_dataset_from_aims.py
│   ├── sgdml_dataset_from_extxyz.py
│   ├── sgdml_dataset_from_ipi.py
│   ├── sgdml_dataset_to_extxyz.py
│   ├── sgdml_dataset_via_ase.py
│   └── sgdml_datasets_from_model.py
├── setup.cfg
├── setup.py
└── sgdml/
    ├── __init__.py
    ├── cli.py
    ├── get.py
    ├── intf/
    │   ├── __init__.py
    │   └── ase_calc.py
    ├── predict.py
    ├── solvers/
    │   ├── __init__.py
    │   ├── analytic.py
    │   └── iterative.py
    ├── torchtools.py
    ├── train.py
    └── utils/
        ├── __init__.py
        ├── desc.py
        ├── io.py
        ├── perm.py
        └── ui.py

================================================
FILE CONTENTS
================================================

================================================
FILE: .gitignore
================================================

.DS_Store

# Compiled python modules.
*.pyc

# Setuptools distribution folder.
/dist/

# Python egg metadata, regenerated from source files by setuptools.
/*.egg-info
/*.egg
sgdml/_bmark_cache.npz


================================================
FILE: LICENSE.txt
================================================
MIT License

Copyright (c) 2018-2022 Stefan Chmiela

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

================================================
FILE: README.md
================================================
# Symmetric Gradient Domain Machine Learning (sGDML)

For more details visit: [sgdml.org](http://sgdml.org/)  
Documentation can be found here: [docs.sgdml.org](http://docs.sgdml.org/)

#### Requirements:
- Python 3.7+
- PyTorch (>=1.8)
- NumPy (>=1.19)
- SciPy (>=1.1)

#### Optional:
- ASE (>=3.16.2) (to run atomistic simulations)

## Getting started

### Stable release

Most systems come with the default package manager for Python ``pip`` already preinstalled. Install ``sgdml`` by simply calling:

```
$ pip install sgdml
```

The ``sgdml`` command-line interface and the corresponding Python API can now be used from anywhere on the system.

### Development version

#### (1) Clone the repository

```
$ git clone https://github.com/stefanch/sGDML.git
$ cd sGDML
```

...or update your existing local copy with

```
$ git pull origin master
```

#### (2) Install

```
$ pip install -e .
```

Using the flag ``--user``, you can tell ``pip`` to install the package to the current users's home directory, instead of system-wide. This option might require you to update your system's ``PATH`` variable accordingly.


### Optional dependencies

Some functionality of this package relies on third-party libraries that are not installed by default. These optional dependencies (or "package extras") are specified during installation using the "square bracket syntax":

```
$ pip install sgdml[<optional1>]
```

#### Atomic Simulation Environment (ASE)

If you are interested in interfacing with [ASE](https://wiki.fysik.dtu.dk/ase/) to perform atomistic simulations (see [here](http://docs.sgdml.org/applications.html) for examples), use the ``ase`` keyword:

```
$ pip install sgdml[ase]
```

## Reconstruct your first force field

Download one of the example datasets:

```
$ sgdml-get dataset ethanol_dft
```

Train a force field model:

```
$ sgdml all ethanol_dft.npz 200 1000 5000
```

## Query a force field

```python
import numpy as np
from sgdml.predict import GDMLPredict
from sgdml.utils import io

r,_ = io.read_xyz('geometries/ethanol.xyz') # 9 atoms
print(r.shape) # (1,27)

model = np.load('models/ethanol.npz')
gdml = GDMLPredict(model)
e,f = gdml.predict(r)
print(e.shape) # (1,)
print(f.shape) # (1,27)
```

## Authors

* Stefan Chmiela
* Jan Hermann

We appreciate and welcome contributions and would like to thank the following people for participating in this project:

* Huziel Sauceda
* Igor Poltavsky
* Luis Gálvez
* Danny Panknin
* Grégory Fonseca
* Anton Charkin-Gorbulin

## References

* [1] Chmiela, S., Tkatchenko, A., Sauceda, H. E., Poltavsky, I., Schütt, K. T., Müller, K.-R.,
*Machine Learning of Accurate Energy-conserving Molecular Force Fields.*
Science Advances, 3(5), e1603015 (2017)   
[10.1126/sciadv.1603015](http://dx.doi.org/10.1126/sciadv.1603015)

* [2] Chmiela, S., Sauceda, H. E., Müller, K.-R., Tkatchenko, A.,
*Towards Exact Molecular Dynamics Simulations with Machine-Learned Force Fields.*
Nature Communications, 9(1), 3887 (2018)   
[10.1038/s41467-018-06169-2](https://doi.org/10.1038/s41467-018-06169-2)

* [3] Chmiela, S., Sauceda, H. E., Poltavsky, I., Müller, K.-R., Tkatchenko, A.,
*sGDML: Constructing Accurate and Data Efficient Molecular Force Fields Using Machine Learning.*
Computer Physics Communications, 240, 38-45 (2019)
[10.1016/j.cpc.2019.02.007](https://doi.org/10.1016/j.cpc.2019.02.007)

* [4] Chmiela, S., Vassilev-Galindo, V., Unke, O. T., Kabylda, A., Sauceda, H. E., Tkatchenko, A., Müller, K.-R.,
*Accurate Global Machine Learning Force Fields for Molecules With Hundreds of Atoms.*
Science Advances, 9(2), e1603015 (2023)
[10.1126/sciadv.adf0873](https://doi.org/10.1126/sciadv.adf0873)

================================================
FILE: pyproject.toml
================================================
[tool.black]
skip-string-normalization = true
skip-numeric-underscore-normalization = true


================================================
FILE: scripts/sgdml_dataset_from_aims.py
================================================
#!/usr/bin/python

# MIT License
#
# Copyright (c) 2018-2022 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from __future__ import print_function

import argparse
import os
import sys

import numpy as np

from sgdml.utils import io, ui


def read_reference_data(f):  # noqa C901
    eV_to_kcalmol = 0.036749326 / 0.0015946679

    e_next, f_next, geo_next = False, False, False
    n_atoms = None
    R, z, E, F = [], [], [], []

    geo_idx = 0
    for line in f:
        if n_atoms:
            cols = line.split()
            if e_next:
                E.append(float(cols[5]))
                e_next = False
            elif f_next:
                a = int(cols[1]) - 1
                F.append(list(map(float, cols[2:5])))
                if a == n_atoms - 1:
                    f_next = False
            elif geo_next:
                if 'atom' in cols:
                    a_count += 1  # noqa: F821
                    R.append(list(map(float, cols[1:4])))

                    if geo_idx == 0:
                        z.append(io._z_str_to_z_dict[cols[4]])

                    if a_count == n_atoms:
                        geo_next = False
                        geo_idx += 1
            elif 'Energy and forces in a compact form:' in line:
                e_next = True
            elif 'Total atomic forces (unitary forces cleaned) [eV/Ang]:' in line:
                f_next = True
            elif (
                'Atomic structure (and velocities) as used in the preceding time step:'
                in line
            ):
                geo_next = True
                a_count = 0
        elif 'The structure contains' in line and 'atoms,  and a total of' in line:
            n_atoms = int(line.split()[3])
            print('Number atoms per geometry:      {:>7d}'.format(n_atoms))
            continue

        if geo_idx > 0 and geo_idx % 1000 == 0:
            sys.stdout.write("\rNumber geometries found so far: {:>7d}".format(geo_idx))
            sys.stdout.flush()
    sys.stdout.write("\rNumber geometries found so far: {:>7d}".format(geo_idx))
    sys.stdout.flush()
    print(
        '\n'
        + ui.color_str('[INFO]', bold=True)
        + ' Energies and forces have been converted from eV to kcal/mol(/Ang)'
    )

    R = np.array(R).reshape(-1, n_atoms, 3)
    z = np.array(z)
    E = np.array(E) * eV_to_kcalmol
    F = np.array(F).reshape(-1, n_atoms, 3) * eV_to_kcalmol

    f.close()
    return (R, z, E, F)


parser = argparse.ArgumentParser(description='Creates a dataset from FHI-aims format.')
parser.add_argument(
    'dataset',
    metavar='<dataset>',
    type=argparse.FileType('r'),
    help='path to xyz dataset file',
)
parser.add_argument(
    '-o',
    '--overwrite',
    dest='overwrite',
    action='store_true',
    help='overwrite existing dataset file',
)
args = parser.parse_args()
dataset = args.dataset

name = os.path.splitext(os.path.basename(dataset.name))[0]
dataset_file_name = name + '.npz'

dataset_exists = os.path.isfile(dataset_file_name)
if dataset_exists and args.overwrite:
    print(ui.color_str('[INFO]', bold=True) + ' Overwriting existing dataset file.')
if not dataset_exists or args.overwrite:
    print('Writing dataset to \'%s\'...' % dataset_file_name)
else:
    sys.exit(
        ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) + ' Dataset \'%s\' already exists.' % dataset_file_name
    )

R, z, E, F = read_reference_data(dataset)

# Prune all arrays to same length.
n_mols = min(min(R.shape[0], F.shape[0]), E.shape[0])
if n_mols != R.shape[0] or n_mols != F.shape[0] or n_mols != E.shape[0]:
    print(
        ui.color_str('[WARN]', fore_color=ui.YELLOW, bold=True)
        + ' Incomplete output detected: Final dataset was pruned to %d points.' % n_mols
    )
R = R[:n_mols, :, :]
F = F[:n_mols, :, :]
E = E[:n_mols]

# Base variables contained in every model file.
base_vars = {
    'type': 'd',
    'R': R,
    'z': z,
    'E': E[:, None],
    'F': F,
    'e_unit': 'kcal/mol',
    'r_unit': 'Ang',
    'name': name,
    'theory': 'unknown',
}

base_vars['F_min'], base_vars['F_max'] = np.min(F.ravel()), np.max(F.ravel())
base_vars['F_mean'], base_vars['F_var'] = np.mean(F.ravel()), np.var(F.ravel())

base_vars['E_min'], base_vars['E_max'] = np.min(E), np.max(E)
base_vars['E_mean'], base_vars['E_var'] = np.mean(E), np.var(E)

base_vars['md5'] = io.dataset_md5(base_vars)

np.savez_compressed(dataset_file_name, **base_vars)
print(ui.color_str('DONE', fore_color=ui.GREEN, bold=True))


================================================
FILE: scripts/sgdml_dataset_from_extxyz.py
================================================
#!/usr/bin/python

# MIT License
#
# Copyright (c) 2018-2022 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from __future__ import print_function

import argparse
import os
import sys

try:
    from ase.io import read
except ImportError:
    raise ImportError('Optional ASE dependency not found! Please run \'pip install sgdml[ase]\' to install it.')

import numpy as np

from sgdml import __version__
from sgdml.utils import io, ui

if sys.version[0] == '3':
    raw_input = input


# Note: assumes that the atoms in each molecule are in the same order.
def read_nonstd_ext_xyz(f):
    n_atoms = None

    R, z, E, F = [], [], [], []
    for i, line in enumerate(f):
        line = line.strip()
        if not n_atoms:
            n_atoms = int(line)
            print('Number atoms per geometry: {:,}'.format(n_atoms))

        file_i, line_i = divmod(i, n_atoms + 2)

        if line_i == 1:
            try:
                e = float(line)
            except ValueError:
                pass
            else:
                E.append(e)

        cols = line.split()
        if line_i >= 2:
            R.append(list(map(float, cols[1:4])))
            if file_i == 0:  # first molecule
                z.append(io._z_str_to_z_dict[cols[0]])
            F.append(list(map(float, cols[4:7])))

        if file_i % 1000 == 0:
            sys.stdout.write('\rNumber geometries found so far: {:,}'.format(file_i))
            sys.stdout.flush()
    sys.stdout.write('\rNumber geometries found so far: {:,}'.format(file_i))
    sys.stdout.flush()
    print()

    R = np.array(R).reshape(-1, n_atoms, 3)
    z = np.array(z)
    E = None if not E else np.array(E)
    F = np.array(F).reshape(-1, n_atoms, 3)

    if F.shape[0] != R.shape[0]:
        sys.exit(
            ui.color_str('[FAIL]', fore_color=ui.RED, bold=True)
            + ' Force labels are missing from dataset or are incomplete!'
        )

    f.close()
    return (R, z, E, F)

# Extracts info string for each frame.
def extract_info_from_extxyz(file_path):
    infos = []

    with open(file_path) as f:
        lines = f.readlines()

    i = 0
    while i < len(lines):
        try:
            n_atoms = int(lines[i])
        except ValueError:
            raise ValueError(f"Invalid atom count at line {i + 1}")

        if i + 1 >= len(lines):
            break

        comment_line = lines[i + 1].strip()
        info = {}
        for token in comment_line.split():
            if "=" in token:
                key, val = token.split("=", 1)
                val = val.strip('"')
                try:
                    val = float(val)
                except ValueError:
                    pass
                info[key] = val
        infos.append(info)

        i += 2 + n_atoms

    return infos


parser = argparse.ArgumentParser(
    description='Creates a dataset from extended XYZ format.'
)
parser.add_argument(
    'dataset',
    metavar='<dataset>',
    type=argparse.FileType('r'),
    help='path to extended xyz dataset file',
)
parser.add_argument(
    '-o',
    '--overwrite',
    dest='overwrite',
    action='store_true',
    help='overwrite existing dataset file',
)
args = parser.parse_args()
dataset = args.dataset


name = os.path.splitext(os.path.basename(dataset.name))[0]
dataset_file_name = name + '.npz'

dataset_exists = os.path.isfile(dataset_file_name)
if dataset_exists and args.overwrite:
    print(ui.color_str('[INFO]', bold=True) + ' Overwriting existing dataset file.')
if not dataset_exists or args.overwrite:
    print('Writing dataset to \'{}\'...'.format(dataset_file_name))
else:
    sys.exit(
        ui.color_str('[FAIL]', fore_color=ui.RED, bold=True)
        + ' Dataset \'{}\' already exists.'.format(dataset_file_name)
    )

lattice, R, z, E, F = None, None, None, None, None

mols = read(dataset.name, format='extxyz', index=':')
#calc = mols[0].get_calculator() # depreciated
calc = mols[0].calc
is_extxyz = calc is not None
if is_extxyz:

    print("\rNumber geometries found: {:,}\n".format(len(mols)))

    if 'forces' not in calc.results:
        sys.exit(
            ui.color_str('[FAIL]', fore_color=ui.RED, bold=True)
            + ' Forces are missing in the input file!'
        )

    lattice = np.array(mols[0].get_cell().T)
    if not np.any(lattice): # all zeros
        print(
            ui.color_str('[INFO]', bold=True)
            + ' No lattice vectors specified in extended XYZ file.'
        )
        lattice = None

    Z = np.array([mol.get_atomic_numbers() for mol in mols])
    all_z_the_same = (Z == Z[0]).all()
    if not all_z_the_same:
        sys.exit(
            ui.color_str('[FAIL]', fore_color=ui.RED, bold=True)
            + ' Order of atoms changes accross dataset.'
        )

    R = np.array([mol.get_positions() for mol in mols])
    z = Z[0]

    # ASE did not parse info string. Try doing it manually.
    if not mols[0].info:

        print(
            ui.color_str('[INFO]', bold=True)
            + ' ASE did not parse info string completely. Try doing it manually.'
        )

        infos = extract_info_from_extxyz(dataset.name)
        for mol, info in zip(mols, infos):
            mol.info.update(info)

    if 'Energy' in mols[0].info:
        E = np.array([mol.info['Energy'] for mol in mols])
    if 'energy' in mols[0].info:
        E = np.array([mol.info['energy'] for mol in mols])
    F = np.array([mol.get_forces() for mol in mols])

else:  # legacy non-standard XYZ format

    with open(dataset.name) as f:
        R, z, E, F = read_nonstd_ext_xyz(f)

# Base variables contained in every model file.
base_vars = {
    'type': 'd',
    'code_version': __version__,
    'name': name,
    'theory': 'unknown',
    'R': R,
    'z': z,
    'F': F,
}

base_vars['F_min'], base_vars['F_max'] = np.min(F.ravel()), np.max(F.ravel())
base_vars['F_mean'], base_vars['F_var'] = np.mean(F.ravel()), np.var(F.ravel())

print('Please provide a description of the length unit used in your input file, e.g. \'Ang\' or \'au\': ')
print('Note: This string will be stored in the dataset file and passed on to models files for later reference.')
r_unit = raw_input('> ').strip()
if r_unit != '':
    base_vars['r_unit'] = r_unit

print('Please provide a description of the energy unit used in your input file, e.g. \'kcal/mol\' or \'eV\': ')
print('Note: This string will be stored in the dataset file and passed on to models files for later reference.')
e_unit = raw_input('> ').strip()
if e_unit != '':
    base_vars['e_unit'] = e_unit

if E is not None:
    base_vars['E'] = E
    base_vars['E_min'], base_vars['E_max'] = np.min(E), np.max(E)
    base_vars['E_mean'], base_vars['E_var'] = np.mean(E), np.var(E)
else:
    print(ui.color_str('[INFO]', bold=True) + ' No energy labels found in dataset.')

if lattice is not None:
    base_vars['lattice'] = lattice

base_vars['md5'] = io.dataset_md5(base_vars)
np.savez_compressed(dataset_file_name, **base_vars)
print(ui.color_str('[DONE]', fore_color=ui.GREEN, bold=True))


================================================
FILE: scripts/sgdml_dataset_from_ipi.py
================================================
#!/usr/bin/python

# MIT License
#
# Copyright (c) 2018 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from __future__ import print_function

import argparse
import os
import sys

import numpy as np

from sgdml.utils import io, ui


def raw_input_float(prompt):
    while True:
        try:
            return float(input(prompt))
        except ValueError:
            print(ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) + ' That is not a valid float.')


# Assumes that the atoms in each molecule are in the same order.
def read_concat_xyz(f):
    n_atoms = None

    R, z = [], []
    for i, line in enumerate(f):
        line = line.strip()
        if not n_atoms:
            n_atoms = int(line)
            print('Number atoms per geometry:      {:>7d}'.format(n_atoms))

        file_i, line_i = divmod(i, n_atoms + 2)

        cols = line.split()
        if line_i >= 2:
            if file_i == 0:  # first molecule
                z.append(io._z_str_to_z_dict[cols[0]])
            R.append(list(map(float, cols[1:4])))

        if file_i % 1000 == 0:
            sys.stdout.write("\rNumber geometries found so far: {:>7d}".format(file_i))
            sys.stdout.flush()
    sys.stdout.write("\rNumber geometries found so far: {:>7d}\n".format(file_i))
    sys.stdout.flush()

    # Only keep complete entries.
    R = R[: int(n_atoms * np.floor(len(R) / float(n_atoms)))]

    R = np.array(R).reshape(-1, n_atoms, 3)
    z = np.array(z)

    f.close()
    return (R, z)


def read_out_file(f, col):

    E = []
    for i, line in enumerate(f):
        line = line.strip()
        if line[0] != '#':  # Ignore comments.
            E.append(float(line.split()[col]))
        if i % 1000 == 0:
            sys.stdout.write("\rNumber lines processed so far:  {:>7d}".format(len(E)))
            sys.stdout.flush()
    sys.stdout.write("\rNumber lines processed so far:  {:>7d}\n".format(len(E)))
    sys.stdout.flush()

    return np.array(E)


parser = argparse.ArgumentParser(
    description='Creates a dataset from extended [TODO] format.'
)
parser.add_argument(
    'geometries',
    metavar='<geometries>',
    type=argparse.FileType('r'),
    help='path to XYZ geometry file',
)
parser.add_argument(
    'forces',
    metavar='<forces>',
    type=argparse.FileType('r'),
    help='path to XYZ force file',
)
parser.add_argument(
    'energies',
    metavar='<energies>',
    type=argparse.FileType('r'),
    help='path to CSV force file',
)
parser.add_argument(
    'energy_col',
    metavar='<energy_col>',
    type=lambda x: io.is_strict_pos_int(x),
    help='which column to parse from energy file (zero based)',
    nargs='?',
    default=0,
)
parser.add_argument(
    '-o',
    '--overwrite',
    dest='overwrite',
    action='store_true',
    help='overwrite existing dataset file',
)
args = parser.parse_args()
geometries = args.geometries
forces = args.forces
energies = args.energies
energy_col = args.energy_col

name = os.path.splitext(os.path.basename(geometries.name))[0]
dataset_file_name = name + '.npz'

dataset_exists = os.path.isfile(dataset_file_name)
if dataset_exists and args.overwrite:
    print(ui.color_str('[INFO]', bold=True) + ' Overwriting existing dataset file.')
if not dataset_exists or args.overwrite:
    print('Writing dataset to \'%s\'...' % dataset_file_name)
else:
    sys.exit(
        ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) + ' Dataset \'%s\' already exists.' % dataset_file_name
    )


print('Reading geometries...')
R, z = read_concat_xyz(geometries)

print('Reading forces...')
F, _ = read_concat_xyz(forces)

print('Reading energies from column %d...' % energy_col)
E = read_out_file(energies, energy_col)

# Prune all arrays to same length.
n_mols = min(min(R.shape[0], F.shape[0]), E.shape[0])
if n_mols != R.shape[0] or n_mols != F.shape[0] or n_mols != E.shape[0]:
    print(
        ui.color_str('[WARN]', fore_color=ui.YELLOW, bold=True)
        + ' Incomplete output detected: Final dataset was pruned to %d points.' % n_mols
    )
R = R[:n_mols, :, :]
F = F[:n_mols, :, :]
E = E[:n_mols]

print(
    ui.color_str('[INFO]', bold=True)
    + ' Geometries, forces and energies must have consistent units.'
)
R_conv_fact = raw_input_float('Unit conversion factor for geometries: ')
R = R * R_conv_fact
F_conv_fact = raw_input_float('Unit conversion factor for forces: ')
F = F * F_conv_fact
E_conv_fact = raw_input_float('Unit conversion factor for energies: ')
E = E * E_conv_fact

# Base variables contained in every model file.
base_vars = {
    'type': 'd',
    'R': R,
    'z': z,
    'E': E[:, None],
    'F': F,
    'name': name,
    'theory': 'unknown',
}
base_vars['md5'] = io.dataset_md5(base_vars)

np.savez_compressed(dataset_file_name, **base_vars)
ui.color_str('[DONE]', fore_color=ui.GREEN, bold=True)


================================================
FILE: scripts/sgdml_dataset_to_extxyz.py
================================================
#!/usr/bin/python

# MIT License
#
# Copyright (c) 2018-2019 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from __future__ import print_function

import argparse
import os
import sys

import numpy as np

from sgdml.utils import io, ui


parser = argparse.ArgumentParser(
    description='Converts a native dataset file to extended XYZ format.'
)
parser.add_argument(
    'dataset',
    metavar='<dataset>',
    type=lambda x: io.is_file_type(x, 'dataset'),
    help='path to dataset file',
)
parser.add_argument(
    '-o',
    '--overwrite',
    dest='overwrite',
    action='store_true',
    help='overwrite existing xyz dataset file',
)

args = parser.parse_args()
dataset_path, dataset = args.dataset

name = os.path.splitext(os.path.basename(dataset_path))[0]
dataset_file_name = name + '.xyz'

xyz_exists = os.path.isfile(dataset_file_name)
if xyz_exists and args.overwrite:
    print(ui.color_str('[INFO]', bold=True) + ' Overwriting existing xyz dataset file.')
if not xyz_exists or args.overwrite:
    print(ui.color_str('[INFO]', bold=True) + ' Writing dataset to \'{}\'...'.format(dataset_file_name))
else:
    sys.exit(
        ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) + ' Dataset \'{}\' already exists.'.format(dataset_file_name)
    )

R = dataset['R']
z = dataset['z']
F = dataset['F']

lattice = dataset['lattice'] if 'lattice' in dataset else None

try:
    with open(dataset_file_name, 'w') as file:

        n = R.shape[0]
        for i, r in enumerate(R):

            e = np.squeeze(dataset['E'][i]) if 'E' in dataset else None
            f = dataset['F'][i,:,:]
            ext_xyz_str = io.generate_xyz_str(r, z, e=e, f=f, lattice=lattice) + '\n'

            file.write(ext_xyz_str)

            progr = float(i) / (n - 1)
            ui.callback(i, n - 1, disp_str='Exporting %d data points...' % n)
            
except IOError:
    sys.exit("ERROR: Writing xyz file failed.")

print()


================================================
FILE: scripts/sgdml_dataset_via_ase.py
================================================
#!/usr/bin/python

# MIT License
#
# Copyright (c) 2018-2022 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from __future__ import print_function

import argparse
import os
import sys

try:
    from ase.io import read
except ImportError:
    raise ImportError('Optional ASE dependency not found! Please run \'pip install sgdml[ase]\' to install it.')

import numpy as np

from sgdml import __version__
from sgdml.utils import io, ui

if sys.version[0] == '3':
    raw_input = input


parser = argparse.ArgumentParser(
    description='Creates a dataset from any input format supported by ASE.'
)
parser.add_argument(
    'dataset',
    metavar='<dataset>',
    type=argparse.FileType('r'),
    help='path to input dataset file',
)
parser.add_argument(
    '-o',
    '--overwrite',
    dest='overwrite',
    action='store_true',
    help='overwrite existing dataset file',
)
args = parser.parse_args()
dataset = args.dataset


name = os.path.splitext(os.path.basename(dataset.name))[0]
dataset_file_name = name + '.npz'

dataset_exists = os.path.isfile(dataset_file_name)
if dataset_exists and args.overwrite:
    print(ui.color_str('[INFO]', bold=True) + ' Overwriting existing dataset file.')
if not dataset_exists or args.overwrite:
    print('Writing dataset to \'{}\'...'.format(dataset_file_name))
else:
    sys.exit(
        ui.color_str('[FAIL]', fore_color=ui.RED, bold=True)
        + ' Dataset \'{}\' already exists.'.format(dataset_file_name)
    )

mols = read(dataset.name, index=':')

# filter incomplete outputs from trajectory
mols = [mol for mol in mols if mol.get_calculator() is not None]

lattice, R, z, E, F = None, None, None, None, None

calc = mols[0].get_calculator()

print("\rNumber geometries: {:,}".format(len(mols)))
#print("\rAvailable properties: " + ', '.join(calc.results))
print()

if 'forces' not in calc.results:
    sys.exit(
        ui.color_str('[FAIL]', fore_color=ui.RED, bold=True)
        + ' Forces are missing in the input file!'
    )

lattice = np.array(mols[0].get_cell().T)
if not np.any(lattice):
    print(
        ui.color_str('[INFO]', bold=True)
        + ' No lattice vectors specified.'
    )
    lattice = None

Z = np.array([mol.get_atomic_numbers() for mol in mols])
all_z_the_same = (Z == Z[0]).all()
if not all_z_the_same:
    sys.exit(
        ui.color_str('[FAIL]', fore_color=ui.RED, bold=True)
        + ' Order of atoms changes accross dataset.'
    )

R = np.array([mol.get_positions() for mol in mols])
z = Z[0]

if 'Energy' in mols[0].info:
    E = np.array([float(mol.info['Energy']) for mol in mols])
else:
    E = np.array([mol.get_potential_energy() for mol in mols])
F = np.array([mol.get_forces() for mol in mols])

print('Please provide a name for this dataset. Otherwise the original filename will be reused.')
custom_name = raw_input('> ').strip()
if custom_name != '':
    name = custom_name

print('Please provide a descriptor for the level of theory used to create this dataset.')
theory = raw_input('> ').strip()
if theory == '':
    theory = 'unknown'

# Base variables contained in every model file.
base_vars = {
    'type': 'd',
    'code_version': __version__,
    'name': name,
    'theory': theory,
    'R': R,
    'z': z,
    'F': F,
}

base_vars['F_min'], base_vars['F_max'] = np.min(F.ravel()), np.max(F.ravel())
base_vars['F_mean'], base_vars['F_var'] = np.mean(F.ravel()), np.var(F.ravel())

print('If you want to convert your original length unit, please provide a conversion factor (default: 1.0): ')
R_to_new_unit = raw_input('> ').strip()
if R_to_new_unit != '':
    R_to_new_unit = float(R_to_new_unit)
else:
    R_to_new_unit = 1.0

print('If you want to convert your original energy unit, please provide a conversion factor (default: 1.0): ')
E_to_new_unit = raw_input('> ').strip()
if E_to_new_unit != '':
    E_to_new_unit = float(E_to_new_unit)
else:
    E_to_new_unit = 1.0

print('Please provide a description of the length unit, e.g. \'Ang\' or \'au\': ')
print('Note: This string will be stored in the dataset file and passed on to models files for later reference.')
r_unit = raw_input('> ').strip()
if r_unit != '':
    base_vars['r_unit'] = r_unit

print('Please provide a description of the energy unit, e.g. \'kcal/mol\' or \'eV\': ')
print('Note: This string will be stored in the dataset file and passed on to models files for later reference.')
e_unit = raw_input('> ').strip()
if e_unit != '':
    base_vars['e_unit'] = e_unit

if E is not None:
    base_vars['E'] = E * E_to_new_unit
    base_vars['E_min'], base_vars['E_max'] = np.min(E), np.max(E)
    base_vars['E_mean'], base_vars['E_var'] = np.mean(E), np.var(E)
else:
    print(ui.color_str('[INFO]', bold=True) + ' No energy labels found in dataset.')

base_vars['R'] *= R_to_new_unit
base_vars['F'] *= E_to_new_unit / R_to_new_unit

if lattice is not None:
    base_vars['lattice'] = lattice

base_vars['md5'] = io.dataset_md5(base_vars)
np.savez_compressed(dataset_file_name, **base_vars)
print(ui.color_str('[DONE]', fore_color=ui.GREEN, bold=True))


================================================
FILE: scripts/sgdml_datasets_from_model.py
================================================
#!/usr/bin/python

# MIT License
#
# Copyright (c) 2018 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from __future__ import print_function

import argparse
import os
import sys

import numpy as np

from sgdml.utils import io, ui

parser = argparse.ArgumentParser(
    description='Extracts the training and test data subsets from a dataset that were used to construct a model.'
)
parser.add_argument(
    'model',
    metavar='<model_file>',
    type=lambda x: io.is_file_type(x, 'model'),
    help='path to model file',
)
parser.add_argument(
    'dataset',
    metavar='<dataset_file>',
    type=lambda x: io.is_file_type(x, 'dataset'),
    help='path to dataset file referenced in model',
)
parser.add_argument(
    '-o',
    '--overwrite',
    dest='overwrite',
    action='store_true',
    help='overwrite existing files',
)
args = parser.parse_args()

model_path, model = args.model
dataset_path, dataset = args.dataset


for s in ['train', 'valid']:

    if dataset['md5'] != model['md5_' + s]:
        sys.exit(
            ui.fail_str('[FAIL]')
            + ' Dataset fingerprint does not match the one referenced in model for \'%s\'.'
            % s
        )

    idxs = model['idxs_' + s]
    R = dataset['R'][idxs, :, :]
    E = dataset['E'][idxs]
    F = dataset['F'][idxs, :, :]

    base_vars = {
        'type': 'd',
        'name': dataset['name'].astype(str),
        'theory': dataset['theory'].astype(str),
        'z': dataset['z'],
        'R': R,
        'E': E,
        'F': F,
    }
    base_vars['md5'] = io.dataset_md5(base_vars)

    subset_file_name = '%s_%s.npz' % (
        os.path.splitext(os.path.basename(dataset_path))[0],
        s,
    )
    file_exists = os.path.isfile(subset_file_name)
    if file_exists and args.overwrite:
        print(ui.info_str('[INFO]') + ' Overwriting existing model file.')
    if not file_exists or args.overwrite:
        np.savez_compressed(subset_file_name, **base_vars)
        ui.callback(1, disp_str='Extracted %s dataset saved to \'%s\'' % (s, subset_file_name)) # DONE
    else:
        print(
            ui.warn_str('[WARN]')
            + ' %s dataset \'%s\' already exists.' % (s.capitalize(), subset_file_name)
            + '\n       Run \'python %s -o %s %s\' to overwrite.\n'
            % (os.path.basename(__file__), model_path, dataset_path)
        )
        sys.exit()


================================================
FILE: setup.cfg
================================================
[flake8]
max-complexity = 12
ignore = E501,W503,E741
select = C,E,F,W

[isort]
multi_line_output = 3
include_trailing_comma = 1
line_length = 85
sections = FUTURE,STDLIB,TYPING,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
known_typing = typing, typing_extensions
no_lines_before = TYPING


================================================
FILE: setup.py
================================================
import os
import re
from io import open
from setuptools import setup, find_packages


def get_property(property, package):
    result = re.search(
        r'{}\s*=\s*[\'"]([^\'"]*)[\'"]'.format(property),
        open(package + '/__init__.py').read(),
    )
    return result.group(1)


from os import path

this_dir = path.abspath(path.dirname(__file__))
with open(path.join(this_dir, 'README.md'), encoding='utf8') as f:
    long_description = f.read()

# Scripts
scripts = []
for dirname, dirnames, filenames in os.walk('scripts'):
    for filename in filenames:
        if filename.endswith('.py'):
            scripts.append(os.path.join(dirname, filename))

setup(
    name='sgdml',
    version=get_property('__version__', 'sgdml'),
    description='Reference implementation of the GDML and sGDML force field models.',
    long_description=long_description,
    long_description_content_type='text/markdown',
    classifiers=[
        'Development Status :: 4 - Beta',
        'Environment :: Console',
        'Intended Audience :: Science/Research',
        'Intended Audience :: Education',
        'Intended Audience :: Developers',
        'License :: OSI Approved :: MIT License',
        'Operating System :: MacOS :: MacOS X',
        'Operating System :: POSIX :: Linux',
        'Programming Language :: Python :: 3.7',
        'Topic :: Scientific/Engineering :: Chemistry',
        'Topic :: Scientific/Engineering :: Physics',
        'Topic :: Software Development :: Libraries :: Python Modules',
    ],
    url='http://www.sgdml.org',
    author='Stefan Chmiela',
    author_email='sgdml@chmiela.com',
    license='LICENSE.txt',
    packages=find_packages(),
    install_requires=['torch >= 1.8', 'numpy >= 1.19.0', 'scipy >= 1.1.0', 'psutil', 'future'],
    entry_points={
        'console_scripts': ['sgdml=sgdml.cli:main', 'sgdml-get=sgdml.get:main']
    },
    extras_require={'ase': ['ase >= 3.16.2']},
    scripts=scripts,
    include_package_data=True,
    zip_safe=False,
)


================================================
FILE: sgdml/__init__.py
================================================
#!/usr/bin/python

# MIT License
#
# Copyright (c) 2019-2025 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

__version__ = '1.0.3'

MAX_PRINT_WIDTH = 100
LOG_LEVELNAME_WIDTH = 7  # do not modify

# more descriptive callback status
DONE = 1
NOT_DONE = 0


# Logging

import copy
import logging
import re
import textwrap

from .utils import ui


class ColoredFormatter(logging.Formatter):

    LEVEL_COLORS = {
        'DEBUG': (ui.CYAN, ui.BLACK),
        'INFO': (ui.WHITE, ui.BLACK),
        'DONE': (ui.GREEN, ui.BLACK),
        'WARNING': (ui.YELLOW, ui.BLACK),
        'ERROR': (ui.RED, ui.BLACK),
        'CRITICAL': (ui.BLACK, ui.RED),
    }

    LEVEL_NAMES = {
        'DEBUG': '[DEBG]',
        'INFO': '[INFO]',
        'DONE': '[DONE]',
        'WARNING': '[WARN]',
        'ERROR': '[FAIL]',
        'CRITICAL': '[CRIT]',
    }

    def __init__(self, msg, use_color=True):

        logging.Formatter.__init__(self, msg)
        self.use_color = use_color

    def format(self, record):

        _record = copy.copy(record)
        levelname = _record.levelname
        msg = _record.msg

        levelname = ui.color_str(
            self.LEVEL_NAMES[levelname],
            self.LEVEL_COLORS[levelname][0],
            self.LEVEL_COLORS[levelname][1],
            bold=True,
        )

        if _record.levelname != 'CRITICAL':
            # wrap long messages (except for critical [i.e. exceptions, since they print a formatted traceback string])
            msg = ui.wrap_str(msg)

        # indent multiline strings after the first line
        msg = ui.indent_str(msg, LOG_LEVELNAME_WIDTH)[LOG_LEVELNAME_WIDTH:]

        _record.levelname = levelname
        _record.msg = msg
        return logging.Formatter.format(self, _record)


class ColoredLogger(logging.Logger):
    def __init__(self, name):

        logging.Logger.__init__(self, name, logging.DEBUG)

        # add 'DONE' logging level
        logging.DONE = logging.INFO + 1
        logging.addLevelName(logging.DONE, 'DONE')

        # only display levelname and message
        formatter = ColoredFormatter('%(levelname)s %(message)s')

        # this handler will write to sys.stderr by default
        hd = logging.StreamHandler()
        hd.setFormatter(formatter)
        hd.setLevel(
            logging.INFO
        ) # control logging level here

        self.addHandler(hd)
        return

    def done(self, msg, *args, **kwargs):

        if self.isEnabledFor(logging.DONE):
            self._log(logging.DONE, msg, args, **kwargs)


logging.setLoggerClass(ColoredLogger)


================================================
FILE: sgdml/cli.py
================================================
#!/usr/bin/python

# MIT License
#
# Copyright (c) 2018-2022 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from __future__ import print_function

import logging
import multiprocessing as mp
import argparse
import os
import shutil
import psutil
import sys
import traceback
import time
from functools import partial

import numpy as np
import scipy as sp

try:
    import torch
except ImportError:
    _has_torch = False
else:
    _has_torch = True

try:
    _torch_mps_is_available = torch.backends.mps.is_available()
except AttributeError:
    _torch_mps_is_available = False
_torch_mps_is_available = False

try:
    _torch_cuda_is_available = torch.cuda.is_available()
except AttributeError:
    _torch_cuda_is_available = False

try:
    import ase
except ImportError:
    _has_ase = False
else:
    _has_ase = True

from . import __version__, DONE, NOT_DONE, MAX_PRINT_WIDTH
from .predict import GDMLPredict
from .train import GDMLTrain
from .utils import io, ui

# BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PACKAGE_NAME = 'sgdml'

log = logging.getLogger(__name__)


class AssistantError(Exception):
    pass


def _print_splash(max_memory, max_processes, use_torch):

    logo_str = r"""         __________  __  _____
   _____/ ____/ __ \/  |/  / /
  / ___/ / __/ / / / /|_/ / /
 (__  ) /_/ / /_/ / /  / / /___
/____/\____/_____/_/  /_/_____/"""

    can_update, latest_version = _check_update()

    version_str = __version__
    version_str += (
        ' '
        + ui.color_str(
            ' Latest: ' + latest_version + ' ',
            fore_color=ui.BLACK,
            back_color=ui.YELLOW,
            bold=True,
        )
        if can_update
        else ''
    )

    max_memory_str = '{:d} GB(s) memory'.format(max_memory)
    max_processes_str = '{:d} CPU(s)'.format(max_processes)
    hardware_str = 'using {}, {}'.format(max_memory_str, max_processes_str)

    if use_torch and _has_torch:

        if _torch_cuda_is_available:
            num_gpu = torch.cuda.device_count()
            if num_gpu > 0:
                hardware_str += ', {:d} GPU(s)'.format(num_gpu)
        elif _torch_mps_is_available:
            hardware_str += ', MPS enabled'

    logo_str_split = logo_str.splitlines()
    print('\n'.join(logo_str_split[:-1]))
    ui.print_two_column_str(logo_str_split[-1] + '  ' + version_str, hardware_str)

    # Print update notice.
    if can_update:
        print(
            '\n'
            + ui.color_str(
                ' UPDATE AVAILABLE ',
                fore_color=ui.BLACK,
                back_color=ui.YELLOW,
                bold=True,
            )
            + '\n'
            + '-' * MAX_PRINT_WIDTH
        )
        print(
            'A new stable release version {} of this software is available.'.format(
                latest_version
            )
        )
        print(
            'You can update your installation by running \'pip install sgdml --upgrade\'.'
        )

    _print_billboard()


def _check_update():

    try:
        from urllib.request import urlopen
    except ImportError:
        from urllib2 import urlopen

    base_url = 'http://api.sgdml.org/'
    url = '{}update.php?v={}'.format(base_url, __version__)

    can_update, must_update = '0', '0'
    latest_version = ''
    try:
        response = urlopen(url, timeout=1)
        can_update, must_update, latest_version = response.read().decode().split(',')
        response.close()
    except:
        pass

    return can_update == '1', latest_version


def _print_billboard():

    try:
        from urllib.request import urlopen
    except ImportError:
        from urllib2 import urlopen

    base_url = 'http://api.sgdml.org/'
    url = '{}billboard.php'.format(base_url)

    resp_str = ''
    try:
        response = urlopen(url, timeout=1)
        resp_str = response.read().decode()
        response.close()
    except:
        pass

    bbs = None
    try:
        import json

        bbs = json.loads(resp_str)
    except:
        pass

    if bbs is not None:

        for bb in bbs:

            back_color = ui.WHITE
            if bb['color'] == 'YELLOW':
                back_color = ui.YELLOW
            elif bb['color'] == 'GREEN':
                back_color = ui.GREEN
            elif bb['color'] == 'RED':
                back_color = ui.RED
            elif bb['color'] == 'CYAN':
                back_color = ui.CYAN

            print(
                '\n'
                + ui.color_str(
                    ' {} '.format(bb['title']),
                    fore_color=ui.BLACK,
                    back_color=back_color,
                    bold=True,
                )
                + '\n'
                + '-' * MAX_PRINT_WIDTH
            )

            print(ui.wrap_str(bb['text'], width=MAX_PRINT_WIDTH - 2))


def _print_dataset_properties(dataset, title_str='Dataset properties'):

    print(ui.color_str(title_str, bold=True))

    n_mols, n_atoms, _ = dataset['R'].shape
    print('  {:<18} \'{}\''.format('Name:', ui.unicode_str(dataset['name'])))
    print('  {:<18} \'{}\''.format('Theory level:', ui.unicode_str(dataset['theory'])))
    print('  {:<18} {:<d}'.format('Atoms:', n_atoms))

    print('  {:<18} {:,} data points'.format('Size:', n_mols))

    ui.print_lattice(dataset['lattice'] if 'lattice' in dataset else None)

    if 'perms' in dataset:
        ui.print_two_column_str(
            '  {:<18} {}'.format('Symmetries:', len(dataset['perms'])),
            'This dataset contains precomputed permutations.',
        )

    if 'E' in dataset:

        e_unit = 'unknown unit'
        if 'e_unit' in dataset:
            e_unit = ui.unicode_str(dataset['e_unit'])

        print('  Energies [{}]'.format(e_unit))
        if 'E_min' in dataset and 'E_max' in dataset:
            E_min, E_max = dataset['E_min'], dataset['E_max']
        else:
            E_min, E_max = np.min(dataset['E']), np.max(dataset['E'])
        E_range_str = ui.gen_range_str(E_min, E_max)
        ui.print_two_column_str(
            '    {:<16} {}'.format('Range:', E_range_str), 'min |-- range --| max'
        )

        E_mean = dataset['E_mean'] if 'E_mean' in dataset else np.mean(dataset['E'])
        print('    {:<16} {:<.3f}'.format('Mean:', E_mean))

        E_var = dataset['E_var'] if 'E_var' in dataset else np.var(dataset['E'])
        print('    {:<16} {:<.3f}'.format('Variance:', E_var))
    else:
        print('  {:<18} {}'.format('Energies:', 'n/a'))

    f_unit = 'unknown unit'
    if 'r_unit' in dataset and 'e_unit' in dataset:
        f_unit = (
            ui.unicode_str(dataset['e_unit']) + '/' + ui.unicode_str(dataset['r_unit'])
        )

    print('  Forces [{}]'.format(f_unit))

    if 'F_min' in dataset and 'F_max' in dataset:
        F_min, F_max = dataset['F_min'], dataset['F_max']
    else:
        F_min, F_max = np.min(dataset['F'].ravel()), np.max(dataset['F'].ravel())
    F_range_str = ui.gen_range_str(F_min, F_max)
    ui.print_two_column_str(
        '    {:<16} {}'.format('Range:', F_range_str), 'min |-- range --| max'
    )

    F_mean = dataset['F_mean'] if 'F_mean' in dataset else np.mean(dataset['F'].ravel())
    print('    {:<16} {:<.3f}'.format('Mean:', F_mean))

    F_var = dataset['F_var'] if 'F_var' in dataset else np.var(dataset['F'].ravel())
    print('    {:<16} {:<.3f}'.format('Variance:', F_var))

    print('  {:<18} {}'.format('Fingerprint:', ui.unicode_str(dataset['md5'])))

    # if 'code_version' in dataset:
    #    print('  {:<18} sGDML {}'.format('Created with:', ui.unicode_str(dataset['code_version'])))

    idx = np.random.choice(n_mols, 1)[0]
    r = dataset['R'][idx, :, :]
    e = np.squeeze(dataset['E'][idx]) if 'E' in dataset else None
    f = dataset['F'][idx, :, :]
    lattice = dataset['lattice'] if 'lattice' in dataset else None

    print(
        '\n'
        + ui.color_str('Example geometry', fore_color=ui.WHITE, bold=True)
        + ' (point no. {:,}, chosen randomly)'.format(idx + 1)
    )

    xyz_info_str = 'Copy & paste the string below into Jmol (www.jmol.org), Avogadro (www.avogadro.cc), etc. to visualize one of the geometries from this dataset. A new example will be drawn on each run.'
    xyz_info_str = ui.wrap_str(xyz_info_str, width=MAX_PRINT_WIDTH - 2)
    xyz_info_str = ui.indent_str(xyz_info_str, 2)
    print(xyz_info_str + '\n')

    xyz_str = io.generate_xyz_str(r, dataset['z'], e=e, f=f, lattice=lattice)
    xyz_str = ui.indent_str(xyz_str, 2)

    cut_str = '---- COPY HERE '
    cut_str_reps = int(np.floor((MAX_PRINT_WIDTH - 6) / len(cut_str)))
    cutline_str = ui.color_str(
        '  -' + cut_str * cut_str_reps + '-----', fore_color=ui.GRAY
    )

    print(cutline_str)
    print(xyz_str)
    print(cutline_str)


def _print_task_properties_reduced(
    use_sym, use_E, use_E_cstr, title_str='Task properties'
):

    print(ui.color_str(title_str, bold=True))

    energy_fix_str = (
        (
            'pointwise energy constraints'
            if use_E_cstr
            else 'global integration constant'
        )
        if use_E
        else 'none'
    )
    print('  {:<16} {}'.format('Energy offset:', energy_fix_str))

    print(
        '  {:<16} {}'.format(
            'Symmetries:', 'include (sGDML)' if use_sym else 'ignore (GDML)'
        )
    )


def _print_task_properties(task, title_str='Task properties'):

    print(ui.color_str(title_str, bold=True))

    print('  {:<18}'.format('Dataset'))
    print('    {:<16} \'{}\''.format('Name:', ui.unicode_str(task['dataset_name'])))
    print(
        '    {:<16} \'{}\''.format(
            'Theory level:', ui.unicode_str(task['dataset_theory'])
        )
    )

    n_atoms = len(task['z'])
    print('    {:<16} {:<d}'.format('Atoms:', n_atoms))

    ui.print_lattice(task['lattice'] if 'lattice' in task else None, inset=True)

    print('  {:<18} {:<d}'.format('Symmetries:', len(task['perms'])))

    print('  {:<18}'.format('Hyper-parameters'))
    print('    {:<16} {:<d}'.format('Length scale:', task['sig']))

    if 'lam' in task:
        print('    {:<16} {:<.0e}'.format('Regularization:', task['lam']))

    # if 'solver_name' in task:
    #     print('  {:<18}'.format('Solver configuration'))
    #     print('    {:<16} \'{}\''.format('Type:', task['solver_name']))

    #     if task['solver_name'] == 'cg':

    #         if 'solver_tol' in task:
    #             print('    {:<16} {:<.0e}'.format('Tolerance:', task['solver_tol']))

    #         if 'n_inducing_pts_init' in task:
    #             print(
    #                 '    {:<16} {:<d}'.format(
    #                     'Inducing points:', task['n_inducing_pts_init']
    #                 )
    #             )
    # else:
    #     print('  {:<18} {}'.format('Solver:', 'unknown'))

    n_train = len(task['idxs_train'])
    ui.print_two_column_str(
        '  {:<18} {:,} points'.format('Train on:', n_train),
        'from \'' + ui.unicode_str(task['md5_train']) + '\'',
    )

    n_valid = len(task['idxs_valid'])
    ui.print_two_column_str(
        '  {:<18} {:,} points'.format('Validate on:', n_valid),
        'from \'' + ui.unicode_str(task['md5_valid']) + '\'',
    )

    # print('  {:<18}'.format('Estimated memory requirement (min.)'))

    # mem_kernel_mat_const = 0
    # mem_precond_const = 0
    # print(
    #    '    {:<16} {}'.format(
    #        'CPU:', ui.gen_memory_str(mem_kernel_mat_const + mem_precond_const)
    #    )
    # )
    # print('      {:<14} {}'.format('Kernel matrix:', ui.gen_memory_str(mem_kernel_mat_const))
    # print('      {:<14} {}'.format('Precond. factor:', ui.gen_memory_str(mem_precond_const)))

    # mem_torch_assemble = 0
    # mem_torch_eval = 0
    # print(
    #    '    {:<16} {}'.format(
    #        'GPU:', ui.gen_memory_str(mem_torch_assemble + mem_torch_eval)
    #    )
    # )
    # print('      {:<14} {}'.format('Kernel matrix assembly:', ui.gen_memory_str(mem_torch_assemble)))
    # print('      {:<14} {}'.format('Model evaluation:', ui.gen_memory_str(mem_torch_eval)))


def _print_model_properties(model, title_str='Model properties'):

    print(ui.color_str(title_str, bold=True))

    print('  {:<18}'.format('Dataset'))
    print('    {:<16} \'{}\''.format('Name:', ui.unicode_str(model['dataset_name'])))
    print(
        '    {:<16} \'{}\''.format(
            'Theory level:', ui.unicode_str(model['dataset_theory'])
        )
    )

    n_atoms = len(model['z'])
    print('    {:<16} {:<d}'.format('Atoms:', n_atoms))

    ui.print_lattice(model['lattice'] if 'lattice' in model else None, inset=True)

    print('  {:<18} {:<d}'.format('Symmetries:', len(model['perms'])))

    print('  {:<18}'.format('Hyper-parameters'))
    print('    {:<16} {:<d}'.format('Length scale:', model['sig']))

    if 'lam' in model:
        print('    {:<16} {:<.0e}'.format('Regularization:', model['lam']))

    if 'solver_name' in model:
        print('  {:<18}'.format('Solver'))
        print('    {:<16} \'{}\''.format('Type:', model['solver_name']))

        if model['solver_name'] == 'cg':

            if 'solver_tol' in model:
                ui.print_two_column_str(
                    '    {:<16} {:<.0e}'.format('Tolerance:', model['solver_tol']),
                    'iterate until: norm(K*alpha - y) <= tol*norm(y) = {:<.0e}'.format(
                        model['solver_tol'] * model['norm_y_train']
                    ),
                )

                if 'solver_resid' in model:
                    is_conv = (
                        model['solver_resid']
                        <= model['solver_tol'] * model['norm_y_train']
                    )
                    print(
                        '    {:<16} {:<.0e}{}'.format(
                            'Converged to:',
                            model['solver_resid'],
                            '' if is_conv else ' (NOT CONVERGED)',
                        )
                    )

            if 'solver_iters' in model:
                print('    {:<16} {:<d}'.format('Iterations:', model['solver_iters']))

            if 'inducing_pts_idxs' in model:
                n_inducing_pts = len(model['inducing_pts_idxs']) // (3 * n_atoms)
                ui.print_two_column_str(
                    '    {:<16} {:<d}'.format('Inducing points:', n_inducing_pts),
                    'inducing columns: {:<d} (multiplied by DOF)'.format(
                        n_inducing_pts * n_atoms * 3
                    ),
                )
    else:
        print('  {:<18} {}'.format('Solver:', 'unknown'))

    n_train = len(model['idxs_train'])
    ui.print_two_column_str(
        '  {:<18} {:,} points'.format('Trained on:', n_train),
        'from \'' + ui.unicode_str(model['md5_train']) + '\'',
    )

    use_E_cstr = 'alphas_E' in model
    print(
        '    {:<16} {}'.format(
            'Energy offset',
            '[{}] global integration constant'.format('x' if not use_E_cstr else ' '),
        )
    )
    ui.print_two_column_str(
        '                     {:<16}'.format(
            '[{}] pointwise energy constraints'.format('x' if use_E_cstr else ' ')
        ),
        'using \'--E_cstr\'',
    )

    if model['use_E']:
        e_err = model['e_err'].item()
    f_err = model['f_err'].item()

    n_valid = len(model['idxs_valid'])
    is_valid = not np.isnan(f_err['mae']) and not np.isnan(f_err['rmse'])
    ui.print_two_column_str(
        '  {:<18} {}{:,} points'.format(
            'Validated on:', '' if is_valid else '[pending] ', n_valid
        ),
        'from \'' + ui.unicode_str(model['md5_valid']) + '\'',
    )

    n_test = int(model['n_test'])
    is_test = n_test > 0
    if is_test:
        ui.print_two_column_str(
            '  {:<18} {:,} points'.format('Tested on:', n_test),
            'from \'' + ui.unicode_str(model['md5_test']) + '\'',
        )
    else:
        print('  {:<18} {}'.format('Test:', '[pending]'))

    e_unit = 'unknown unit'
    f_unit = 'unknown unit'
    if 'r_unit' in model and 'e_unit' in model:
        e_unit = model['e_unit']
        f_unit = ui.unicode_str(model['e_unit']) + '/' + ui.unicode_str(model['r_unit'])

    if is_valid:
        action_str = 'Validation' if not is_valid else 'Expected test'
        print('  {:<18}'.format('{} errors (MAE/RMSE)'.format(action_str)))
        if model['use_E']:
            print(
                '    {:<16} {:>.4f}/{:>.4f} [{}]'.format(
                    'Energy:', e_err['mae'], e_err['rmse'], e_unit
                )
            )
        print(
            '    {:<16} {:>.4f}/{:>.4f} [{}]'.format(
                'Forces:', f_err['mae'], f_err['rmse'], f_unit
            )
        )


def _print_next_step(
    prev_step, task_dir=None, model_dir=None, model_files=None, dataset_path=None
):

    if prev_step == 'create':

        assert task_dir is not None

        ui.print_step_title(
            'NEXT STEP',
            '{} train {} <valid_dataset_file>'.format(PACKAGE_NAME, task_dir),
            underscore=False,
        )

    elif prev_step == 'train' or prev_step == 'validate' or prev_step == 'resume':

        assert model_dir is not None and model_files is not None

        if dataset_path is None:
            dataset_path = '<test_dataset_file>'

        n_models = len(model_files)
        if n_models == 1:
            model_file_path = os.path.join(model_dir, model_files[0])
            ui.print_step_title(
                'NEXT STEP',
                '{} test {} {} [<n_test>]'.format(
                    PACKAGE_NAME, model_file_path, dataset_path
                ),
                underscore=False,
            )
        else:
            ui.print_step_title(
                'NEXT STEP',
                '{} select {}'.format(PACKAGE_NAME, model_dir),
                underscore=False,
            )

    elif prev_step == 'select':

        assert model_files is not None

        ui.print_step_title(
            'NEXT STEP',
            '{} test {} <test_dataset_file> [<n_test>]'.format(
                PACKAGE_NAME, model_files[0]
            ),
            underscore=False,
        )

    else:
        raise AssistantError('Unexpected previous step string.')


def all(
    dataset,
    valid_dataset,
    test_dataset,
    n_train,
    n_valid,
    n_test,
    sigs,
    gdml,
    use_E,
    use_E_cstr,
    lazy_training,
    overwrite,
    max_memory,
    max_processes,
    use_torch,
    task_dir=None,
    model_file=None,
    perms_from_arg=None,
    **kwargs
):

    print(
        '\n'
        + ui.color_str(' STEP 0 ', fore_color=ui.BLACK, back_color=ui.WHITE, bold=True)
        + ' Dataset(s)\n'
        + '-' * MAX_PRINT_WIDTH
    )

    _, dataset_extracted = dataset
    _print_dataset_properties(dataset_extracted, title_str='Properties')

    if valid_dataset is None:
        valid_dataset = dataset
    else:
        _, valid_dataset_extracted = valid_dataset
        print()
        _print_dataset_properties(
            valid_dataset_extracted, title_str='Properties (validation dataset)'
        )

        if not np.array_equal(dataset_extracted['z'], valid_dataset_extracted['z']):
            raise AssistantError(
                'Atom composition or order in validation dataset does not match the one in bulk dataset.'
            )

    if test_dataset is None:
        test_dataset = dataset
    else:
        _, test_dataset_extracted = test_dataset
        _print_dataset_properties(
            test_dataset_extracted, title_str='Properties (test dataset)'
        )

        if not np.array_equal(dataset_extracted['z'], test_dataset_extracted['z']):
            raise AssistantError(
                'Atom composition or order in test dataset does not match the one in bulk dataset.'
            )

    ui.print_step_title('STEP 1', 'Cross-validation task creation')
    task_dir = create(
        dataset,
        valid_dataset,
        n_train,
        n_valid,
        sigs,
        gdml,
        use_E,
        use_E_cstr,
        overwrite,
        task_dir,
        perms_from_arg=perms_from_arg,
        **kwargs
    )

    ui.print_step_title('STEP 2', 'Training and validation')
    task_dir_arg = io.is_dir_with_file_type(task_dir, 'task')
    model_dir_or_file_path = train(
        task_dir_arg,
        valid_dataset,
        lazy_training,
        overwrite,
        max_memory,
        max_processes,
        use_torch,
        **kwargs
    )

    model_dir_arg = io.is_dir_with_file_type(
        model_dir_or_file_path, 'model', or_file=True
    )

    _, model_file_names = model_dir_arg
    if len(model_file_names) == 0:
        raise AssistantError(
            'No trained models found!'
            + ('\nTry turning turning off \'--lazy\'-mode.' if lazy_training else '')
        )

    ui.print_step_title('STEP 3', 'Hyper-parameter selection')
    model_file_name = select(model_dir_arg, overwrite, model_file, **kwargs)

    # Have all tasks been trained?
    _, task_file_names = task_dir_arg
    if len(task_file_names) > len(model_file_names):
        log.warning(
            'Not all training tasks have been completed! The model selected here might not be optimal.'
            + ('\nTry turning turning off \'--lazy\'-mode.' if lazy_training else '')
        )

    ui.print_step_title('STEP 4', 'Testing')
    model_dir_arg = io.is_dir_with_file_type(model_file_name, 'model', or_file=True)
    test(
        model_dir_arg,
        test_dataset,
        n_test,
        overwrite=False,
        max_memory=max_memory,
        max_processes=max_processes,
        use_torch=use_torch,
        **kwargs
    )

    print(
        '\n'
        + ui.color_str('  DONE  ', fore_color=ui.BLACK, back_color=ui.GREEN, bold=True)
        + ' Training assistant finished sucessfully.'
    )
    print('         This is your model file: \'{}\''.format(model_file_name))


# if training job exists and is a subset of the requested cv range, add new tasks
# otherwise, if new range is different or smaller, fail
def create(  # noqa: C901
    dataset,
    valid_dataset,
    n_train,
    n_valid,
    sigs,
    gdml,
    use_E,
    use_E_cstr,
    overwrite,
    task_dir=None,
    perms_from_arg=None,
    command=None,
    **kwargs
):

    has_valid_dataset = not (valid_dataset is None or valid_dataset == dataset)

    dataset_path, dataset = dataset
    n_data = dataset['F'].shape[0]

    func_called_directly = (
        command == 'create'
    )  # has this function been called from command line or from 'all'?
    if func_called_directly:
        ui.print_step_title('TASK CREATION')
        _print_dataset_properties(dataset)
        print()

    _print_task_properties_reduced(use_sym=not gdml, use_E=use_E, use_E_cstr=use_E_cstr)
    print()

    if n_data < n_train:
        raise AssistantError(
            'Dataset only contains {} points, can not train on {}.'.format(
                n_data, n_train
            )
        )

    if not has_valid_dataset:
        valid_dataset_path, valid_dataset = dataset_path, dataset
        if n_data - n_train < n_valid:
            raise AssistantError(
                'Dataset only contains {} points, can not train on {} and validate on {}.'.format(
                    n_data, n_train, n_valid
                )
            )
    else:
        valid_dataset_path, valid_dataset = valid_dataset
        n_valid_data = valid_dataset['R'].shape[0]
        if n_valid_data < n_valid:
            raise AssistantError(
                'Validation dataset only contains {} points, can not validate on {}.'.format(
                    n_data, n_valid
                )
            )

    if sigs is None:
        log.info(
            'Kernel hyper-parameter sigma (length scale) was automatically set to range \'10:10:100\'.'
        )
        sigs = list(range(10, 100, 10))  # default range

    if task_dir is None:
        task_dir = io.train_dir_name(
            dataset,
            n_train,
            use_sym=not gdml,
            use_E=use_E,
            use_E_cstr=use_E_cstr,
        )

    task_file_names = []
    if os.path.exists(task_dir):
        if overwrite:
            log.info('Overwriting existing training directory')
            shutil.rmtree(task_dir, ignore_errors=True)
            os.makedirs(task_dir)
        else:
            if io.is_task_dir_resumeable(
                task_dir, dataset, valid_dataset, n_train, n_valid, sigs, gdml
            ):
                log.info(
                    'Resuming existing hyper-parameter search in \'{}\'.'.format(
                        task_dir
                    )
                )

                # Get all task file names.
                try:
                    _, task_file_names = io.is_dir_with_file_type(task_dir, 'task')
                except Exception:
                    pass
            else:
                raise AssistantError(
                    'Unfinished hyper-parameter search found in \'{}\'.\n'.format(
                        task_dir
                    )
                    + 'Run \'%s %s -o %s %d %d -s %s\' to overwrite.'
                    % (
                        PACKAGE_NAME,
                        command,
                        dataset_path,
                        n_train,
                        n_valid,
                        ' '.join(str(s) for s in sigs),
                    )
                )
    else:
        os.makedirs(task_dir)

    if task_file_names:

        with np.load(
            os.path.join(task_dir, task_file_names[0]), allow_pickle=True
        ) as task:
            tmpl_task = dict(task)
    else:
        if not use_E:
            log.info(
                'Energy labels will be ignored for training.\n'
                + 'Note: If available in the dataset file, the energy labels will however still be used to generate stratified training, test and validation datasets. Otherwise a random sampling is used.'
            )

        if 'E' not in dataset:
            log.warning(
                'Training dataset will be sampled with no guidance from energy labels (i.e. randomly)!'
            )

        if 'E' not in valid_dataset:
            log.warning(
                'Validation dataset will be sampled with no guidance from energy labels (i.e. randomly)!\n'
                + 'Note: Larger validation datasets are recommended due to slower convergence of the error.'
            )

        if ('lattice' in dataset) ^ ('lattice' in valid_dataset):
            log.error('One of the datasets specifies lattice vectors and one does not!')
            # TODO: stop program?

        if 'lattice' in dataset or 'lattice' in valid_dataset:
            log.info(
                'Lattice vectors found in dataset: applying periodic boundary conditions.'
            )

        perms = None
        if perms_from_arg is not None:

            _, perms_from = perms_from_arg
            if 'perms' in perms_from:
                perms = perms_from['perms']
            else:
                raise AssistantError(
                    'Provided permutation file does not contain any (looking for \'perms\'-key).'
                )

        gdml_train = (
            GDMLTrain()
        )  # No process number of memory restrictions necessary here.
        try:
            tmpl_task = gdml_train.create_task(
                dataset,
                n_train,
                valid_dataset,
                n_valid,
                sig=1,
                perms=perms,
                use_sym=not gdml,
                use_E=use_E,
                use_E_cstr=use_E_cstr,
                callback=ui.callback,
            )  # template task
        except:
            print()
            log.critical(traceback.format_exc())
            print()
            os._exit(1)

    n_written = 0
    for sig in sigs:
        tmpl_task['sig'] = sig
        task_file_name = io.task_file_name(tmpl_task)
        task_path = os.path.join(task_dir, task_file_name)

        if os.path.isfile(task_path):
            log.info('Skipping existing task \'{}\'.'.format(task_file_name))
        else:
            np.savez_compressed(task_path, **tmpl_task)
            n_written += 1
    if n_written > 0:
        log.done(
            'Writing {:d}/{:d} task(s) with m={} training points each'.format(
                n_written, len(sigs), tmpl_task['R_train'].shape[0]
            )
        )

    if func_called_directly:
        _print_next_step('create', task_dir=task_dir)

    return task_dir


def train(
    task_dir,
    valid_dataset,
    lazy_training,
    overwrite,
    max_memory,
    max_processes,
    use_torch,
    command=None,
    **kwargs
):

    task_dir, task_file_names = task_dir
    n_tasks = len(task_file_names)

    func_called_directly = (
        command == 'train'
    )  # Has this function been called from command line or from 'all'?
    if func_called_directly:
        ui.print_step_title('MODEL TRAINING')

    def save_progr_callback(
        unconv_model, unconv_model_path=None
    ):  # Saves current (unconverged) model during iterative training

        if unconv_model_path is None:
            log.critical(
                'Path for unconverged model not set in \'save_progr_callback\'.'
            )
            print()
            os._exit(1)

        np.savez_compressed(unconv_model_path, **unconv_model)

    try:
        gdml_train = GDMLTrain(
            max_memory=max_memory, max_processes=max_processes, use_torch=use_torch
        )
    except:
        print()
        log.critical(traceback.format_exc())
        print()
        os._exit(1)

    prev_valid_err = -1
    has_converged_once = False

    for i, task_file_name in enumerate(task_file_names):

        task_file_path = os.path.join(task_dir, task_file_name)
        with np.load(task_file_path, allow_pickle=True) as task:

            if n_tasks > 1:
                if i > 0:
                    print()

                n_train = len(task['idxs_train'])
                n_valid = len(task['idxs_valid'])
                ui.print_two_column_str(
                    ui.color_str('Task {:d} of {:d}'.format(i + 1, n_tasks), bold=True),
                    '{:,} + {:,} points (training + validation), sigma (length scale): {}'.format(
                        n_train, n_valid, task['sig']
                    ),
                )

            model_file_name = io.model_file_name(task, is_extended=False)
            model_file_path = os.path.join(task_dir, model_file_name)

            # is_conv = True
            # valid_errs = None
            # is_model_validated = False
            if not overwrite and os.path.isfile(
                model_file_path
            ):  # Train model found, validate if necessary
                log.info(
                    'Model \'{}\' already exists.'.format(model_file_name)
                    + (
                        '\nRun \'{} train -o {}\' to overwrite.'.format(
                            PACKAGE_NAME, task_file_path
                        )
                        if func_called_directly
                        else ''
                    )
                )

                model_path = os.path.join(task_dir, model_file_name)
                _, model = io.is_file_type(model_path, 'model')

                e_err = {'mae': 0.0, 'rmse': 0.0}
                if model['use_E']:
                    e_err = model['e_err'].item()
                f_err = model['f_err'].item()

                is_conv = True
                if 'solver_resid' in model:
                    is_conv = (
                        model['solver_resid']
                        <= model['solver_tol'] * model['norm_y_train']
                    )

                is_model_validated = not (
                    np.isnan(f_err['mae']) or np.isnan(f_err['rmse'])
                )
                if is_model_validated:

                    disp_str = (
                        'energy %.3f/%.3f, ' % (e_err['mae'], e_err['rmse'])
                        if model['use_E']
                        else ''
                    )
                    disp_str += 'forces %.3f/%.3f' % (f_err['mae'], f_err['rmse'])
                    disp_str = 'Validation errors (MAE/RMSE): ' + disp_str
                    ui.callback(1, 1, disp_str=disp_str)

                    valid_errs = [f_err['rmse']]

            else:  # Train and validate model

                # Check if training this task has been attempted before.
                if lazy_training and n_tasks > 1:
                    if 'tried_training' in task and task['tried_training']:
                        log.warning(
                            'Skipping task, because it has been tried before (without success).'
                        )
                        continue

                # Record in task file that there was a training attempt.
                task = dict(task)
                task['tried_training'] = True
                np.savez_compressed(task_file_path, **task)

                n_train, n_atoms = task['R_train'].shape[:2]

                unconv_model_file = '_unconv_{}'.format(model_file_name)
                unconv_model_path = os.path.join(task_dir, unconv_model_file)

                try:
                    model = gdml_train.train(
                        task,
                        partial(
                            save_progr_callback, unconv_model_path=unconv_model_path
                        ),
                        ui.callback,
                    )
                except:
                    print()
                    log.critical(traceback.format_exc())
                    print()
                    os._exit(1)
                else:
                    if func_called_directly:
                        log.done('Writing model to file \'{}\''.format(model_file_path))
                    np.savez_compressed(model_file_path, **model)

                    # Delete temporary model, if one exists.
                    unconv_model_exists = os.path.isfile(unconv_model_path)
                    if unconv_model_exists:
                        os.remove(unconv_model_path)

                is_model_validated = False

            if not is_model_validated:

                if (
                    n_tasks == 1
                ):  # Only validate if there is more than one training task.
                    log.info(
                        'Skipping validation step as there is only one model to validate.'
                    )
                    break

                # Validate model.
                model_dir = (task_dir, [model_file_name])
                valid_errs = test(
                    model_dir,
                    valid_dataset,
                    -1,  # n_test = -1 -> validation mode
                    overwrite,
                    max_memory,
                    max_processes,
                    use_torch,
                    command,
                    **kwargs
                )

                is_conv = True
                if 'solver_resid' in model:
                    is_conv = (
                        model['solver_resid']
                        <= model['solver_tol'] * model['norm_y_train']
                    )

            has_converged_once = has_converged_once or is_conv
            if (
                has_converged_once
                and prev_valid_err != -1
                and prev_valid_err < valid_errs[0]
            ):
                print()
                log.info(
                    'Skipping remaining training tasks, as validation error is rising again.'
                )
                break

            prev_valid_err = valid_errs[0]

    model_dir_or_file_path = model_file_path if n_tasks == 1 else task_dir
    if func_called_directly:

        model_dir_arg = io.is_dir_with_file_type(
            model_dir_or_file_path, 'model', or_file=True
        )
        model_dir, model_files = model_dir_arg
        _print_next_step('train', model_dir=model_dir, model_files=model_files)

    return model_dir_or_file_path  # model directory or file


def _batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx : min(ndx + n, l)]


def _online_err(err, size, n, mae_n_sum, rmse_n_sum):

    err = np.abs(err)

    mae_n_sum += np.sum(err) / size
    mae = mae_n_sum / n

    rmse_n_sum += np.sum(err**2) / size
    rmse = np.sqrt(rmse_n_sum / n)

    return mae, mae_n_sum, rmse, rmse_n_sum


def resume(
    model,
    dataset,
    valid_dataset,
    overwrite,
    max_memory,
    max_processes,
    use_torch,
    command=None,
    **kwargs
):

    model_path, model = model
    dataset_path, dataset = dataset

    valid_dataset_arg = valid_dataset
    valid_dataset_path, valid_dataset = valid_dataset

    ui.print_step_title('RESUME TRAINING')
    _print_model_properties(model, title_str='Model properties (initial)')
    print()

    if dataset['md5'] != model['md5_train']:
        raise AssistantError(
            'Fingerprint of provided training dataset does not match the one specified in model file.'
        )
    if valid_dataset['md5'] != model['md5_valid']:
        raise AssistantError(
            'Fingerprint of provided validation dataset does not match the one specified in model file.'
        )

    if model['solver_name'] == 'analytic':
        raise AssistantError(
            'This model was trained using a matrix decomposition method and thus already converged to the highest possible accuracy! It does not make sense to resume training in this case.'
        )
    elif 'solver_resid' in model and 'solver_tol' in model:
        if model['solver_resid'] > model['solver_tol'] * model['norm_y_train']:

            gdml_train = GDMLTrain(
                max_memory=max_memory, max_processes=max_processes, use_torch=use_torch
            )
            try:
                task = gdml_train.create_task_from_model(
                    model,
                    dataset,
                )
            except:
                print()
                log.critical(traceback.format_exc())
                print()
                os._exit(1)
            del gdml_train

            def save_progr_callback(
                unconv_model,
            ):  # saves current (unconverged) model during iterative training
                np.savez_compressed(model_path, **unconv_model)

            try:
                gdml_train = GDMLTrain(
                    max_memory=max_memory,
                    max_processes=max_processes,
                    use_torch=use_torch,
                )
            except:
                print()
                log.critical(traceback.format_exc())
                print()
                os._exit(1)

            try:
                model = gdml_train.train(
                    task, save_progr_callback=save_progr_callback, callback=ui.callback
                )
            except:
                print()
                log.critical(traceback.format_exc())
                print()
                os._exit(1)
            else:
                log.done('Model parameters have been updated.')
                np.savez_compressed(model_path, **model)

        else:
            log.warning('Model is already converged to the specified tolerance.')

    # Validate model.
    model_dir, model_file_name = os.path.split(model_path)
    model_dir_arg = (model_dir, [model_file_name])

    valid_errs = test(
        model_dir_arg,
        valid_dataset_arg,
        -1,  # n_test = -1 -> validation mode
        overwrite,
        max_memory,
        max_processes,
        use_torch,
        command,
        **kwargs
    )

    _print_next_step('resume', model_dir=model_dir, model_files=[model_file_name])


def validate(
    model_dir,
    valid_dataset,
    overwrite,
    max_memory,
    max_processes,
    use_torch,
    command=None,
    **kwargs
):

    dataset_path_extracted, dataset_extracted = valid_dataset

    func_called_directly = (
        command == 'validate'
    )  # has this function been called from command line or from 'all'?
    if func_called_directly:
        ui.print_step_title('MODEL VALIDATION')
        _print_dataset_properties(dataset_extracted)

    test(
        model_dir,
        valid_dataset,
        -1,  # n_test = -1 -> validation mode
        overwrite,
        max_memory,
        max_processes,
        use_torch,
        command,
        **kwargs
    )

    if func_called_directly:

        model_dir, model_files = model_dir
        n_models = len(model_files)
        _print_next_step('validate', model_dir=model_dir, model_files=model_files)


def test(
    model_dir,
    test_dataset,
    n_test,
    overwrite,
    max_memory,
    max_processes,
    use_torch,
    command=None,
    **kwargs
):  # noqa: C901

    # NOTE: this function runs a validation if n_test < 0 and test with all points if n_test == 0

    model_dir, model_file_names = model_dir
    n_models = len(model_file_names)

    n_test = 0 if n_test is None else n_test
    is_validation = n_test < 0
    is_test = n_test >= 0

    dataset_path, dataset = test_dataset

    func_called_directly = (
        command == 'test'
    )  # has this function been called from command line or from 'all'?
    if func_called_directly:
        ui.print_step_title('MODEL TEST')
        _print_dataset_properties(dataset)

    F_rmse = []

    # NEW

    DEBUG_WRITE = False

    if DEBUG_WRITE:
        if os.path.exists('test_pred.xyz'):
            os.remove('test_pred.xyz')
        if os.path.exists('test_ref.xyz'):
            os.remove('test_ref.xyz')
        if os.path.exists('test_diff.xyz'):
            os.remove('test_diff.xyz')

    # NEW

    num_workers, batch_size = -1, -1
    gdml_train = None
    for i, model_file_name in enumerate(model_file_names):

        model_path = os.path.join(model_dir, model_file_name)
        _, model = io.is_file_type(model_path, 'model')

        if i == 0 and command != 'all':
            print()
            _print_model_properties(model)
            print()

        if not np.array_equal(model['z'], dataset['z']):
            raise AssistantError(
                'Atom composition or order in dataset does not match the one in model.'
            )

        if ('lattice' in model) is not ('lattice' in dataset):
            if 'lattice' in model:
                raise AssistantError(
                    'Model contains lattice vectors, but dataset does not.'
                )
            elif 'lattice' in dataset:
                raise AssistantError(
                    'Dataset contains lattice vectors, but model does not.'
                )

        if model['use_E']:
            e_err = model['e_err'].item()
        f_err = model['f_err'].item()

        is_model_validated = not (np.isnan(f_err['mae']) or np.isnan(f_err['rmse']))

        if n_models > 1:
            if i > 0:
                print()
            print(
                ui.color_str(
                    '%s model %d of %d'
                    % ('Testing' if is_test else 'Validating', i + 1, n_models),
                    bold=True,
                )
            )

        if is_validation:
            if is_model_validated and not overwrite:
                log.info(
                    'Skipping already validated model \'{}\'.'.format(model_file_name)
                    + (
                        '\nRun \'{} validate -o {} {}\' to overwrite.'.format(
                            PACKAGE_NAME, model_path, dataset_path
                        )
                        if command == 'test'
                        else ''
                    )
                )
                continue

            if dataset['md5'] != model['md5_valid']:
                raise AssistantError(
                    'Fingerprint of provided validation dataset does not match the one specified in model file.'
                )

        test_idxs = model['idxs_valid']
        if is_test:

            # exclude training and/or test sets from validation set if necessary
            excl_idxs = np.empty((0,), dtype=np.uint)
            if dataset['md5'] == model['md5_train']:
                excl_idxs = np.concatenate([excl_idxs, model['idxs_train']]).astype(
                    np.uint
                )
            if dataset['md5'] == model['md5_valid']:
                excl_idxs = np.concatenate([excl_idxs, model['idxs_valid']]).astype(
                    np.uint
                )

            n_data = dataset['F'].shape[0]
            n_data_eff = n_data - len(excl_idxs)

            if (
                n_test == 0 and n_data_eff != 0
            ):  # test on all data points that have not been used for training or testing
                n_test = n_data_eff
                log.info(
                    'Test set size was automatically set to {:,} points.'.format(n_test)
                )

            if n_test == 0 or n_data_eff == 0:
                log.warning('Skipping! No unused points for test in provided dataset.')
                return
            elif n_data_eff < n_test:
                n_test = n_data_eff
                log.warning(
                    'Test size reduced to {:d}. Not enough unused points in provided dataset.'.format(
                        n_test
                    )
                )

            if 'E' in dataset:
                if gdml_train is None:
                    gdml_train = GDMLTrain(
                        max_memory=max_memory, max_processes=max_processes
                    )
                test_idxs = gdml_train.draw_strat_sample(
                    dataset['E'], n_test, excl_idxs=excl_idxs
                )
            else:
                test_idxs = np.delete(np.arange(n_data), excl_idxs)

                log.warning(
                    'Test dataset will be sampled with no guidance from energy labels (randomly)!\n'
                    + 'Note: Larger test datasets are recommended due to slower convergence of the error.'
                )
        # shuffle to improve convergence of online error
        np.random.shuffle(test_idxs)

        # NEW
        if DEBUG_WRITE:
            test_idxs = np.sort(test_idxs)

        z = dataset['z']
        R = dataset['R'][test_idxs, :, :]
        F = dataset['F'][test_idxs, :, :]

        if model['use_E']:
            E = dataset['E'][test_idxs]

        try:
            gdml_predict = GDMLPredict(
                model,
                max_memory=max_memory,
                max_processes=max_processes,
                use_torch=use_torch,
            )
        except:
            print()
            log.critical(traceback.format_exc())
            print()
            os._exit(1)

        b_size = min(1000, len(test_idxs))

        if not use_torch:
            if num_workers == -1 or batch_size == -1:
                ui.callback(NOT_DONE, disp_str='Optimizing parallelism')

                gps, is_from_cache = gdml_predict.prepare_parallel(
                    n_bulk=b_size, return_is_from_cache=True
                )
                num_workers, chunk_size, bulk_mp = (
                    gdml_predict.num_workers,
                    gdml_predict.chunk_size,
                    gdml_predict.bulk_mp,
                )

                sec_disp_str = 'no chunking'.format(chunk_size)
                if chunk_size != gdml_predict.n_train:
                    sec_disp_str = 'chunks of {:d}'.format(chunk_size)

                if num_workers == 0:
                    sec_disp_str = 'no workers / ' + sec_disp_str
                else:
                    sec_disp_str = (
                        '{:d} workers {}/ '.format(
                            num_workers, '[MP] ' if bulk_mp else ''
                        )
                        + sec_disp_str
                    )

                ui.callback(
                    DONE,
                    disp_str='Optimizing parallelism'
                    + (' (from cache)' if is_from_cache else ''),
                    sec_disp_str=sec_disp_str,
                )
            else:
                gdml_predict._set_num_workers(num_workers)
                gdml_predict._set_chunk_size(chunk_size)
                gdml_predict._set_bulk_mp(bulk_mp)

        n_atoms = z.shape[0]

        if model['use_E']:
            e_mae_sum, e_rmse_sum = 0, 0
        f_mae_sum, f_rmse_sum = 0, 0
        cos_mae_sum, cos_rmse_sum = 0, 0
        mag_mae_sum, mag_rmse_sum = 0, 0

        n_done = 0
        t = time.time()
        for b_range in _batch(list(range(len(test_idxs))), b_size):

            n_done_step = len(b_range)
            n_done += n_done_step

            r = R[b_range].reshape(n_done_step, -1)
            e_pred, f_pred = gdml_predict.predict(r)

            # energy error
            if model['use_E']:
                e = E[b_range]
                e_mae, e_mae_sum, e_rmse, e_rmse_sum = _online_err(
                    np.squeeze(e) - e_pred, 1, n_done, e_mae_sum, e_rmse_sum
                )

                # import matplotlib.pyplot as plt
                # plt.hist(np.squeeze(e) - e_pred)
                # plt.show()

            # force component error
            f = F[b_range].reshape(n_done_step, -1)
            f_mae, f_mae_sum, f_rmse, f_rmse_sum = _online_err(
                f - f_pred, 3 * n_atoms, n_done, f_mae_sum, f_rmse_sum
            )

            # magnitude error
            f_pred_mags = np.linalg.norm(f_pred.reshape(-1, 3), axis=1)
            f_mags = np.linalg.norm(f.reshape(-1, 3), axis=1)
            mag_mae, mag_mae_sum, mag_rmse, mag_rmse_sum = _online_err(
                f_pred_mags - f_mags, n_atoms, n_done, mag_mae_sum, mag_rmse_sum
            )

            # normalized cosine error
            f_pred_norm = f_pred.reshape(-1, 3) / f_pred_mags[:, None]
            f_norm = f.reshape(-1, 3) / f_mags[:, None]
            cos_err = (
                np.arccos(np.clip(np.einsum('ij,ij->i', f_pred_norm, f_norm), -1, 1))
                / np.pi
            )
            cos_mae, cos_mae_sum, cos_rmse, cos_rmse_sum = _online_err(
                cos_err, n_atoms, n_done, cos_mae_sum, cos_rmse_sum
            )

            # NEW

            if is_test and DEBUG_WRITE:

                try:
                    with open('test_pred.xyz', 'a') as file:

                        n = r.shape[0]
                        for i, ri in enumerate(r):

                            r_out = ri.reshape(-1, 3)
                            e_out = e_pred[i]
                            f_out = f_pred[i].reshape(-1, 3)

                            ext_xyz_str = (
                                io.generate_xyz_str(r_out, model['z'], e=e_out, f=f_out)
                                + '\n'
                            )

                            file.write(ext_xyz_str)

                except IOError:
                    sys.exit("ERROR: Writing xyz file failed.")

                try:
                    with open('test_ref.xyz', 'a') as file:

                        n = r.shape[0]
                        for i, ri in enumerate(r):

                            r_out = ri.reshape(-1, 3)
                            e_out = (
                                None
                                if not model['use_E']
                                else np.squeeze(E[b_range][i])
                            )
                            f_out = f[i].reshape(-1, 3)

                            ext_xyz_str = (
                                io.generate_xyz_str(r_out, model['z'], e=e_out, f=f_out)
                                + '\n'
                            )
                            file.write(ext_xyz_str)

                except IOError:
                    sys.exit("ERROR: Writing xyz file failed.")

                try:
                    with open('test_diff.xyz', 'a') as file:

                        n = r.shape[0]
                        for i, ri in enumerate(r):

                            r_out = ri.reshape(-1, 3)
                            e_out = (
                                None
                                if not model['use_E']
                                else (np.squeeze(E[b_range][i]) - e_pred[i])
                            )
                            f_out = (f[i] - f_pred[i]).reshape(-1, 3)

                            ext_xyz_str = (
                                io.generate_xyz_str(r_out, model['z'], e=e_out, f=f_out)
                                + '\n'
                            )
                            file.write(ext_xyz_str)

                except IOError:
                    sys.exit("ERROR: Writing xyz file failed.")

            # NEW

            sps = n_done / (time.time() - t)  # examples per second
            disp_str = 'energy %.3f/%.3f, ' % (e_mae, e_rmse) if model['use_E'] else ''
            disp_str += 'forces %.3f/%.3f' % (f_mae, f_rmse)
            disp_str = (
                '{} errors (MAE/RMSE): '.format('Test' if is_test else 'Validation')
                + disp_str
            )
            sec_disp_str = '@ %.1f geo/s' % sps if b_range is not None else ''

            ui.callback(
                n_done,
                len(test_idxs),
                disp_str=disp_str,
                sec_disp_str=sec_disp_str,
                newline_when_done=False,
            )

        if is_test:
            ui.callback(
                DONE,
                disp_str='Testing on {:,} points'.format(n_test),
                sec_disp_str=sec_disp_str,
            )
        else:
            ui.callback(DONE, disp_str=disp_str, sec_disp_str=sec_disp_str)

        if model['use_E']:
            e_rmse_pct = (e_rmse / e_err['rmse'] - 1.0) * 100
        f_rmse_pct = (f_rmse / f_err['rmse'] - 1.0) * 100

        if is_test and n_models == 1:
            n_train = len(model['idxs_train'])
            n_valid = len(model['idxs_valid'])
            print()
            ui.print_two_column_str(
                ui.color_str('Test errors (MAE/RMSE)', bold=True),
                '{:,} + {:,} points (training + validation), sigma (length scale): {}'.format(
                    n_train, n_valid, model['sig']
                ),
            )

            r_unit = 'unknown unit'
            e_unit = 'unknown unit'
            f_unit = 'unknown unit'
            if 'r_unit' in dataset and 'e_unit' in dataset:
                r_unit = dataset['r_unit']
                e_unit = dataset['e_unit']
                f_unit = str(dataset['e_unit']) + '/' + str(dataset['r_unit'])

            format_str = '  {:<18} {:>.4f}/{:>.4f} [{}]'
            if model['use_E']:
                ui.print_two_column_str(
                    format_str.format('Energy:', e_mae, e_rmse, e_unit),
                    'relative to expected: {:+.1f}%'.format(e_rmse_pct),
                )

            ui.print_two_column_str(
                format_str.format('Forces:', f_mae, f_rmse, f_unit),
                'relative to expected: {:+.1f}%'.format(f_rmse_pct),
            )

            print(format_str.format('  Magnitude:', mag_mae, mag_rmse, r_unit))
            ui.print_two_column_str(
                format_str.format('  Angle:', cos_mae, cos_rmse, '0-1'),
                'lower is better',
            )
            print()

        model_mutable = dict(model)
        model.close()
        model = model_mutable

        model_needs_update = (
            overwrite
            or (is_test and model['n_test'] < len(test_idxs))
            or (is_validation and not is_model_validated)
        )
        if model_needs_update:

            if is_validation and overwrite:
                model['n_test'] = 0  # flag the model as not tested

            if is_test:
                model['n_test'] = len(test_idxs)
                model['md5_test'] = dataset['md5']

            if model['use_E']:
                model['e_err'] = {
                    'mae': e_mae.item(),
                    'rmse': e_rmse.item(),
                }

            model['f_err'] = {'mae': f_mae.item(), 'rmse': f_rmse.item()}
            np.savez_compressed(model_path, **model)

            if is_test and model['n_test'] > 0:
                log.info('Expected errors were updated in model file.')

        else:
            add_info_str = (
                'the same number of'
                if model['n_test'] == len(test_idxs)
                else 'only {:,}'.format(len(test_idxs))
            )
            log.warning(
                'This model has previously been tested on {:,} points, which is why the errors for the current test run with {} points have NOT been used to update the model file.\n'.format(
                    model['n_test'], add_info_str
                )
                + 'Run \'{} test -o {} {} {}\' to overwrite.'.format(
                    PACKAGE_NAME, os.path.relpath(model_path), dataset_path, n_test
                )
            )

        F_rmse.append(f_rmse)

    return F_rmse


def select(model_dir, overwrite, model_file=None, command=None, **kwargs):  # noqa: C901

    func_called_directly = (
        command == 'select'
    )  # has this function been called from command line or from 'all'?
    if func_called_directly:
        ui.print_step_title('MODEL SELECTION')

    any_model_not_validated = False
    any_model_is_tested = False

    model_dir, model_file_names = model_dir
    if len(model_file_names) > 1:

        use_E = True

        rows = []
        data_names = ['sig', 'MAE', 'RMSE', 'MAE', 'RMSE']
        for i, model_file_name in enumerate(model_file_names):
            model_path = os.path.join(model_dir, model_file_name)
            _, model = io.is_file_type(model_path, 'model')

            use_E = model['use_E']

            if i == 0:
                idxs_train = set(model['idxs_train'])
                md5_train = model['md5_train']
                idxs_valid = set(model['idxs_valid'])
                md5_valid = model['md5_valid']
            else:
                if (
                    md5_train != model['md5_train']
                    or md5_valid != model['md5_valid']
                    or idxs_train != set(model['idxs_train'])
                    or idxs_valid != set(model['idxs_valid'])
                ):
                    raise AssistantError(
                        '{} contains models trained or validated on different datasets.'.format(
                            model_dir
                        )
                    )

            e_err = {'mae': 0.0, 'rmse': 0.0}
            if model['use_E']:
                e_err = model['e_err'].item()
            f_err = model['f_err'].item()

            is_model_validated = not (np.isnan(f_err['mae']) or np.isnan(f_err['rmse']))
            if not is_model_validated:
                any_model_not_validated = True

            is_model_tested = model['n_test'] > 0
            if is_model_tested:
                any_model_is_tested = True

            rows.append(
                [model['sig'], e_err['mae'], e_err['rmse'], f_err['mae'], f_err['rmse']]
            )

            model.close()

        if any_model_not_validated:
            log.warning(
                'One or more models in the given directory have not been validated.'
            )
            print()

        if any_model_is_tested:
            log.error(
                'One or more models in the given directory have already been tested. This means that their recorded expected errors are test errors, not validation errors. However, one should never perform model selection based on the test error!\n'
                + 'Please run the validation command (again) with the overwrite option \'-o\', then this selection command.'
            )
            return

        f_rmse_col = [row[4] for row in rows]
        best_idx = f_rmse_col.index(min(f_rmse_col))  # idx of row with lowest f_rmse
        best_sig = rows[best_idx][0]

        rows = sorted(rows, key=lambda col: col[0])  # sort according to sigma
        print(ui.color_str('Cross-validation errors', bold=True))
        print(' ' * 7 + 'Energy' + ' ' * 6 + 'Forces')
        print((' {:>3} ' + '{:>5} ' * 4).format(*data_names))
        print(' ' + '-' * 27)
        format_str = ' {:>3} ' + '{:5.2f} ' * 4
        format_str_no_E = ' {:>3}     -     - ' + '{:5.2f} ' * 2
        for row in rows:
            if use_E:
                row_str = format_str.format(*row)
            else:
                row_str = format_str_no_E.format(*[row[0], row[3], row[4]])

            if row[0] != best_sig:
                row_str = ui.color_str(row_str, fore_color=ui.GRAY)
            print(row_str)
        print()

        sig_col = [row[0] for row in rows]
        if best_sig == min(sig_col) or best_sig == max(sig_col):
            log.warning(
                'The optimal sigma (length scale) lies on the boundary of the search grid.\n'
                + 'Model performance might improve if the search grid is extended in direction sigma {} {:d}.'.format(
                    '<' if best_idx == 0 else '>', best_sig
                )
            )

    else:  # only one model available
        log.info('Skipping model selection step as there is only one model to select.')

        best_idx = 0

    best_model_path = os.path.join(model_dir, model_file_names[best_idx])

    if model_file is None:

        # generate model file name based on model properties
        best_model = np.load(best_model_path, allow_pickle=True)
        model_file = io.model_file_name(best_model, is_extended=True)
        best_model.close()

    model_exists = os.path.isfile(model_file)
    if model_exists and overwrite:
        log.info('Overwriting existing model file.')

    if not model_exists or overwrite:
        if func_called_directly:
            log.done('Writing model file \'{}\''.format(model_file))

        shutil.copy(best_model_path, model_file)
        shutil.rmtree(model_dir, ignore_errors=True)
    else:
        log.warning(
            'Model \'{}\' already exists.\n'.format(model_file)
            + 'Run \'{} select -o {}\' to overwrite.'.format(
                PACKAGE_NAME, os.path.relpath(model_dir)
            )
        )

    if func_called_directly:
        _print_next_step('select', model_files=[model_file])

    return model_file


def show(file, command=None, **kwargs):

    ui.print_step_title('SHOW DETAILS')
    file_path, file = file

    if file['type'].astype(str) == 'd':
        _print_dataset_properties(file)

    if file['type'].astype(str) == 't':
        _print_task_properties(file)

    if file['type'].astype(str) == 'm':
        _print_model_properties(file)


def reset(command=None, **kwargs):

    if ui.yes_or_no('\nDo you really want to purge all caches and temporary files?'):

        pkg_dir = os.path.dirname(os.path.abspath(__file__))
        bmark_file = '_bmark_cache.npz'
        bmark_path = os.path.join(pkg_dir, bmark_file)

        if os.path.exists(bmark_path):
            try:
                os.remove(bmark_path)
            except OSError:
                print()
                log.critical('Exception: unable to delete benchmark cache.')
                print()
                os._exit(1)

            log.done('Benchmark cache deleted.')
        else:
            log.info('Benchmark cache was already empty.')
    else:
        print(' Cancelled.')


def main():
    def _add_argument_sample_size(parser, subset_str):
        subparser.add_argument(
            'n_%s' % subset_str,
            metavar='<n_%s>' % subset_str,
            type=io.is_strict_pos_int,
            help='%s sample size' % subset_str,
        )

    def _add_argument_dir_with_file_type(parser, type, or_file=False):
        parser.add_argument(
            '%s_dir' % type,
            metavar='<%s_dir%s>' % (type, '_or_file' if or_file else ''),
            type=lambda x: io.is_dir_with_file_type(x, type, or_file=or_file),
            help='path to %s directory%s' % (type, ' or file' if or_file else ''),
        )

    # Available resources
    total_memory = psutil.virtual_memory().total // 2**30
    total_cpus = mp.cpu_count()

    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--version',
        action='version',
        version='%(prog)s '
        + __version__
        + ' [Python {}, NumPy {}, SciPy {}'.format(
            '.'.join(map(str, sys.version_info[:3])), np.__version__, sp.__version__
        )
        + ', PyTorch {}'.format(torch.__version__ if _has_torch else 'N/A')
        + ', ASE {}'.format(ase.__version__ if _has_ase else 'N/A')
        + ']',
    )

    parent_parser = argparse.ArgumentParser(add_help=False)

    subparsers = parser.add_subparsers(title='commands', dest='command')
    subparsers.required = True
    parser_all = subparsers.add_parser(
        'all',
        help='reconstruct a force field from beginning to end',
        parents=[parent_parser],
    )
    parser_create = subparsers.add_parser(
        'create', help='create training task(s)', parents=[parent_parser]
    )
    parser_train = subparsers.add_parser(
        'train', help='train model(s) from task(s)', parents=[parent_parser]
    )
    parser_resume = subparsers.add_parser(
        'resume', help='resume training of a model', parents=[parent_parser]
    )
    parser_valid = subparsers.add_parser(
        'validate', help='validate model(s)', parents=[parent_parser]
    )
    parser_select = subparsers.add_parser(
        'select', help='select best performing model', parents=[parent_parser]
    )
    parser_test = subparsers.add_parser(
        'test', help='test a model', parents=[parent_parser]
    )
    parser_show = subparsers.add_parser(
        'show',
        help='print details about a dataset, task or model file',
        parents=[parent_parser],
    )
    subparsers.add_parser(
        'reset', help='delete all caches and temporary files', parents=[parent_parser]
    )

    for subparser in [parser_all, parser_create]:

        subparser.add_argument(
            'dataset',
            metavar='<dataset_file>',
            type=lambda x: io.is_file_type(x, 'dataset'),
            help='path to dataset file (train/validation/test subsets are sampled from here if no seperate dataset are specified)',
        )

        _add_argument_sample_size(subparser, 'train')
        _add_argument_sample_size(subparser, 'valid')
        subparser.add_argument(
            '-v',
            '--validation_dataset',
            metavar='<valid_dataset_file>',
            dest='valid_dataset',
            type=lambda x: io.is_file_type(x, 'dataset'),
            help='path to separate validation dataset file',
        )
        subparser.add_argument(
            '-t',
            '--test_dataset',
            metavar='<test_dataset_file>',
            dest='test_dataset',
            type=lambda x: io.is_file_type(x, 'dataset'),
            help='path to separate test dataset file',
        )
        subparser.add_argument(
            '-s',
            '--sig',
            metavar=('<s1>', '<s2>'),
            dest='sigs',
            type=io.parse_list_or_range,
            help='integer list and/or range <start>:[<step>:]<stop> for the kernel hyper-parameter sigma (length scale)',
            nargs='+',
        )

        group = subparser.add_mutually_exclusive_group()
        group.add_argument(
            '--gdml',
            action='store_true',
            help='don\'t include symmetries in the model (GDML)',
        )

        group.add_argument(
            '--perms_from',
            metavar='<file>',
            dest='perms_from_arg',
            type=lambda x: io.is_valid_file_type(x),
            help='path to file to take permutations from (key: \'perms\')',
        )

        group = subparser.add_mutually_exclusive_group()
        group.add_argument(
            '--no_E',
            dest='use_E',
            action='store_false',
            help='only reconstruct force field w/o potential energy surface',
        )
        group.add_argument(
            '--E_cstr',
            dest='use_E_cstr',
            action='store_true',
            help='include pointwise energy constraints',
        )

        subparser.add_argument(
            '--task_dir',
            metavar='<task_dir>',
            dest='task_dir',
            help='user-defined task output dir name',
        )

    for subparser in [parser_all, parser_select]:
        subparser.add_argument(
            '--model_file',
            metavar='<model_file>',
            dest='model_file',
            help='user-defined model output file name',
        )

    for subparser in [parser_all, parser_train]:
        subparser.add_argument(
            '--lazy',
            dest='lazy_training',
            action='store_true',
            help='give up on unfinished tasks (if more than one)',
        )

    for subparser in [parser_valid, parser_test]:
        _add_argument_dir_with_file_type(subparser, 'model', or_file=True)

    parser_valid.add_argument(
        'valid_dataset',
        metavar='<valid_dataset_file>',
        type=lambda x: io.is_file_type(x, 'dataset'),
        help='path to validation dataset file',
    )
    parser_test.add_argument(
        'test_dataset',
        metavar='<test_dataset_file>',
        type=lambda x: io.is_file_type(x, 'dataset'),
        help='path to test dataset file',
    )

    for subparser in [parser_all, parser_test]:
        subparser.add_argument(
            'n_test',
            metavar='<n_test>',
            type=io.is_strict_pos_int,
            help='test sample size',
            nargs='?',
            default=None,
        )

    parser_resume.add_argument(
        'model',
        metavar='<model_file>',
        type=lambda x: io.is_file_type(x, 'model'),
        help='path to model file to complete training for',
    )
    parser_resume.add_argument(
        'dataset',
        metavar='<train_dataset_file>',
        type=lambda x: io.is_file_type(x, 'dataset'),
        help='path to original training dataset file',
    )

    _add_argument_dir_with_file_type(parser_train, 'task', or_file=True)

    for subparser in [parser_train, parser_resume]:
        subparser.add_argument(
            'valid_dataset',
            metavar='<valid_dataset_file>',
            type=lambda x: io.is_file_type(x, 'dataset'),
            help='path to validation dataset file',
        )

    _add_argument_dir_with_file_type(parser_select, 'model')

    parser_show.add_argument(
        'file',
        metavar='<file>',
        type=lambda x: io.is_valid_file_type(x),
        help='path to dataset, task or model file',
    )

    for subparser in [
        parser_all,
        parser_train,
        parser_resume,
        parser_valid,
        parser_test,
    ]:

        subparser.add_argument(
            '-m',
            '--max_memory',
            metavar='<max_memory>',
            type=int,
            help='limit memory usage (whenever possible) [GB]',
            choices=range(1, total_memory + 1),
            default=total_memory,
        )

        subparser.add_argument(
            '-p',
            '--max_processes',
            metavar='<max_processes>',
            type=int,
            help='limit number of processes',
            choices=range(1, total_cpus + 1),
            default=total_cpus,
        )

        subparser.add_argument(
            '--cpu',
            dest='use_torch',
            action='store_false',
            help='use CPU implementation (no PyTorch dependency)',
        )

    for subparser in [
        parser_all,
        parser_create,
        parser_train,
        parser_resume,
        parser_valid,
        parser_select,
        parser_test,
    ]:
        subparser.add_argument(
            '-o',
            '--overwrite',
            dest='overwrite',
            action='store_true',
            help='overwrite existing files',
        )

    args = parser.parse_args()

    # Post-processing for optional sig argument
    if 'sigs' in args and args.sigs is not None:
        args.sigs = np.hstack(
            args.sigs
        ).tolist()  # Flatten list, if (part of it) was generated using the range syntax
        args.sigs = sorted(list(set(args.sigs)))  # remove potential duplicates

    # Post-processing for optional model output file argument
    if 'model_file' in args and args.model_file is not None:
        if not args.model_file.endswith('.npz'):
            args.model_file += '.npz'

    # Check PyTorch GPU support.
    if ('use_torch' in args and args.use_torch) or 'use_torch' not in args:
        if _has_torch:
            if not (_torch_cuda_is_available or _torch_mps_is_available):
                print()  # TODO: print only if log level includes warning
                log.warning(
                    'Your PyTorch installation does not see any GPU(s) on your system and will thus run all calculations on the CPU! If this is what you want, we recommend bypassing PyTorch using \'--cpu\' for improved performance.'
                )
        else:
            print()
            log.critical(
                'PyTorch dependency not found! Please install or use \'--cpu\' to bypass PyTorch and run everything the CPU.'
            )
            print()
            os._exit(1)

    args = vars(args)

    _print_splash(
        args['max_memory'] if 'max_memory' in args else total_memory,
        args['max_processes'] if 'max_processes' in args else total_cpus,
        args['use_torch'] if 'use_torch' in args else True,
    )

    try:
        getattr(sys.modules[__name__], args['command'])(**args)
    except AssistantError as err:
        log.error(str(err))
        print()
        os._exit(1)
    except:
        log.critical(traceback.format_exc())
        print()
        os._exit(1)
    print()


if __name__ == "__main__":
    main()


================================================
FILE: sgdml/get.py
================================================
#!/usr/bin/python

# MIT License
#
# Copyright (c) 2018-2023 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from __future__ import print_function

import argparse
import os
import re
import sys

from . import __version__
from .utils import ui

if sys.version[0] == '3':
    raw_input = input

try:
    from urllib.request import urlopen
except ImportError:
    from urllib2 import urlopen


def download(command, file_name):

    base_url = 'http://www.quantum-machine.org/gdml/' + (
        'data/npz/' if command == 'dataset' else 'models/'
    )
    request = urlopen(base_url + file_name)
    file = open(file_name, 'wb')
    filesize = int(request.headers['Content-Length'])

    size = 0
    block_sz = 1024
    while True:
        buffer = request.read(block_sz)
        if not buffer:
            break
        size += len(buffer)
        file.write(buffer)

        ui.callback(
            size,
            filesize,
            disp_str='Downloading: {}'.format(file_name),
            sec_disp_str='{:,} bytes'.format(filesize),
        )
    file.close()


def main():

    base_url = 'http://www.quantum-machine.org/gdml/'

    parser = argparse.ArgumentParser()

    parent_parser = argparse.ArgumentParser(add_help=False)
    parent_parser.add_argument(
        '-o',
        '--overwrite',
        dest='overwrite',
        action='store_true',
        help='overwrite existing files',
    )

    subparsers = parser.add_subparsers(title='commands', dest='command')
    subparsers.required = True
    parser_dataset = subparsers.add_parser(
        'dataset', help='download benchmark dataset', parents=[parent_parser]
    )
    parser_model = subparsers.add_parser(
        'model', help='download pre-trained model', parents=[parent_parser]
    )

    for subparser in [parser_dataset, parser_model]:
        subparser.add_argument(
            'name',
            metavar='<name>',
            type=str,
            help='item name',
            nargs='?',
            default=None,
        )

    args = parser.parse_args()

    print("Contacting server (%s)..." % base_url)

    if args.name is not None:

        url = '%sget.php?version=%s&%s=%s' % (
            base_url,
            __version__,
            args.command,
            args.name,
        )
        response = urlopen(url)
        match, score = response.read().decode().split(',')
        response.close()

        if int(score) == 0 or ui.yes_or_no('Do you mean \'%s\'?' % match):
            download(args.command, match + '.npz')
            return

    response = urlopen(
        '%sget.php?version=%s&%s' % (base_url, __version__, args.command)
    )
    line = response.readlines()
    response.close()

    print()
    print('Available %ss:' % args.command)

    print('{:<2} {:<31}    {:>4}'.format('ID', 'Name', 'Size'))
    print('-' * 42)

    items = line[0].split(b';')
    for i, item in enumerate(items):
        name, size = item.split(b',')
        size = int(size) / 1024**2  # Bytes to MBytes

        print('{:>2d} {:<30} {:>5.1f} MB'.format(i, name.decode("utf-8"), size))
    print()

    down_list = raw_input(
        'Please list which %ss to download (e.g. 0 1 2 6) or type \'all\': '
        % args.command
    )
    down_idxs = []
    if 'all' in down_list.lower():
        down_idxs = list(range(len(items)))
    elif re.match(
        "^ *[0-9][0-9 ]*$", down_list
    ):  # only digits and spaces, at least one digit
        down_idxs = [int(idx) for idx in re.split(r'\s+', down_list.strip())]
        down_idxs = list(set(down_idxs))
    else:
        print(ui.color_str('ABORTED', fore_color=ui.RED, bold=True))

    for idx in down_idxs:
        if idx not in range(len(items)):
            print(
                ui.color_str('[WARN]', fore_color=ui.YELLOW, bold=True)
                + ' Index '
                + str(idx)
                + ' out of range, skipping.'
            )
        else:
            name = items[idx].split(b',')[0].decode("utf-8")
            if os.path.exists(name):
                print("'%s' exists, skipping." % (name))
                continue

            download(args.command, name + '.npz')


if __name__ == "__main__":
    main()


================================================
FILE: sgdml/intf/__init__.py
================================================


================================================
FILE: sgdml/intf/ase_calc.py
================================================
# MIT License
#
# Copyright (c) 2018-2020 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import logging
import numpy as np

try:
    from ase.calculators.calculator import Calculator
    from ase.units import kcal, mol
except ImportError:
    raise ImportError(
        'Optional ASE dependency not found! Please run \'pip install sgdml[ase]\' to install it.'
    )

from ..predict import GDMLPredict


class SGDMLCalculator(Calculator):

    implemented_properties = ['energy', 'forces']

    def __init__(
        self,
        model_path,
        E_to_eV=kcal / mol,
        F_to_eV_Ang=kcal / mol,
        use_torch=False,
        *args,
        **kwargs
    ):
        """
        ASE calculator for the sGDML force field.

        A calculator takes atomic numbers and atomic positions from an Atoms object and calculates the energy and forces.

        Note
        ----
        ASE uses eV and Angstrom as energy and length unit, respectively. Unless the paramerters `E_to_eV` and `F_to_eV_Ang` are specified, the sGDML model is assumed to use kcal/mol and Angstorm and the appropriate conversion factors are set accordingly.
        Here is how to find them: `ASE units <https://wiki.fysik.dtu.dk/ase/ase/units.html>`_.

        Parameters
        ----------
                model_path : :obj:`str`
                        Path to a sGDML model file
                E_to_eV : float, optional
                        Conversion factor from whatever energy unit is used by the model to eV. By default this parameter is set to convert from kcal/mol.
                F_to_eV_Ang : float, optional
                        Conversion factor from whatever length unit is used by the model to Angstrom. By default, the length unit is not converted (assumed to be in Angstrom)
                use_torch : boolean, optional
                        Use PyTorch to calculate predictions
        """

        super(SGDMLCalculator, self).__init__(*args, **kwargs)

        self.log = logging.getLogger(__name__)

        model = np.load(model_path, allow_pickle=True)
        self.gdml_predict = GDMLPredict(model, use_torch=use_torch)
        self.gdml_predict.prepare_parallel(n_bulk=1)

        self.log.warning(
            'Please remember to specify the proper conversion factors, if your model does not use \'kcal/mol\' and \'Ang\' as units.'
        )

        # Converts energy from the unit used by the sGDML model to eV.
        self.E_to_eV = E_to_eV

        # Converts length from eV to unit used in sGDML model.
        self.Ang_to_R = F_to_eV_Ang / E_to_eV

        # Converts force from the unit used by the sGDML model to eV/Ang.
        self.F_to_eV_Ang = F_to_eV_Ang

    def calculate(self, atoms=None, *args, **kwargs):

        super(SGDMLCalculator, self).calculate(atoms, *args, **kwargs)

        # convert model units to ASE default units
        r = np.array(atoms.get_positions()) * self.Ang_to_R

        e, f = self.gdml_predict.predict(r.ravel())

        # convert model units to ASE default units (eV and Ang)
        e *= self.E_to_eV
        f *= self.F_to_eV_Ang

        self.results = {'energy': e, 'forces': f.reshape(-1, 3)}


================================================
FILE: sgdml/predict.py
================================================
"""
This module contains all routines for evaluating GDML and sGDML models.
"""

# MIT License
#
# Copyright (c) 2018-2022 Stefan Chmiela, Gregory Fonseca
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

from __future__ import print_function

import sys
import logging
import os
import psutil

import multiprocessing as mp

Pool = mp.get_context('fork').Pool

import timeit
from functools import partial

try:
    import torch
except ImportError:
    _has_torch = False
else:
    _has_torch = True

try:
    _torch_mps_is_available = torch.backends.mps.is_available()
except AttributeError:
    _torch_mps_is_available = False
_torch_mps_is_available = False

try:
    _torch_cuda_is_available = torch.cuda.is_available()
except AttributeError:
    _torch_cuda_is_available = False

import numpy as np

from . import __version__
from .utils.desc import Desc


def share_array(arr_np):
    """
    Return a ctypes array allocated from shared memory with data from a
    NumPy array of type `float`.

    Parameters
    ----------
            arr_np : :obj:`numpy.ndarray`
                    NumPy array.

    Returns
    -------
            array of :obj:`ctype`
    """

    arr = mp.RawArray('d', arr_np.ravel())
    return arr, arr_np.shape


def _predict_wkr(
    r, r_desc_d_desc, lat_and_inv, glob_id, wkr_start_stop=None, chunk_size=None
):
    """
    Compute (part) of a prediction.

    Every prediction is a linear combination involving the training points used for
    this model. This function evalutates that combination for the range specified by
    `wkr_start_stop`. This workload can optionally be processed in chunks,
    which can be faster as it requires less memory to be allocated.

    Note
    ----
        It is sufficient to provide either the parameter `r` or `r_desc_d_desc`.
        The other one can be set to `None`.

    Parameters
    ----------
            r : :obj:`numpy.ndarray`
                    An array of size 3N containing the Cartesian
                    coordinates of each atom in the molecule.
            r_desc_d_desc : tuple of :obj:`numpy.ndarray`
                    A tuple made up of:
                        (1) An array of size D containing the descriptors
                        of dimension D for the molecule.
                        (2) An array of size D x 3N containing the
                        descriptor Jacobian for the molecules. It has dimension
                        D with 3N partial derivatives with respect to the 3N
                        Cartesian coordinates of each atom.
            lat_and_inv : tuple of :obj:`numpy.ndarray`
                    Tuple of 3 x 3 matrix containing lattice vectors as columns and
                    its inverse.
            glob_id : int
                    Identifier of the global namespace that this
                    function is supposed to be using (zero if only one
                    instance of this class exists at the same time).
            wkr_start_stop : tuple of int, optional
                    Range defined by the indices of first and last (exclusive)
                    sum element. The full prediction is generated if this parameter
                    is not specified.
            chunk_size : int, optional
                    Chunk size. The whole linear combination is evaluated in a large
                    vector operation instead of looping over smaller chunks if this
                    parameter is left unspecified.

    Returns
    -------
            :obj:`numpy.ndarray`
                    Partial prediction of all force components and
                    energy (appended to array as last element).
    """

    global globs
    glob = globs[glob_id]
    sig, n_perms = glob['sig'], glob['n_perms']

    desc_func = glob['desc_func']

    R_desc_perms = np.frombuffer(glob['R_desc_perms']).reshape(
        glob['R_desc_perms_shape']
    )
    R_d_desc_alpha_perms = np.frombuffer(glob['R_d_desc_alpha_perms']).reshape(
        glob['R_d_desc_alpha_perms_shape']
    )

    if 'alphas_E_lin' in glob:
        alphas_E_lin = np.frombuffer(glob['alphas_E_lin']).reshape(
            glob['alphas_E_lin_shape']
        )

    r_desc, r_d_desc = r_desc_d_desc or desc_func.from_R(
        r, lat_and_inv, max_processes=1
    )  # no additional forking during parallelization

    n_train = int(R_desc_perms.shape[0] / n_perms)

    wkr_start, wkr_stop = (0, n_train) if wkr_start_stop is None else wkr_start_stop
    if chunk_size is None:
        chunk_size = n_train

    dim_d = desc_func.dim
    dim_i = desc_func.dim_i
    dim_c = chunk_size * n_perms

    # Pre-allocate memory.
    diff_ab_perms = np.empty((dim_c, dim_d))
    a_x2 = np.empty((dim_c,))
    mat52_base = np.empty((dim_c,))

    # avoid divisions (slower)
    sig_inv = 1.0 / sig
    mat52_base_fact = 5.0 / (3 * sig**3)
    diag_scale_fact = 5.0 / sig
    sqrt5 = np.sqrt(5.0)

    E_F = np.zeros((dim_d + 1,))
    F = E_F[1:]

    wkr_start *= n_perms
    wkr_stop *= n_perms

    b_start = wkr_start
    for b_stop in list(range(wkr_start + dim_c, wkr_stop, dim_c)) + [wkr_stop]:

        rj_desc_perms = R_desc_perms[b_start:b_stop, :]
        rj_d_desc_alpha_perms = R_d_desc_alpha_perms[b_start:b_stop, :]

        # Resize pre-allocated memory for last iteration, if chunk_size is not a divisor of the training set size.
        # Note: It's faster to process equally sized chunks.
        c_size = b_stop - b_start
        if c_size < dim_c:
            diff_ab_perms = diff_ab_perms[:c_size, :]
            a_x2 = a_x2[:c_size]
            mat52_base = mat52_base[:c_size]

        np.subtract(
            np.broadcast_to(r_desc, rj_desc_perms.shape),
            rj_desc_perms,
            out=diff_ab_perms,
        )
        norm_ab_perms = sqrt5 * np.linalg.norm(diff_ab_perms, axis=1)

        np.exp(-norm_ab_perms * sig_inv, out=mat52_base)
        mat52_base *= mat52_base_fact
        np.einsum(
            'ji,ji->j', diff_ab_perms, rj_d_desc_alpha_perms, out=a_x2
        )  # colum wise dot product

        F += (a_x2 * mat52_base).dot(diff_ab_perms) * diag_scale_fact
        mat52_base *= norm_ab_perms + sig
        F -= mat52_base.dot(rj_d_desc_alpha_perms)

        # Note: Energies are automatically predicted with a flipped sign here (because -E are trained, instead of E)
        E_F[0] += a_x2.dot(mat52_base)

        # Note: Energies are automatically predicted with a flipped sign here (because -E are trained, instead of E)
        if 'alphas_E_lin' in glob:

            K_fe = diff_ab_perms * mat52_base[:, None]
            F += alphas_E_lin[b_start:b_stop].dot(K_fe)

            K_ee = (
                1 + (norm_ab_perms * sig_inv) * (1 + norm_ab_perms / (3 * sig))
            ) * np.exp(-norm_ab_perms * sig_inv)

            E_F[0] += K_ee.dot(alphas_E_lin[b_start:b_stop])

        b_start = b_stop

    out = E_F[: dim_i + 1]

    # Descriptor has less entries than 3N, need to extend size of the 'E_F' array.
    if dim_d < dim_i:
        out = np.empty((dim_i + 1,))
        out[0] = E_F[0]

    out[1:] = desc_func.vec_dot_d_desc(
        r_d_desc,
        F,
    )  # 'r_d_desc.T.dot(F)' for our special representation of 'r_d_desc'

    return out


class GDMLPredict(object):
    def __init__(
        self,
        model,
        batch_size=None,
        num_workers=None,
        max_memory=None,
        max_processes=None,
        use_torch=False,
        log_level=None,
    ):
        """
        Query trained sGDML force fields.

        This class is used to load a trained model and make energy and
        force predictions for new geometries. GPU support is provided
        through PyTorch (requires optional `torch` dependency to be
        installed).

        Note
        ----
                The parameters `batch_size` and `num_workers` are only
                relevant if this code runs on a CPU. Both can be set
                automatically via the function `prepare_parallel`.
                Note: Running calculations via PyTorch is only
                recommended with available GPU hardware. CPU calcuations
                are faster with our NumPy implementation.

        Parameters
        ----------
                model : :obj:`dict`
                        Data structure that holds all parameters of the
                        trained model. This object is the output of
                        `GDMLTrain.train`
                batch_size : int, optional
                        Chunk size for processing parallel tasks
                num_workers : int, optional
                        Number of parallel workers (in addition to the main
                        process)
                max_memory : int, optional
                        Limit the max. memory usage [GB]. This is only a
                        soft limit that can not always be enforced.
                max_processes : int, optional
                        Limit the max. number of processes. Otherwise
                        all CPU cores are used. This parameters has no
                        effect if `use_torch=True`
                use_torch : boolean, optional
                        Use PyTorch to calculate predictions
                log_level : optional
                        Set custom logging level (e.g. `logging.CRITICAL`)
        """

        global globs
        if 'globs' not in globals():
            globs = []

        # Create a personal global space for this model at a new index
        # Note: do not call delete entries in this list, since 'self.glob_id' is
        # static. Instead, setting them to None conserves positions while still
        # freeing up memory.
        globs.append({})
        self.glob_id = len(globs) - 1
        glob = globs[self.glob_id]

        self.log = logging.getLogger(__name__)
        if log_level is not None:
            self.log.setLevel(log_level)

        total_memory = psutil.virtual_memory().total // 2**30  # bytes to GB)
        self.max_memory = (
            min(max_memory, total_memory) if max_memory is not None else total_memory
        )

        total_cpus = mp.cpu_count()
        self.max_processes = (
            min(max_processes, total_cpus) if max_processes is not None else total_cpus
        )

        if 'type' not in model or not (model['type'] == 'm' or model['type'] == b'm'):
            self.log.critical('The provided data structure is not a valid model.')
            sys.exit()

        self.n_atoms = model['z'].shape[0]

        self.desc = Desc(self.n_atoms, max_processes=max_processes)
        glob['desc_func'] = self.desc

        # Cache for iterative training mode.
        self.R_desc = None
        self.R_d_desc = None

        self.lat_and_inv = (
            (model['lattice'], np.linalg.inv(model['lattice']))
            if 'lattice' in model
            else None
        )

        self.n_train = model['R_desc'].shape[1]
        glob['sig'] = model['sig']

        self.std = model['std'] if 'std' in model else 1.0
        self.c = model['c']

        n_perms = model['perms'].shape[0]
        glob['n_perms'] = n_perms

        self.tril_perms_lin = model['tril_perms_lin']

        self.torch_predict = None
        self.use_torch = use_torch
        if use_torch:

            if not _has_torch:
                raise ImportError(
                    'Optional PyTorch dependency not found! Please run \'pip install sgdml[torch]\' to install it or disable the PyTorch option.'
                )

            from .torchtools import GDMLTorchPredict

            self.torch_predict = GDMLTorchPredict(
                model,
                self.lat_and_inv,
                max_memory=max_memory,
                max_processes=max_processes,
                log_level=self.log.level,
            )

            # Enable data parallelism
            n_gpu = torch.cuda.device_count()
            if n_gpu > 1:
                self.torch_predict = torch.nn.DataParallel(self.torch_predict)

            # Send model to device
            # self.torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
            if _torch_cuda_is_available:
                self.torch_device = 'cuda'
            elif _torch_mps_is_available:
                self.torch_device = 'mps'
            else:
                self.torch_device = 'cpu'

            while True:
                try:
                    self.torch_predict.to(self.torch_device)
                except RuntimeError as e:
                    if 'out of memory' in str(e):

                        if _torch_cuda_is_available:
                            torch.cuda.empty_cache()

                        model = self.torch_predict
                        if isinstance(self.torch_predict, torch.nn.DataParallel):
                            model = model.module

                        if (
                            model.get_n_perm_batches() == 1
                        ):  # model caches the permutations, this could be why it is too large
                            model.set_n_perm_batches(
                                model.get_n_perm_batches() + 1
                            )  # uncache
                            # self.torch_predict.to( # NOTE!
                            #    self.torch_device
                            # )  # try sending to device again
                            pass
                        else:
                            self.log.critical(
                                'Not enough memory on device (RAM or GPU memory). There is no hope!'
                            )
                            print()
                            os._exit(1)
                    else:
                        raise e
                else:
                    break
        else:

            # Precompute permuted training descriptors and its first derivatives multiplied with the coefficients.

            R_desc_perms = (
                np.tile(model['R_desc'].T, n_perms)[:, self.tril_perms_lin]
                .reshape(self.n_train, n_perms, -1, order='F')
                .reshape(self.n_train * n_perms, -1)
            )
            glob['R_desc_perms'], glob['R_desc_perms_shape'] = share_array(R_desc_perms)

            R_d_desc_alpha_perms = (
                np.tile(model['R_d_desc_alpha'], n_perms)[:, self.tril_perms_lin]
                .reshape(self.n_train, n_perms, -1, order='F')
                .reshape(self.n_train * n_perms, -1)
            )
            (
                glob['R_d_desc_alpha_perms'],
                glob['R_d_desc_alpha_perms_shape'],
            ) = share_array(R_d_desc_alpha_perms)

            if 'alphas_E' in model:
                alphas_E_lin = np.tile(model['alphas_E'][:, None], (1, n_perms)).ravel()
                glob['alphas_E_lin'], glob['alphas_E_lin_shape'] = share_array(
                    alphas_E_lin
                )

            # Parallel processing configuration

            self.bulk_mp = False  # Bulk predictions with multiple processes?

            self.pool = None

            # How many workers in addition to main process?
            num_workers = num_workers or (
                self.max_processes - 1
            )  # exclude main process
            self._set_num_workers(num_workers, force_reset=True)

            # Size of chunks in which each parallel task will be processed (unit: number of training samples)
            # This parameter should be as large as possible, but it depends on the size of available memory.
            self._set_chunk_size(batch_size)

    def __del__(self):

        global globs

        try:
            self.pool.terminate()
            self.pool.join()
            self.pool = None
        except:
            pass

        if 'globs' in globals() and globs is not None and self.glob_id < len(globs):
            globs[self.glob_id] = None

    ## Public ##

    # def set_R(self, R):
    #     """
    #     Store a reference to the training geometries.
    #     This function is used to avoid unnecessary copies of the
    #     traininig geometries when evaluation the training error
    #     (= gradient of the model's loss function).

    #     This routine is used during iterative model training.

    #     Parameters
    #     ----------
    #     R : :obj:`numpy.ndarray`
    #         Array containing the geometry for each training point.
    #     """

    #     # Add singleton dimension if input is (,3N).
    #     if R.ndim == 1:
    #         R = R[None, :]

    #     self.R = R

    #     # if self.use_torch:
    #     #     model = self.torch_predict
    #     #     if isinstance(self.torch_predict, torch.nn.DataParallel):
    #     #         model = model.module

    #     #     R_torch = torch.from_numpy(R.reshape(-1, self.n_atoms, 3)).to(self.torch_device)
    #     #     model.set_R(R_torch)

    def set_R_desc(self, R_desc):
        """
        Store a reference to the training geometry descriptors.

        This can accelerate iterative model training.

        Parameters
        ----------
            R_desc : :obj:`numpy.ndarray`, optional
                    An 2D array of size M x D containing the
                    descriptors of dimension D for M
                    molecules.
        """

        self.R_desc = R_desc

    def set_R_d_desc(self, R_d_desc):
        """
        Store a reference to the training geometry descriptor Jacobians.
        This function must be called before `set_alphas()` can be used.

        This routine is used during iterative model training.

        Parameters
        ----------
            R_d_desc : :obj:`numpy.ndarray`, optional
                    A 2D array of size M x D x 3N containing of the
                    descriptor Jacobians for M molecules. The descriptor
                    has dimension D with 3N partial derivatives with
                    respect to the 3N Cartesian coordinates of each atom.
        """

        self.R_d_desc = R_d_desc

        if self.use_torch:
            model = self.torch_predict
            if isinstance(self.torch_predict, torch.nn.DataParallel):
                model = model.module

            model.set_R_d_desc(R_d_desc)

    def set_alphas(self, alphas_F, alphas_E=None):
        """
        Reconfigure the current model with a new set of regression parameters.
        `R_d_desc` needs to be set for this function to work.

        This routine is used during iterative model training.

        Parameters
        ----------
                alphas_F : :obj:`numpy.ndarray`
                    1D array containing the new model parameters.
                alphas_E : :obj:`numpy.ndarray`, optional
                    1D array containing the additional new model parameters, if
                    energy constraints are used in the kernel (`use_E_cstr=True`)
        """

        if self.use_torch:

            model = self.torch_predict
            if isinstance(self.torch_predict, torch.nn.DataParallel):
                model = model.module

            model.set_alphas(alphas_F, alphas_E=alphas_E)

        else:

            assert self.R_d_desc is not None

            global globs
            glob = globs[self.glob_id]

            dim_i = self.desc.dim_i
            R_d_desc_alpha = self.desc.d_desc_dot_vec(
                self.R_d_desc, alphas_F.reshape(-1, dim_i)
            )

            R_d_desc_alpha_perms_new = np.tile(R_d_desc_alpha, glob['n_perms'])[
                :, self.tril_perms_lin
            ].reshape(self.n_train, glob['n_perms'], -1, order='F')

            R_d_desc_alpha_perms = np.frombuffer(glob['R_d_desc_alpha_perms'])
            np.copyto(R_d_desc_alpha_perms, R_d_desc_alpha_perms_new.ravel())

            if alphas_E is not None:

                alphas_E_lin_new = np.tile(
                    alphas_E[:, None], (1, glob['n_perms'])
                ).ravel()

                alphas_E_lin = np.frombuffer(glob['alphas_E_lin'])
                np.copyto(alphas_E_lin, alphas_E_lin_new)

    def _set_num_workers(
        self, num_workers=None, force_reset=False
    ):  # TODO: complain if chunk or worker parameters do not fit training data (this causes issues with the caching)!!
        """
        Set number of processes to use during prediction.

        If bulk_mp == True, each worker handles the whole generation of single prediction (this if for querying multiple geometries at once)
        If bulk_mp == False, each worker may handle only a part of a prediction (chunks are defined in 'wkr_starts_stops'). In that scenario multiple proesses
        are used to distribute the work of generating a single prediction

        This number should not exceed the number of available CPU cores.

        Note
        ----
                This parameter can be optimally determined using
                `prepare_parallel`.

        Parameters
        ----------
                num_workers : int, optional
                    Number of processes (maximum value is set if `None`).
                force_reset : bool, optional
                    Force applying the new setting.
        """

        if force_reset or self.num_workers is not num_workers:

            if self.pool is not None:
                self.pool.terminate()
                self.pool.join()
                self.pool = None

            self.num_workers = 0
            if num_workers is None or num_workers > 0:
                self.pool = Pool(num_workers)
                self.num_workers = (
                    self.pool._processes
                )  # number of actual workers (not max_processes)

        # Data ranges for processes
        if self.bulk_mp or self.num_workers < 2:
            # wkr_starts = [self.n_train]
            wkr_starts = [0]
        else:
            wkr_starts = list(
                range(
                    0,
                    self.n_train,
                    int(np.ceil(float(self.n_train) / self.num_workers)),
                )
            )
        wkr_stops = wkr_starts[1:] + [self.n_train]

        self.wkr_starts_stops = list(zip(wkr_starts, wkr_stops))

    def _set_chunk_size(self, chunk_size=None):

        # TODO: complain if chunk or worker parameters do not fit training data (this causes issues with the caching)!!
        """
        Set chunk size for each worker process.

        Every prediction is generated as a linear combination of the training
        points that the model is comprised of. If multiple workers are available
        (and bulk mode is disabled), each one processes an (approximatelly equal)
        part of those training points. Then, the chunk size determines how much of
        a processes workload is passed to NumPy's underlying low-level routines at
        once. If the chunk size is smaller than the number of points the worker is
        supposed to process, it processes them in multiple steps using a loop. This
        can sometimes be faster, depending on the available hardware.

        Note
        ----
                This parameter can be optimally determined using
                `prepare_parallel`.

        Parameters
        ----------
                chunk_size : int
                        Chunk size (maximum value is set if `None`).
        """

        if chunk_size is None:
            chunk_size = self.n_train

        self.chunk_size = chunk_size

    def _set_batch_size(self, batch_size=None):  # deprecated
        """

        Warning
        -------
        Deprecated! Please use the function `_set_chunk_size` in future projects.

        Set chunk size for each worker process. A chunk is a subset
        of the training data points whose linear combination needs to
        be evaluated in order to generate a prediction.

        The chunk size determines how much of a processes workload will
        be passed to Python's underlying low-level routines at once.
        This parameter is highly hardware dependent.

        Note
        ----
                This parameter can be optimally determined using
                `prepare_parallel`.

        Parameters
        ----------
                batch_size : int
                        Chunk size (maximum value is set if `None`).
        """

        self._set_chunk_size(batch_size)

    def _set_bulk_mp(self, bulk_mp=False):
        """
        Toggles bulk prediction mode.

        If bulk prediction is enabled, the prediction is parallelized accross
        input geometries, i.e. each worker generates the complete prediction for
        one query. Otherwise (depending on the number of available CPU cores) the
        input geometries are process sequentially, but every one of them may be
        processed by multiple workers at once (in chunks).

        Note
        ----
                This parameter can be optimally determined using
                `prepare_parallel`.

        Parameters
        ----------
                bulk_mp : bool, optional
                        Enable or disable bulk prediction mode.
        """

        bulk_mp = bool(bulk_mp)
        if self.bulk_mp is not bulk_mp:
            self.bulk_mp = bulk_mp

            # Reset data ranges for processes stored in 'wkr_starts_stops'
            self._set_num_workers(self.num_workers)

    def set_opt_num_workers_and_batch_size_fast(self, n_bulk=1, n_reps=1):  # deprecated
        """
        Warning
        -------
        Deprecated! Please use the function `prepare_parallel` in future projects.

        Parameters
        ----------
                n_bulk : int, optional
                        Number of geometries that will be passed to the
                        `predict` function in each call (performance
                        will be optimized for that exact use case).
                n_reps : int, optional
                        Number of repetitions (bigger value: more
                        accurate, but also slower).

        Returns
        -------
                int
                        Force and energy prediciton speed in geometries
                        per second.
        """

        self.prepare_parallel(n_bulk, n_reps)

    def prepare_parallel(
        self, n_bulk=1, n_reps=1, return_is_from_cache=False
    ):  # noqa: C901
        """
        Find and set the optimal parallelization parameters for the
        currently loaded model, running on a particular system. The result
        also depends on the number of geometries `n_bulk` that will be
        passed at once when calling the `predict` function.

        This function runs a benchmark in which the prediction routine is
        repeatedly called `n_reps`-times (default: 1) with varying parameter
        configurations, while the runtime is measured for each one. The
        optimal parameters are then cached for fast retrival in future
        calls of this function.

        We recommend calling this function after initialization of this
        class, as it will drastically increase the performance of the
        `predict` function.

        Note
        ----
                Depending on the parameter `n_reps`, this routine may take
                some seconds/minutes to complete. However, once a
                statistically significant number of benchmark results has
                been gathered for a particular configuration, it starts
                returning almost instantly.

        Parameters
        ----------
                n_bulk : int, optional
                        Number of geometries that will be passed to the
                        `predict` function in each call (performance
                        will be optimized for that exact use case).
                n_reps : int, optional
                        Number of repetitions (bigger value: more
                        accurate, but also slower).
                return_is_from_cache : bool, optional
                        If enabled, this function returns a second value
                        indicating if the returned results were obtained
                        from cache.

        Returns
        -------
                int
                        Force and energy prediciton speed in geometries
                        per second.
                boolean, optional
                        Return, whether this function obtained the results
                        from cache.
        """

        # global globs
        # glob = globs[self.glob_id]
        # n_perms = glob['n_perms']

        # No benchmarking necessary if prediction is running on GPUs.
        if self.use_torch:
            self.log.info(
                'Skipping multi-CPU benchmark, since torch is enabled.'
            )  # TODO: clarity!
            return

        # Retrieve cached benchmark results, if available.
        bmark_result = self._load_cached_bmark_result(n_bulk)
        if bmark_result is not None:

            num_workers, chunk_size, bulk_mp, gps = bmark_result

            self._set_chunk_size(chunk_size)
            self._set_num_workers(num_workers)
            self._set_bulk_mp(bulk_mp)

            if return_is_from_cache:
                is_from_cache = True
                return gps, is_from_cache
            else:
                return gps

        warm_up_done = False

        best_results = []
        last_i = None

        best_gps = 0
        gps_min = 0.0

        best_params = None

        r_dummy = np.random.rand(n_bulk, self.n_atoms * 3)

        def _dummy_predict():
            self.predict(r_dummy)

        bulk_mp_rng = [True, False] if n_bulk > 1 else [False]
        for bulk_mp in bulk_mp_rng:
            self._set_bulk_mp(bulk_mp)

            if bulk_mp is False:
                last_i = 0

            num_workers_rng = list(range(0, self.max_processes))
            if bulk_mp:
                num_workers_rng.reverse()  # benchmark converges faster this way

            # num_workers_rng_sizes = [batch_size for batch_size in batch_size_rng if min_batch_size % batch_size == 0]

            # for num_workers in range(min_num_workers,self.max_processes+1):
            for num_workers in num_workers_rng:
                if not bulk_mp and num_workers != 0 and self.n_train % num_workers != 0:
                    continue

                self._set_num_workers(num_workers)

                best_gps = 0
                gps_rng = (np.inf, 0.0)  # min and max per num_workers

                min_chunk_size = (
                    min(self.n_train, n_bulk)
                    if bulk_mp or num_workers < 2
                    else int(np.ceil(self.n_train / num_workers))
                )
                chunk_size_rng = list(range(min_chunk_size, 0, -1))

                chunk_size_rng_sizes = [
                    chunk_size
                    for chunk_size in chunk_size_rng
                    if min_chunk_size % chunk_size == 0
                ]

                # print('batch_size_rng_sizes ' + str(bulk_mp))
                # print(batch_size_rng_sizes)

                i_done = 0
                i_dir = 1
                i = 0 if last_i is None else last_i
                # i = 0

                # print(batch_size_rng_sizes)
                while i >= 0 and i < len(chunk_size_rng_sizes):

                    chunk_size = chunk_size_rng_sizes[i]
                    self._set_chunk_size(chunk_size)

                    i_done += 1

                    if warm_up_done == False:
                        timeit.timeit(_dummy_predict, number=10)
                        warm_up_done = True

                    gps = n_bulk * n_reps / timeit.timeit(_dummy_predict, number=n_reps)

                    # print(
                    #  '{:2d}@{:d} {:d} | {:7.2f} gps'.format(
                    #      num_workers, chunk_size, bulk_mp, gps
                    #  )
                    # )

                    gps_rng = (
                        min(gps_rng[0], gps),
                        max(gps_rng[1], gps),
                    )  # min and max per num_workers

                    # gps_min_max = min(gps_min_max[0], gps), max(gps_min_max[1], gps)

                    # print('     best_gps ' + str(best_gps))

                    # NEW

                    # if gps > best_gps and gps > gps_min: # gps is still going up, everything is good
                    #     best_gps = gps
                    #     best_params = num_workers, batch_size, bulk_mp
                    # else:
                    #     break

                    # if gps > best_gps: # gps is still going up, everything is good
                    #     best_gps = gps
                    #     best_params = num_workers, batch_size, bulk_mp
                    # else: # gps did not go up wrt. to previous step

                    #     # can we switch the search direction?
                    #     #   did we already?
                    #     #   we checked two consecutive configurations
                    #     #   are bigger batch sizes possible?

                    #     print(batch_size_rng_sizes)

                    #     turn_search_dir = i_dir > 0 and i_done == 2 and batch_size != batch_size_rng_sizes[1]

                    #     # only turn, if the current gps is not lower than the lowest overall
                    #     if turn_search_dir and gps >= gps_min:
                    #         i -= 2 * i_dir
                    #         i_dir = -1
                    #         print('><')
                    #         continue
                    #     else:
                    #         print('>>break ' + str(i_done))
                    #         break

                    # NEW

                    # gps still going up?
                    # AND: gps not lower than the lowest overall?
                    # if gps < best_gps and gps >= gps_min:
                    if gps < best_gps:
                        if (
                            i_dir > 0
                            and i_done == 2
                            and chunk_size
                            != chunk_size_rng_sizes[
                                1
                            ]  # there is no point in turning if this is the second batch size in the range
                        ):  # do we turn?
                            i -= 2 * i_dir
                            i_dir = -1
                            # print('><')
                            continue
                        else:
                            if chunk_size == chunk_size_rng_sizes[1]:
                                i -= 1 * i_dir
                            # print('>>break ' + str(i_done))
                            break
                    else:
                        best_gps = gps
                        best_params = num_workers, chunk_size, bulk_mp

                    if (
                        not bulk_mp and n_bulk > 1
                    ):  # stop search early when multiple cpus are available and the 1 cpu case is tested
                        if (
                            gps < gps_min
                        ):  # if the batch size run is lower than the lowest overall, stop right here
                            # print('breaking here')
                            break

                    i += 1 * i_dir

                last_i = i - 1 * i_dir
                i_dir = 1

                if len(best_results) > 0:
                    overall_best_gps = max(best_results, key=lambda x: x[1])[1]
                    if best_gps < overall_best_gps:
                        # print('breaking, because best of last test was worse than overall best so far')
                        break

                    # if best_gps < gps_min:
                    #    print('breaking here3')
                    #    break

                gps_min = gps_rng[0]  # FIX me: is this the overall min?
                # print ('gps_min ' + str(gps_min))

                # print ('best_gps')
                # print (best_gps)

                best_results.append(
                    (best_params, best_gps)
                )  # best results per num_workers

        (num_workers, chunk_size, bulk_mp), gps = max(best_results, key=lambda x: x[1])

        # Cache benchmark results.
        self._save_cached_bmark_result(n_bulk, num_workers, chunk_size, bulk_mp, gps)

        self._set_chunk_size(chunk_size)
        self._set_num_workers(num_workers)
        self._set_bulk_mp(bulk_mp)

        if return_is_from_cache:
            is_from_cache = False
            return gps, is_from_cache
        else:
            return gps

    def _save_cached_bmark_result(self, n_bulk, num_workers, chunk_size, bulk_mp, gps):

        pkg_dir = os.path.dirname(os.path.abspath(__file__))
        bmark_file = '_bmark_cache.npz'
        bmark_path = os.path.join(pkg_dir, bmark_file)

        bkey = '{}-{}-{}-{}'.format(
            self.n_atoms, self.n_train, n_bulk, self.max_processes
        )

        if os.path.exists(bmark_path):

            with np.load(bmark_path, allow_pickle=True) as bmark:
                bmark = dict(bmark)

                bmark['runs'] = np.append(bmark['runs'], bkey)
                bmark['num_workers'] = np.append(bmark['num_workers'], num_workers)
                bmark['batch_size'] = np.append(bmark['batch_size'], chunk_size)
                bmark['bulk_mp'] = np.append(bmark['bulk_mp'], bulk_mp)
                bmark['gps'] = np.append(bmark['gps'], gps)
        else:
            bmark = {
                'code_version': __version__,
                'runs': [bkey],
                'gps': [gps],
                'num_workers': [num_workers],
                'batch_size': [chunk_size],
                'bulk_mp': [bulk_mp],
            }

        np.savez_compressed(bmark_path, **bmark)

    def _load_cached_bmark_result(self, n_bulk):

        pkg_dir = os.path.dirname(os.path.abspath(__file__))
        bmark_file = '_bmark_cache.npz'
        bmark_path = os.path.join(pkg_dir, bmark_file)

        bkey = '{}-{}-{}-{}'.format(
            self.n_atoms, self.n_train, n_bulk, self.max_processes
        )

        if not os.path.exists(bmark_path):
            return None

        with np.load(bmark_path, allow_pickle=True) as bmark:

            # Keep collecting benchmark runs, until we have at least three.
            run_idxs = np.where(bmark['runs'] == bkey)[0]
            if len(run_idxs) >= 3:

                config_keys = []
                for run_idx in run_idxs:
                    config_keys.append(
                        '{}-{}-{}'.format(
                            bmark['num_workers'][run_idx],
                            bmark['batch_size'][run_idx],
                            bmark['bulk_mp'][run_idx],
                        )
                    )

                values, uinverse = np.unique(config_keys, return_index=True)

                best_mean = -1
                best_gps = 0
                for i, config_key in enumerate(zip(values, uinverse)):
                    mean_gps = np.mean(
                        bmark['gps'][
                            np.where(np.array(config_keys) == config_key[0])[0]
                        ]
                    )

                    if best_gps == 0 or best_gps < mean_gps:
                        best_mean = i
                        best_gps = mean_gps

                best_idx = run_idxs[uinverse[best_mean]]
                num_workers = bmark['num_workers'][best_idx]
                chunk_size = bmark['batch_size'][best_idx]
                bulk_mp = bmark['bulk_mp'][best_idx]

                return num_workers, chunk_size, bulk_mp, best_gps

        return None

    def get_GPU_batch(self):
        """
        Get batch size used by the GPU implementation to process bulk
        predictions (predictions for multiple input geometries at once).

        This value is determined on-the-fly depending on the available GPU
        memory.
        """

        if self.use_torch:

            model = self.torch_predict
            if isinstance(model, torch.nn.DataParallel):
                model = model.module

            return model._batch_size()

    def predict(self, R=None, return_E=True):
        """
        Predict energy and forces for multiple geometries. This function
        can run on the GPU, if the optional PyTorch dependency is
        installed and `use_torch=True` was speciefied during
        initialization of this class.

        Optionally, the descriptors and descriptor Jacobians for the
        same geometries can be provided, if already available from some
        previous calculations.

        Note
        ----
                The order of the atoms in `R` is not arbitrary and must
                be the same as used for training the model.

        Parameters
        ----------
                R : :obj:`numpy.ndarray`, optional
                        An 2D array of size M x 3N containing the
                        Cartesian coordinates of each atom of M
                        molecules. If this parameter is ommited, the training
                        error is returned. Note that the training geometries
                        need to be set right after initialization using
                        `set_R()` for this to work.
                return_E : boolean, optional
                        If false (default: true), only the forces are returned.

        Returns
        -------
                :obj:`numpy.ndarray`
                        Energies stored in an 1D array of size M (unless `return_E == False`)
                :obj:`numpy.ndarray`
                        Forces stored in an 2D arry of size M x 3N.
        """

        # Add singleton dimension if input is (,3N).
        if R is not None and R.ndim == 1:
            R = R[None, :]

        if self.use_torch:  # multi-GPU (or CPU if no GPUs are available)

            R_torch = torch.arange(self.n_train)
            if R is None:
                if self.R_d_desc is None:
                    self.log.critical(
                        'A reference to the training geometry descriptors needs to be set (using \'set_R_d_desc()\') for this function to work without arguments (using PyTorch).'
                    )
                    print()
                    os._exit(1)
            else:
                R_torch = (
                    torch.from_numpy(R.reshape(-1, self.n_atoms, 3))
                    .type(torch.float32)
                    .to(self.torch_device)
                )

            model = self.torch_predict
            if R_torch.shape[0] < torch.cuda.device_count() and isinstance(
                model, torch.nn.DataParallel
            ):
                model = self.torch_predict.module
            E_torch_F_torch = model.forward(R_torch, return_E=return_E)

            if return_E:
                E_torch, F_torch = E_torch_F_torch
                E = E_torch.cpu().numpy()
            else:
                (F_torch,) = E_torch_F_torch

            F = F_torch.cpu().numpy().reshape(-1, 3 * self.n_atoms)

        else:  # multi-CPU

            # Use precomputed descriptors in training mode.
            is_desc_in_cache = self.R_desc is not None and self.R_d_desc is not None

            if R is None and not is_desc_in_cache:
                self.log.critical(
                    'A reference to the training geometry descriptors and Jacobians needs to be set for this function to work without arguments.'
                )
                print()
                os._exit(1)

            assert is_desc_in_cache or R is not None

            dim_i = 3 * self.n_atoms
            n_pred = self.R_desc.shape[0] if R is None else R.shape[0]

            E_F = np.empty((n_pred, dim_i + 1))

            if (
                self.bulk_mp and self.num_workers > 0
            ):  # One whole prediction per worker (and multiple workers).

                _predict_wo_r_or_desc = partial(
                    _predict_wkr,
                    lat_and_inv=self.lat_and_inv,
                    glob_id=self.glob_id,
                    wkr_start_stop=None,
                    chunk_size=self.chunk_size,
                )

                for i, e_f in enumerate(
                    self.pool.imap(
                        partial(_predict_wo_r_or_desc, None)
                        if is_desc_in_cache
                        else partial(_predict_wo_r_or_desc, r_desc_d_desc=None),
                        zip(self.R_desc, self.R_d_desc) if is_desc_in_cache else R,
                    )
                ):
                    E_F[i, :] = e_f

            else:  # Multiple workers per prediction (or just one worker).

                for i in range(n_pred):

                    if is_desc_in_cache:
                        r_desc, r_d_desc = self.R_desc[i], self.R_d_desc[i]
                    else:
                        r_desc, r_d_desc = self.desc.from_R(R[i], self.lat_and_inv)

                    _predict_wo_wkr_starts_stops = partial(
                        _predict_wkr,
                        None,
                        (r_desc, r_d_desc),
                        self.lat_and_inv,
                        self.glob_id,
                        chunk_size=self.chunk_size,
                    )

                    if self.num_workers == 0:
                        E_F[i, :] = _predict_wo_wkr_starts_stops()
                    else:
                        E_F[i, :] = sum(
                            self.pool.imap_unordered(
                                _predict_wo_wkr_starts_stops, self.wkr_starts_stops
                            )
                        )

            E_F *= self.std
            F = E_F[:, 1:]
            E = E_F[:, 0] + self.c

        ret = (F,)
        if return_E:
            ret = (E,) + ret

        return ret


================================================
FILE: sgdml/solvers/__init__.py
================================================


================================================
FILE: sgdml/solvers/analytic.py
================================================
#!/usr/bin/python

# MIT License
#
# Copyright (c) 2020-2022 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import sys
import logging
import warnings
from functools import partial

import numpy as np
import scipy as sp
import timeit

from .. import DONE, NOT_DONE


class Analytic(object):
    def __init__(self, gdml_train, desc, callback=None):

        self.log = logging.getLogger(__name__)

        self.gdml_train = gdml_train
        self.desc = desc

        self.callback = callback

    # from memory_profiler import profile
    # @profile
    def solve(self, task, R_desc, R_d_desc, tril_perms_lin, y):

        sig = task['sig']
        lam = task['lam']
        use_E_cstr = task['use_E_cstr']

        n_train, dim_d = R_d_desc.shape[:2]
        n_atoms = int((1 + np.sqrt(8 * dim_d + 1)) / 2)
        dim_i = 3 * n_atoms

        if self.callback is not None:
            self.callback = partial(
                self.callback,
                disp_str='Assembling kernel matrix',
            )

        K = -self.gdml_train._assemble_kernel_mat(
            R_desc,
            R_d_desc,
            tril_perms_lin,
            sig,
            self.desc,
            use_E_cstr=use_E_cstr,
            callback=self.callback,
        )  # Flip sign to make convex

        start = timeit.default_timer()

        with warnings.catch_warnings():
            warnings.simplefilter('ignore')

            if K.shape[0] == K.shape[1]:

                K[np.diag_indices_from(K)] += lam  # Regularize

                if self.callback is not None:
                    self.callback = partial(
                        self.callback,
                        disp_str='Solving linear system (Cholesky factorization)',
                    )
                    self.callback(NOT_DONE)

                try:

                    # Cholesky (do not overwrite K in case we need to retry)
                    L, lower = sp.linalg.cho_factor(
                        K, overwrite_a=False, check_finite=False
                    )
                    alphas = -sp.linalg.cho_solve(
                        (L, lower), y, overwrite_b=False, check_finite=False
                    )

                except np.linalg.LinAlgError:  # Try a solver that makes less assumptions

                    if self.callback is not None:
                        self.callback = partial(
                            self.callback,
                            disp_str='Solving linear system (LU factorization)      ',  # Keep whitespaces!
                        )
                        self.callback(NOT_DONE)

                    try:
                        # LU
                        alphas = -sp.linalg.solve(
                            K, y, overwrite_a=True, overwrite_b=True, check_finite=False
                        )
                    except MemoryError:
                        self.log.critical(
                            'Not enough memory to train this system using a closed form solver.'
                        )
                        print()
                        os._exit(1)

                except MemoryError:
                    self.log.critical(
                        'Not enough memory to train this system using a closed form solver.'
                    )
                    print()
                    os._exit(1)
            else:

                if self.callback is not None:
                    self.callback = partial(
                        self.callback,
                        disp_str='Solving over-determined linear system (least squares approximation)',
                    )
                    self.callback(NOT_DONE)

                # Least squares for non-square K
                alphas = -np.linalg.lstsq(K, y, rcond=-1)[0]

        stop = timeit.default_timer()

        if self.callback is not None:
            dur_s = stop - start
            sec_disp_str = 'took {:.1f} s'.format(dur_s) if dur_s >= 0.1 else ''
            self.callback(
                DONE,
                disp_str='Training on {:,} points'.format(n_train),
                sec_disp_str=sec_disp_str,
            )

        return alphas

    @staticmethod
    def est_memory_requirement(n_train, n_atoms):

        est_bytes = 3 * (n_train * 3 * n_atoms) ** 2 * 8  # K + factor(s) of K
        est_bytes += (n_train * 3 * n_atoms) * 8  # alpha

        return est_bytes


================================================
FILE: sgdml/solvers/iterative.py
================================================
#!/usr/bin/python

# MIT License
#
# Copyright (c) 2020-2025 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import os
import logging
from functools import partial
import inspect
import multiprocessing as mp

import numpy as np
import scipy as sp
import timeit
import collections

from .. import DONE, NOT_DONE
from ..utils import ui
from ..predict import GDMLPredict

try:
    import torch
except ImportError:
    _has_torch = False
else:
    _has_torch = True


CG_STEPS_HIST_LEN = (
    100  # number of past steps to consider when calculatating solver effectiveness
)
EFF_RESTART_THRESH = 0  # if solver effectiveness is less than that percentage after 'CG_STEPS_HIST_LEN'-steps, a solver restart is triggert (with stronger preconditioner)

MAX_NUM_RESTARTS = 6


class CGRestartException(Exception):
    pass


class Iterative(object):
    def __init__(
        self,
        gdml_train,
        desc,
        max_memory,
        max_processes,
        use_torch,
        callback=None,
    ):

        self.log = logging.getLogger(__name__)

        self.gdml_train = gdml_train
        self.gdml_predict = None
        self.desc = desc

        self.callback = callback

        self._max_memory = max_memory
        self._max_processes = max_processes
        self._use_torch = use_torch

    def _init_precon_operator(
        self, task, R_desc, R_d_desc, tril_perms_lin, inducing_pts_idxs, callback=None
    ):

        lam = task['lam']
        lam_inv = 1.0 / lam

        sig = task['sig']

        use_E_cstr = task['use_E_cstr']

        L_inv_K_mn = self._nystroem_cholesky_factor(
            R_desc,
            R_d_desc,
            tril_perms_lin,
            sig,
            lam,
            use_E_cstr=use_E_cstr,
            col_idxs=inducing_pts_idxs,
            callback=callback,
        )

        L_inv_K_mn = np.ascontiguousarray(L_inv_K_mn)

        lev_scores = np.einsum(
            'i...,i...->...', L_inv_K_mn, L_inv_K_mn
        )  # compute leverage scores because it is basically free once we got the factor

        m, n = L_inv_K_mn.shape

        if self._use_torch and False:  # TURNED OFF!
            _torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
            L_inv_K_mn_torch = torch.from_numpy(L_inv_K_mn).to(_torch_device)

        global is_primed
        is_primed = False

        def _P_vec(v):

            global is_primed
            if not is_primed:
                is_primed = True
                return v

            if self._use_torch and False:  # TURNED OFF!

                v_torch = torch.from_numpy(v).to(_torch_device)[:, None]
                return (
                    L_inv_K_mn_torch.t().mm(L_inv_K_mn_torch.mm(v_torch)) - v_torch
                ).cpu().numpy() * lam_inv

            else:

                ret = L_inv_K_mn.T.dot(L_inv_K_mn.dot(v))
                ret -= v
                ret *= lam_inv

                return ret

        return sp.sparse.linalg.LinearOperator((n, n), matvec=_P_vec), lev_scores

    def _init_kernel_operator(
        self, task, R_desc, R_d_desc, tril_perms_lin, lam, n, callback=None
    ):

        n_train = R_desc.shape[0]

        # dummy alphas
        v_F = np.zeros((n - n_train, 1)) if task['use_E_cstr'] else np.zeros((n, 1))
        v_E = np.zeros((n_train, 1)) if task['use_E_cstr'] else None

        # Note: The standard deviation is set to 1.0, because we are predicting normalized labels here.
        model = self.gdml_train.create_model(
            task, 'cg', R_desc, R_d_desc, tril_perms_lin, 1.0, v_F, alphas_E=v_E
        )

        self.gdml_predict = GDMLPredict(
            model,
            max_memory=self._max_memory,
            max_processes=self._max_processes,
            use_torch=self._use_torch,
        )

        self.gdml_predict.set_R_desc(R_desc)  # only needed on CPU
        self.gdml_predict.set_R_d_desc(R_d_desc)

        if not self._use_torch:

            if callback is not None:
                callback = partial(callback, disp_str='Optimizing CPU parallelization')
                callback(NOT_DONE)

            self.gdml_predict.prepare_parallel(n_bulk=n_train)

            if callback is not None:
                callback(DONE)

        global is_primed
        is_primed = False

        def _K_vec(v):

            global is_primed
            if not is_primed:
                is_primed = True
                return v

            v_F, v_E = v, None
            if task['use_E_cstr']:
                v_F, v_E = v[:-n_train], v[-n_train:]

            self.gdml_predict.set_alphas(v_F, alphas_E=v_E)

            pred = self.gdml_predict.predict(return_E=task['use_E_cstr'])
            if task['use_E_cstr']:
                e_pred, f_pred = pred
                pred = np.hstack((f_pred.ravel(), -e_pred))
            else:
                pred = pred[0].ravel()

            pred -= lam * v
            return pred

        return sp.sparse.linalg.LinearOperator((n, n), matvec=_K_vec)

    def _nystroem_cholesky_factor(
        self,
        R_desc,
        R_d_desc,
        tril_perms_lin,
        sig,
        lam,
        use_E_cstr,
        col_idxs,
        callback_task_name='',
        callback=None,
    ):

        if callback_task_name != '':
            callback_task_name = ' ({})'.format(callback_task_name)

        if callback is not None:
            callback = partial(
                callback,
                disp_str='Assembling kernel [m x k]{}'.format(callback_task_name),
            )

        dim_d = R_desc.shape[1]
        n_atoms = int((1 + np.sqrt(8 * dim_d + 1)) / 2)
        n = R_desc.shape[0] * n_atoms * 3 + (R_desc.shape[0] if use_E_cstr else 0)
        m = len(
            range(*col_idxs.indices(n)) if isinstance(col_idxs, slice) else col_idxs
        )

        K_nmm = self.gdml_train._assemble_kernel_mat(
            R_desc,
            R_d_desc,
            tril_perms_lin,
            sig,
            self.desc,
            use_E_cstr=use_E_cstr,
            col_idxs=col_idxs,
            alloc_extra_rows=m,
            callback=callback,
        )

        # Store (psd) copy of K_mm in lower part of this oversized K_(n+m)m matrix.
        K_nmm[-m:, :] = -K_nmm[col_idxs, :]

        K_nm = K_nmm[:-m, :]
        K_mm = K_nmm[-m:, :]

        if callback is not None:
            callback = partial(
                callback,
                disp_str='Cholesky fact. (1/2) [k x k]{}'.format(callback_task_name),
            )
            callback(NOT_DONE)

        # Additional regularization is almost always necessary here (hence pre_reg=True).
        K_mm, lower = self._cho_factor_stable(K_mm, pre_reg=True)  # overwrites input!
        L_mm = K_mm
        # del K_mm

        if callback is not None:
            callback(DONE)
            callback = partial(
                callback,
                disp_str='m tri. solves (1/2) [k x k]{}'.format(callback_task_name),
            )
            callback(0, n)

        b_start, b_size = 0, int(n / 4)  # update in percentage steps of 25
        for b_stop in list(range(b_size, n, b_size)) + [n]:

            K_nm[b_start:b_stop, :] = sp.linalg.solve_triangular(
                L_mm,
                K_nm[b_start:b_stop, :].T,
                lower=lower,
                trans='T',
                overwrite_b=True,
                check_finite=False,
            ).T
            b_start = b_stop

            if callback is not None:
                callback(b_stop, n)

        del L_mm

        K_nmm[-m:, :] = K_nm.T.dot(K_nm)
        K_nmm[-m:, :][np.diag_indices_from(K_nmm[-m:, :])] += lam
        inner = K_nmm[-m:, :]

        if callback is not None:
            callback = partial(
                callback,
                disp_str='Cholesky fact. (2/2) [k x k]{}'.format(callback_task_name),
            )
            callback(NOT_DONE)

        L_lower = self._cho_factor_stable(
            inner, eps_mag_max=-14
        )  # Do not regularize more than 1e-14.
        if L_lower is not None:
            K_nmm[-m:, :], lower = L_lower
            L = K_nmm[-m:, :]
            del inner
        else:

            callback = partial(
                callback,
                disp_str='QR fact. (alt.) [k x k]{}'.format(callback_task_name),
            )
            callback(NOT_DONE)

            K_nmm[-m:, :] = 0
            K_nmm[-m:, :][np.diag_indices(m)] = np.sqrt(lam)

            K_nmm[-m:, :] = np.linalg.qr(K_nmm, mode='r')
            L = K_nmm[-m:, :]
            lower = False

        if callback is not None:
            callback(DONE)
            callback = partial(
                callback,
                disp_str='m tri. solves (2/2) [k x k]{}'.format(callback_task_name),
            )
            callback(0, n)

        b_start, b_size = 0, int(n / 4)  # update in percentage steps of 25
        for b_stop in list(range(b_size, n, b_size)) + [n]:

            K_nm[b_start:b_stop, :] = sp.linalg.solve_triangular(
                L,
                K_nm[b_start:b_stop, :].T,
                lower=lower,
                trans='T',
                overwrite_b=True,
                check_finite=False,
            ).T  # Note: Overwrites K_nm to save memory
            b_start = b_stop

            if callback is not None:
                callback(b_stop, n)
        del L

        return K_nm.T

    def _lev_scores(
        self,
        R_desc,
        R_d_desc,
        tril_perms_lin,
        sig,
        lam,
        use_E_cstr,
        n_inducing_pts,
        callback=None,
    ):

        n_train, dim_d = R_d_desc.shape[:2]
        dim_i = 3 * int((1 + np.sqrt(8 * dim_d + 1)) / 2)

        # Convert from training points to actual columns.
        # dim_m = (
        #    np.maximum(1, n_inducing_pts // 4) * dim_i
        # )  # only use 1/4 of inducing points for leverage score estimate
        dim_m = dim_i * min(n_inducing_pts, 10)

        # Which columns to use for leverage score approximation?
        lev_approx_idxs = np.sort(
            np.random.choice(
                n_train * dim_i + (n_train if use_E_cstr else 0), dim_m, replace=False
            )
        )  # random subset of columns
        # lev_approx_idxs = np.sort(np.random.choice(n_train*dim_i, dim_m, replace=False)) # random subset of columns

        # lev_approx_idxs = np.s_[
        #    :dim_m
        # ]  # first 'dim_m' columns (faster kernel construction)

        L_inv_K_mn = self._nystroem_cholesky_factor(
            R_desc,
            R_d_desc,
            tril_perms_lin,
            sig,
            lam,
            use_E_cstr=use_E_cstr,
            col_idxs=lev_approx_idxs,
            callback_task_name='lev. scores',
            callback=callback,
        )

        lev_scores = np.einsum('i...,i...->...', L_inv_K_mn, L_inv_K_mn)
        return lev_scores

    def inducing_pts_from_lev_scores(self, lev_scores, N):

        # Sample 'N' columns with probabilities proportional to the leverage scores.
        inducing_pts_idxs = np.random.choice(
            np.arange(lev_scores.size),
            N,
            replace=False,
            p=lev_scores / lev_scores.sum(),
        )

        return np.sort(inducing_pts_idxs)

    # performs a cholesky decompostion of a matrix, but regularizes the matrix (if neeeded) until its positive definite
    def _cho_factor_stable(self, M, pre_reg=False, eps_mag_max=1):
        """
        Performs a Cholesky decompostion of a matrix, but regularizes
        as needed until its positive definite.

        Parameters
        ----------
            M : :obj:`numpy.ndarray`
                Matrix to factorize.
            pre_reg : boolean, optional
                Regularize M right away (machine precision), before
                trying to factorize it (default: False).

        Returns
        -------
            :obj:`numpy.ndarray`
                Matrix whose upper or lower triangle contains the Cholesky factor of a. Other parts of the matrix contain random data.
            boolean
                Flag indicating whether the factor is in the lower or upper triangle
        """

        eps = np.finfo(float).eps
        eps_mag = int(np.floor(np.log10(eps)))

        if pre_reg:
            M[np.diag_indices_from(M)] += eps
            eps_mag += 1  # if additional regularization is necessary, start from the next order of magnitude

        for reg in 10.0 ** np.arange(
            eps_mag, eps_mag_max + 1
        ):  # regularize more and more aggressively (strongest regularization: 1)
            try:

                L, lower = sp.linalg.cho_factor(
                    M, overwrite_a=False, check_finite=False
                )

            except np.linalg.LinAlgError as e:

                if 'not positive definite' in str(e):
                    self.log.debug(
                        'Cholesky solver needs more aggressive regularization (adding {} to diagonal)'.format(
                            reg
                        )
                    )
                    M[np.diag_indices_from(M)] += reg
                else:
                    raise e
            else:
                return L, lower

        self.log.critical(
            'Failed to factorize despite strong regularization (max: {})!\nYou could try a larger sigma.'.format(
                10.0**eps_mag_max
            )
        )
        print()
        os._exit(1)

    def solve(
        self,
        task,
        R_desc,
        R_d_desc,
        tril_perms_lin,
        y,
        y_std,
        tol=1e-4,
        save_progr_callback=None,
    ):

        global num_iters, start, resid, avg_tt, m  # , P_t

        n_train, n_atoms = task['R_train'].shape[:2]
        dim_i = 3 * n_atoms

        sig = task['sig']
        lam = task['lam']

        # these keys are only present if the task was created from an existing model
        alphas0_F = task['alphas0_F'] if 'alphas0_F' in task else None
        alphas0_E = task['alphas0_E'] if 'alphas0_E' in task else None
        num_iters0 = task['solver_iters'] if 'solver_iters' in task else 0

        # Number of inducing points to use for Nystrom approximation.
        max_memory_bytes = self._max_memory * 1024**3
        max_n_inducing_pts = Iterative.max_n_inducing_pts(
            n_train, n_atoms, max_memory_bytes
        )
        n_inducing_pts = min(n_train, max_n_inducing_pts)
        n_inducing_pts_init = (
            len(task['inducing_pts_idxs']) // (3 * n_atoms)
            if 'inducing_pts_idxs' in task
            else None
        )

        if self.callback is not None:
            self.callback = partial(
                self.callback,
                disp_str='Building preconditioner (k={} ind. point{})'.format(
                    n_inducing_pts, 's' if n_inducing_pts > 1 else ''
                ),
            )
        subtask_callback = (
            partial(ui.sec_callback, main_callback=self.callback)
            if self.callback is not None
            else None
        )

        lev_scores = None
        if n_inducing_pts_init is not None and n_inducing_pts_init == n_inducing_pts:
            inducing_pts_idxs = task['inducing_pts_idxs']  # reuse old inducing points
        else:
            # Determine good inducing points.
            lev_scores = self._lev_scores(
                R_desc,
                R_d_desc,
                tril_perms_lin,
                sig,
                lam,
                task['use_E_cstr'],
                n_inducing_pts,
                callback=subtask_callback,
            )

            dim_m = n_inducing_pts * dim_i
            inducing_pts_idxs = self.inducing_pts_from_lev_scores(lev_scores, dim_m)

        start = timeit.default_timer()
        P_op, lev_scores = self._init_precon_operator(
            task,
            R_desc,
            R_d_desc,
            tril_perms_lin,
            inducing_pts_idxs,
            callback=subtask_callback,
        )
        stop = timeit.default_timer()

        if self.callback is not None:
            dur_s = stop - start
            sec_disp_str = 'took {:.1f} s'.format(dur_s) if dur_s >= 0.1 else ''
            self.callback(DONE, sec_disp_str=sec_disp_str)

            self.callback = partial(
                self.callback,
                disp_str='Initializing solver',
            )
        subtask_callback = (
            partial(ui.sec_callback, main_callback=self.callback)
            if self.callback is not None
            else None
        )

        n = P_op.shape[0]
        K_op = self._init_kernel_operator(
            task, R_desc, R_d_desc, tril_perms_lin, lam, n, callback=subtask_callback
        )

        num_iters = int(num_iters0)

        if self.callback is not None:

            num_devices = (
                mp.cpu_count() if self._max_processes is None else self._max_processes
            )
            if self._use_torch:
            
Download .txt
gitextract_7dj7aa45/

├── .gitignore
├── LICENSE.txt
├── README.md
├── pyproject.toml
├── scripts/
│   ├── sgdml_dataset_from_aims.py
│   ├── sgdml_dataset_from_extxyz.py
│   ├── sgdml_dataset_from_ipi.py
│   ├── sgdml_dataset_to_extxyz.py
│   ├── sgdml_dataset_via_ase.py
│   └── sgdml_datasets_from_model.py
├── setup.cfg
├── setup.py
└── sgdml/
    ├── __init__.py
    ├── cli.py
    ├── get.py
    ├── intf/
    │   ├── __init__.py
    │   └── ase_calc.py
    ├── predict.py
    ├── solvers/
    │   ├── __init__.py
    │   ├── analytic.py
    │   └── iterative.py
    ├── torchtools.py
    ├── train.py
    └── utils/
        ├── __init__.py
        ├── desc.py
        ├── io.py
        ├── perm.py
        └── ui.py
Download .txt
SYMBOL INDEX (169 symbols across 17 files)

FILE: scripts/sgdml_dataset_from_aims.py
  function read_reference_data (line 36) | def read_reference_data(f):  # noqa C901

FILE: scripts/sgdml_dataset_from_extxyz.py
  function read_nonstd_ext_xyz (line 46) | def read_nonstd_ext_xyz(f):
  function extract_info_from_extxyz (line 95) | def extract_info_from_extxyz(file_path):

FILE: scripts/sgdml_dataset_from_ipi.py
  function raw_input_float (line 36) | def raw_input_float(prompt):
  function read_concat_xyz (line 45) | def read_concat_xyz(f):
  function read_out_file (line 79) | def read_out_file(f, col):

FILE: setup.py
  function get_property (line 7) | def get_property(property, package):

FILE: sgdml/__init__.py
  class ColoredFormatter (line 45) | class ColoredFormatter(logging.Formatter):
    method __init__ (line 65) | def __init__(self, msg, use_color=True):
    method format (line 70) | def format(self, record):
  class ColoredLogger (line 95) | class ColoredLogger(logging.Logger):
    method __init__ (line 96) | def __init__(self, name):
    method done (line 117) | def done(self, msg, *args, **kwargs):

FILE: sgdml/cli.py
  class AssistantError (line 77) | class AssistantError(Exception):
  function _print_splash (line 81) | def _print_splash(max_memory, max_processes, use_torch):
  function _check_update (line 146) | def _check_update():
  function _print_billboard (line 168) | def _print_billboard():
  function _print_dataset_properties (line 223) | def _print_dataset_properties(dataset, title_str='Dataset properties'):
  function _print_task_properties_reduced (line 325) | def _print_task_properties_reduced(
  function _print_task_properties (line 349) | def _print_task_properties(task, title_str='Task properties'):
  function _print_model_properties (line 427) | def _print_model_properties(model, title_str='Model properties'):
  function _print_next_step (line 558) | def _print_next_step(
  function all (line 612) | def all(
  function create (line 745) | def create(  # noqa: C901
  function train (line 946) | def train(
  function _batch (line 1164) | def _batch(iterable, n=1):
  function _online_err (line 1170) | def _online_err(err, size, n, mae_n_sum, rmse_n_sum):
  function resume (line 1183) | def resume(
  function validate (line 1288) | def validate(
  function test (line 1327) | def test(
  function select (line 1797) | def select(model_dir, overwrite, model_file=None, command=None, **kwargs...
  function show (line 1940) | def show(file, command=None, **kwargs):
  function reset (line 1955) | def reset(command=None, **kwargs):
  function main (line 1979) | def main():

FILE: sgdml/get.py
  function download (line 44) | def download(command, file_name):
  function main (line 71) | def main():

FILE: sgdml/intf/ase_calc.py
  class SGDMLCalculator (line 37) | class SGDMLCalculator(Calculator):
    method __init__ (line 41) | def __init__(
    method calculate (line 93) | def calculate(self, atoms=None, *args, **kwargs):

FILE: sgdml/predict.py
  function share_array (line 65) | def share_array(arr_np):
  function _predict_wkr (line 84) | def _predict_wkr(
  class GDMLPredict (line 248) | class GDMLPredict(object):
    method __init__ (line 249) | def __init__(
    method __del__ (line 465) | def __del__(self):
    method set_R_desc (line 510) | def set_R_desc(self, R_desc):
    method set_R_d_desc (line 526) | def set_R_d_desc(self, R_d_desc):
    method set_alphas (line 551) | def set_alphas(self, alphas_F, alphas_E=None):
    method _set_num_workers (line 603) | def _set_num_workers(
    method _set_chunk_size (line 658) | def _set_chunk_size(self, chunk_size=None):
    method _set_batch_size (line 689) | def _set_batch_size(self, batch_size=None):  # deprecated
    method _set_bulk_mp (line 717) | def _set_bulk_mp(self, bulk_mp=False):
    method set_opt_num_workers_and_batch_size_fast (line 745) | def set_opt_num_workers_and_batch_size_fast(self, n_bulk=1, n_reps=1):...
    method prepare_parallel (line 770) | def prepare_parallel(
    method _save_cached_bmark_result (line 1044) | def _save_cached_bmark_result(self, n_bulk, num_workers, chunk_size, b...
    method _load_cached_bmark_result (line 1076) | def _load_cached_bmark_result(self, n_bulk):
    method get_GPU_batch (line 1129) | def get_GPU_batch(self):
    method predict (line 1146) | def predict(self, R=None, return_E=True):

FILE: sgdml/solvers/analytic.py
  class Analytic (line 37) | class Analytic(object):
    method __init__ (line 38) | def __init__(self, gdml_train, desc, callback=None):
    method solve (line 49) | def solve(self, task, R_desc, R_d_desc, tril_perms_lin, y):
    method est_memory_requirement (line 154) | def est_memory_requirement(n_train, n_atoms):

FILE: sgdml/solvers/iterative.py
  class CGRestartException (line 56) | class CGRestartException(Exception):
  class Iterative (line 60) | class Iterative(object):
    method __init__ (line 61) | def __init__(
    method _init_precon_operator (line 83) | def _init_precon_operator(
    method _init_kernel_operator (line 144) | def _init_kernel_operator(
    method _nystroem_cholesky_factor (line 208) | def _nystroem_cholesky_factor(
    method _lev_scores (line 353) | def _lev_scores(
    method inducing_pts_from_lev_scores (line 401) | def inducing_pts_from_lev_scores(self, lev_scores, N):
    method _cho_factor_stable (line 414) | def _cho_factor_stable(self, M, pre_reg=False, eps_mag_max=1):
    method solve (line 473) | def solve(
    method max_n_inducing_pts (line 828) | def max_n_inducing_pts(n_train, n_atoms, max_memory_bytes):
    method est_memory_requirement (line 847) | def est_memory_requirement(n_train, n_inducing_pts, n_atoms):

FILE: sgdml/torchtools.py
  function _next_batch_size (line 52) | def _next_batch_size(n_total, batch_size):
  class GDMLTorchAssemble (line 61) | class GDMLTorchAssemble(nn.Module):
    method __init__ (line 67) | def __init__(
    method _forward (line 110) | def _forward(
    method forward (line 343) | def forward(self, J_indx):
  class GDMLTorchPredict (line 395) | class GDMLTorchPredict(nn.Module):
    method __init__ (line 401) | def __init__(
    method get_n_perm_batches (line 595) | def get_n_perm_batches(self):
    method set_n_perm_batches (line 600) | def set_n_perm_batches(self, n_perm_batches):
    method apply_perms_to_obj (line 618) | def apply_perms_to_obj(self, xs, perm_idxs=None):
    method remove_perms_from_obj (line 631) | def remove_perms_from_obj(self, xs):
    method uncache_perms (line 637) | def uncache_perms(self):
    method cache_perms (line 651) | def cache_perms(self):
    method est_mem_requirement (line 667) | def est_mem_requirement(self, return_min=False):
    method set_R_d_desc (line 728) | def set_R_d_desc(self, R_d_desc):
    method set_alphas (line 760) | def set_alphas(self, alphas, alphas_E=None):
    method _forward (line 877) | def _forward(self, Rs_or_train_idxs, return_E=True):
    method forward (line 1048) | def forward(self, Rs_or_train_idxs, return_E=True):

FILE: sgdml/train.py
  function _share_array (line 75) | def _share_array(arr_np, typecode_or_type):
  function _assemble_kernel_mat_wkr (line 97) | def _assemble_kernel_mat_wkr(
  class GDMLTrain (line 305) | class GDMLTrain(object):
    method __init__ (line 306) | def __init__(self, max_memory=None, max_processes=None, use_torch=False):
    method __del__ (line 363) | def __del__(self):
    method create_task (line 370) | def create_task(
    method create_task_from_model (line 649) | def create_task_from_model(self, model, dataset):
    method create_model (line 727) | def create_model(
    method train (line 836) | def train(  # noqa: C901
    method _recov_int_const (line 1090) | def _recov_int_const(
    method _assemble_kernel_mat (line 1260) | def _assemble_kernel_mat(
    method draw_strat_sample (line 1537) | def draw_strat_sample(self, T, n, excl_idxs=None):

FILE: sgdml/utils/desc.py
  function _pbc_diff (line 44) | def _pbc_diff(diffs, lat_and_inv, use_torch=False):
  function _pdist (line 80) | def _pdist(r, lat_and_inv=None):
  function _squareform (line 113) | def _squareform(vec_or_mat):
  function _r_to_desc (line 139) | def _r_to_desc(r, pdist):
  function _r_to_d_desc (line 166) | def _r_to_d_desc(r, pdist, lat_and_inv=None):
  function _from_r (line 208) | def _from_r(r, lat_and_inv=None):
  class Desc (line 242) | class Desc(object):
    method __init__ (line 244) | def __init__(self, n_atoms, max_processes=None):
    method from_R (line 288) | def from_R(self, R, lat_and_inv=None, max_processes=None, callback=None):
    method d_desc_dot_vec (line 368) | def d_desc_dot_vec(self, R_d_desc, vecs, overwrite_vecs=False):
    method vec_dot_d_desc (line 388) | def vec_dot_d_desc(self, R_d_desc, vecs, out=None):
    method d_desc_from_comp (line 422) | def d_desc_from_comp(self, R_d_desc, out=None):
    method d_desc_to_comp (line 473) | def d_desc_to_comp(self, R_d_desc):
    method perm (line 510) | def perm(perm):

FILE: sgdml/utils/io.py
  function z_str_to_z (line 154) | def z_str_to_z(z_str):
  function z_to_z_str (line 158) | def z_to_z_str(z):
  function train_dir_name (line 162) | def train_dir_name(dataset, n_train, use_sym, use_E, use_E_cstr):
  function task_file_name (line 183) | def task_file_name(task):
  function model_file_name (line 192) | def model_file_name(task_or_model, is_extended=False):
  function dataset_md5 (line 208) | def dataset_md5(dataset):
  function read_xyz (line 238) | def read_xyz(file_path):
  function write_geometry (line 264) | def write_geometry(filename, r, z, comment_str=''):
  function generate_xyz_str (line 278) | def generate_xyz_str(r, z, e=None, f=None, lattice=None):
  function lattice_vec_to_par (line 303) | def lattice_vec_to_par(lat):
  function is_file_type (line 327) | def is_file_type(arg, type):
  function filter_file_type (line 414) | def filter_file_type(dir, type, md5_match=None):
  function is_valid_file_type (line 464) | def is_valid_file_type(arg_in):
  function is_dir_with_file_type (line 514) | def is_dir_with_file_type(arg, type, or_file=False):
  function is_task_dir_resumeable (line 572) | def is_task_dir_resumeable(
  function is_strict_pos_int (line 642) | def is_strict_pos_int(arg):
  function parse_list_or_range (line 667) | def parse_list_or_range(arg):

FILE: sgdml/utils/perm.py
  function share_array (line 48) | def share_array(arr_np, typecode):
  function _bipartite_match_wkr (line 53) | def _bipartite_match_wkr(i, n_train, same_z_cost):
  function bipartite_match (line 90) | def bipartite_match(R, z, lat_and_inv=None, max_processes=None, callback...
  function sync_perm_mat (line 238) | def sync_perm_mat(match_perms_all, match_cost, n_atoms, callback=None):
  function to_cycles (line 263) | def to_cycles(perm):
  function salvage_subgroup (line 289) | def salvage_subgroup(perms):
  function complete_sym_group (line 344) | def complete_sym_group(
  function find_perms (line 384) | def find_perms(R, z, lat_and_inv=None, callback=None, max_processes=None):
  function find_extra_perms (line 415) | def find_extra_perms(R, z, lat_and_inv=None, callback=None, max_processe...
  function find_frags (line 527) | def find_frags(r, z, lat_and_inv=None):
  function find_frag_perms (line 564) | def find_frag_perms(R, z, lat_and_inv=None, callback=None, max_processes...
  function _frag_perm_to_perm (line 759) | def _frag_perm_to_perm(n_atoms, frag_idxs, frag_perms):
  function find_perms_in_frag (line 774) | def find_perms_in_frag(R, z, frag_idxs, lat_and_inv=None, max_processes=...
  function find_perms_via_alignment (line 790) | def find_perms_via_alignment(
  function find_perms_via_reflection (line 917) | def find_perms_via_reflection(
  function print_perm_colors (line 967) | def print_perm_colors(perm, pts, plane_3idxs=None):
  function inv_perm (line 1035) | def inv_perm(perm):

FILE: sgdml/utils/ui.py
  function yes_or_no (line 39) | def yes_or_no(question):
  function callback (line 61) | def callback(
  function sec_callback (line 150) | def sec_callback(
  function color_str (line 189) | def color_str(str, fore_color=WHITE, back_color=BLACK, bold=False):
  function blink_str (line 203) | def blink_str(str):
  function unicode_str (line 208) | def unicode_str(s):
  function gen_memory_str (line 218) | def gen_memory_str(bytes):
  function gen_lattice_str (line 232) | def gen_lattice_str(lat):
  function str_plen (line 244) | def str_plen(str):
  function wrap_str (line 265) | def wrap_str(str, width=MAX_PRINT_WIDTH - LOG_LEVELNAME_WIDTH):
  function indent_str (line 297) | def indent_str(str, indent):
  function wrap_indent_str (line 318) | def wrap_indent_str(label, str, width=MAX_PRINT_WIDTH - LOG_LEVELNAME_WI...
  function merge_col_str (line 346) | def merge_col_str(
  function gen_mat_str (line 378) | def gen_mat_str(mat):
  function gen_range_str (line 434) | def gen_range_str(min, max):
  function print_step_title (line 457) | def print_step_title(title_str, sec_title_str='', underscore=True):
  function print_two_column_str (line 474) | def print_two_column_str(str, sec_str=''):
  function print_lattice (line 489) | def print_lattice(lat=None, inset=False):
Condensed preview — 28 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (404K chars).
[
  {
    "path": ".gitignore",
    "chars": 198,
    "preview": "\n.DS_Store\n\n# Compiled python modules.\n*.pyc\n\n# Setuptools distribution folder.\n/dist/\n\n# Python egg metadata, regenerat"
  },
  {
    "path": "LICENSE.txt",
    "chars": 1075,
    "preview": "MIT License\n\nCopyright (c) 2018-2022 Stefan Chmiela\n\nPermission is hereby granted, free of charge, to any person obtaini"
  },
  {
    "path": "README.md",
    "chars": 3667,
    "preview": "# Symmetric Gradient Domain Machine Learning (sGDML)\n\nFor more details visit: [sgdml.org](http://sgdml.org/)  \nDocumenta"
  },
  {
    "path": "pyproject.toml",
    "chars": 91,
    "preview": "[tool.black]\nskip-string-normalization = true\nskip-numeric-underscore-normalization = true\n"
  },
  {
    "path": "scripts/sgdml_dataset_from_aims.py",
    "chars": 5518,
    "preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2022 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
  },
  {
    "path": "scripts/sgdml_dataset_from_extxyz.py",
    "chars": 8022,
    "preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2022 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
  },
  {
    "path": "scripts/sgdml_dataset_from_ipi.py",
    "chars": 5828,
    "preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018 Stefan Chmiela\n#\n# Permission is hereby granted, free of charge,"
  },
  {
    "path": "scripts/sgdml_dataset_to_extxyz.py",
    "chars": 2950,
    "preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2019 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
  },
  {
    "path": "scripts/sgdml_dataset_via_ase.py",
    "chars": 6055,
    "preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2022 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
  },
  {
    "path": "scripts/sgdml_datasets_from_model.py",
    "chars": 3388,
    "preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018 Stefan Chmiela\n#\n# Permission is hereby granted, free of charge,"
  },
  {
    "path": "setup.cfg",
    "chars": 277,
    "preview": "[flake8]\nmax-complexity = 12\nignore = E501,W503,E741\nselect = C,E,F,W\n\n[isort]\nmulti_line_output = 3\ninclude_trailing_co"
  },
  {
    "path": "setup.py",
    "chars": 2004,
    "preview": "import os\nimport re\nfrom io import open\nfrom setuptools import setup, find_packages\n\n\ndef get_property(property, package"
  },
  {
    "path": "sgdml/__init__.py",
    "chars": 3591,
    "preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2019-2025 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
  },
  {
    "path": "sgdml/cli.py",
    "chars": 74967,
    "preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2022 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
  },
  {
    "path": "sgdml/get.py",
    "chars": 5223,
    "preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2023 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
  },
  {
    "path": "sgdml/intf/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "sgdml/intf/ase_calc.py",
    "chars": 4188,
    "preview": "# MIT License\n#\n# Copyright (c) 2018-2020 Stefan Chmiela\n#\n# Permission is hereby granted, free of charge, to any person"
  },
  {
    "path": "sgdml/predict.py",
    "chars": 46891,
    "preview": "\"\"\"\nThis module contains all routines for evaluating GDML and sGDML models.\n\"\"\"\n\n# MIT License\n#\n# Copyright (c) 2018-20"
  },
  {
    "path": "sgdml/solvers/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "sgdml/solvers/analytic.py",
    "chars": 5423,
    "preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2020-2022 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
  },
  {
    "path": "sgdml/solvers/iterative.py",
    "chars": 27847,
    "preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2020-2025 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
  },
  {
    "path": "sgdml/torchtools.py",
    "chars": 40652,
    "preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2019-2023 Stefan Chmiela, Jan Hermann\n#\n# Permission is hereby grante"
  },
  {
    "path": "sgdml/train.py",
    "chars": 60606,
    "preview": "\"\"\"\nThis module contains all routines for training GDML and sGDML models.\n\"\"\"\n\n# MIT License\n#\n# Copyright (c) 2018-2022"
  },
  {
    "path": "sgdml/utils/__init__.py",
    "chars": 0,
    "preview": ""
  },
  {
    "path": "sgdml/utils/desc.py",
    "chars": 17340,
    "preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2022 Stefan Chmiela, Luis Galvez\n#\n# Permission is hereby grante"
  },
  {
    "path": "sgdml/utils/io.py",
    "chars": 19139,
    "preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2021 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
  },
  {
    "path": "sgdml/utils/perm.py",
    "chars": 31377,
    "preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2021 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
  },
  {
    "path": "sgdml/utils/ui.py",
    "chars": 13361,
    "preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2021 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
  }
]

About this extraction

This page contains the full source code of the stefanch/sGDML GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 28 files (380.5 KB), approximately 92.7k tokens, and a symbol index with 169 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.

Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.

Copied to clipboard!