Showing preview only (394K chars total). Download the full file or copy to clipboard to get everything.
Repository: stefanch/sGDML
Branch: master
Commit: a6ae5e86f88c
Files: 28
Total size: 380.5 KB
Directory structure:
gitextract_7dj7aa45/
├── .gitignore
├── LICENSE.txt
├── README.md
├── pyproject.toml
├── scripts/
│ ├── sgdml_dataset_from_aims.py
│ ├── sgdml_dataset_from_extxyz.py
│ ├── sgdml_dataset_from_ipi.py
│ ├── sgdml_dataset_to_extxyz.py
│ ├── sgdml_dataset_via_ase.py
│ └── sgdml_datasets_from_model.py
├── setup.cfg
├── setup.py
└── sgdml/
├── __init__.py
├── cli.py
├── get.py
├── intf/
│ ├── __init__.py
│ └── ase_calc.py
├── predict.py
├── solvers/
│ ├── __init__.py
│ ├── analytic.py
│ └── iterative.py
├── torchtools.py
├── train.py
└── utils/
├── __init__.py
├── desc.py
├── io.py
├── perm.py
└── ui.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .gitignore
================================================
.DS_Store
# Compiled python modules.
*.pyc
# Setuptools distribution folder.
/dist/
# Python egg metadata, regenerated from source files by setuptools.
/*.egg-info
/*.egg
sgdml/_bmark_cache.npz
================================================
FILE: LICENSE.txt
================================================
MIT License
Copyright (c) 2018-2022 Stefan Chmiela
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# Symmetric Gradient Domain Machine Learning (sGDML)
For more details visit: [sgdml.org](http://sgdml.org/)
Documentation can be found here: [docs.sgdml.org](http://docs.sgdml.org/)
#### Requirements:
- Python 3.7+
- PyTorch (>=1.8)
- NumPy (>=1.19)
- SciPy (>=1.1)
#### Optional:
- ASE (>=3.16.2) (to run atomistic simulations)
## Getting started
### Stable release
Most systems come with the default package manager for Python ``pip`` already preinstalled. Install ``sgdml`` by simply calling:
```
$ pip install sgdml
```
The ``sgdml`` command-line interface and the corresponding Python API can now be used from anywhere on the system.
### Development version
#### (1) Clone the repository
```
$ git clone https://github.com/stefanch/sGDML.git
$ cd sGDML
```
...or update your existing local copy with
```
$ git pull origin master
```
#### (2) Install
```
$ pip install -e .
```
Using the flag ``--user``, you can tell ``pip`` to install the package to the current users's home directory, instead of system-wide. This option might require you to update your system's ``PATH`` variable accordingly.
### Optional dependencies
Some functionality of this package relies on third-party libraries that are not installed by default. These optional dependencies (or "package extras") are specified during installation using the "square bracket syntax":
```
$ pip install sgdml[<optional1>]
```
#### Atomic Simulation Environment (ASE)
If you are interested in interfacing with [ASE](https://wiki.fysik.dtu.dk/ase/) to perform atomistic simulations (see [here](http://docs.sgdml.org/applications.html) for examples), use the ``ase`` keyword:
```
$ pip install sgdml[ase]
```
## Reconstruct your first force field
Download one of the example datasets:
```
$ sgdml-get dataset ethanol_dft
```
Train a force field model:
```
$ sgdml all ethanol_dft.npz 200 1000 5000
```
## Query a force field
```python
import numpy as np
from sgdml.predict import GDMLPredict
from sgdml.utils import io
r,_ = io.read_xyz('geometries/ethanol.xyz') # 9 atoms
print(r.shape) # (1,27)
model = np.load('models/ethanol.npz')
gdml = GDMLPredict(model)
e,f = gdml.predict(r)
print(e.shape) # (1,)
print(f.shape) # (1,27)
```
## Authors
* Stefan Chmiela
* Jan Hermann
We appreciate and welcome contributions and would like to thank the following people for participating in this project:
* Huziel Sauceda
* Igor Poltavsky
* Luis Gálvez
* Danny Panknin
* Grégory Fonseca
* Anton Charkin-Gorbulin
## References
* [1] Chmiela, S., Tkatchenko, A., Sauceda, H. E., Poltavsky, I., Schütt, K. T., Müller, K.-R.,
*Machine Learning of Accurate Energy-conserving Molecular Force Fields.*
Science Advances, 3(5), e1603015 (2017)
[10.1126/sciadv.1603015](http://dx.doi.org/10.1126/sciadv.1603015)
* [2] Chmiela, S., Sauceda, H. E., Müller, K.-R., Tkatchenko, A.,
*Towards Exact Molecular Dynamics Simulations with Machine-Learned Force Fields.*
Nature Communications, 9(1), 3887 (2018)
[10.1038/s41467-018-06169-2](https://doi.org/10.1038/s41467-018-06169-2)
* [3] Chmiela, S., Sauceda, H. E., Poltavsky, I., Müller, K.-R., Tkatchenko, A.,
*sGDML: Constructing Accurate and Data Efficient Molecular Force Fields Using Machine Learning.*
Computer Physics Communications, 240, 38-45 (2019)
[10.1016/j.cpc.2019.02.007](https://doi.org/10.1016/j.cpc.2019.02.007)
* [4] Chmiela, S., Vassilev-Galindo, V., Unke, O. T., Kabylda, A., Sauceda, H. E., Tkatchenko, A., Müller, K.-R.,
*Accurate Global Machine Learning Force Fields for Molecules With Hundreds of Atoms.*
Science Advances, 9(2), e1603015 (2023)
[10.1126/sciadv.adf0873](https://doi.org/10.1126/sciadv.adf0873)
================================================
FILE: pyproject.toml
================================================
[tool.black]
skip-string-normalization = true
skip-numeric-underscore-normalization = true
================================================
FILE: scripts/sgdml_dataset_from_aims.py
================================================
#!/usr/bin/python
# MIT License
#
# Copyright (c) 2018-2022 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import print_function
import argparse
import os
import sys
import numpy as np
from sgdml.utils import io, ui
def read_reference_data(f): # noqa C901
eV_to_kcalmol = 0.036749326 / 0.0015946679
e_next, f_next, geo_next = False, False, False
n_atoms = None
R, z, E, F = [], [], [], []
geo_idx = 0
for line in f:
if n_atoms:
cols = line.split()
if e_next:
E.append(float(cols[5]))
e_next = False
elif f_next:
a = int(cols[1]) - 1
F.append(list(map(float, cols[2:5])))
if a == n_atoms - 1:
f_next = False
elif geo_next:
if 'atom' in cols:
a_count += 1 # noqa: F821
R.append(list(map(float, cols[1:4])))
if geo_idx == 0:
z.append(io._z_str_to_z_dict[cols[4]])
if a_count == n_atoms:
geo_next = False
geo_idx += 1
elif 'Energy and forces in a compact form:' in line:
e_next = True
elif 'Total atomic forces (unitary forces cleaned) [eV/Ang]:' in line:
f_next = True
elif (
'Atomic structure (and velocities) as used in the preceding time step:'
in line
):
geo_next = True
a_count = 0
elif 'The structure contains' in line and 'atoms, and a total of' in line:
n_atoms = int(line.split()[3])
print('Number atoms per geometry: {:>7d}'.format(n_atoms))
continue
if geo_idx > 0 and geo_idx % 1000 == 0:
sys.stdout.write("\rNumber geometries found so far: {:>7d}".format(geo_idx))
sys.stdout.flush()
sys.stdout.write("\rNumber geometries found so far: {:>7d}".format(geo_idx))
sys.stdout.flush()
print(
'\n'
+ ui.color_str('[INFO]', bold=True)
+ ' Energies and forces have been converted from eV to kcal/mol(/Ang)'
)
R = np.array(R).reshape(-1, n_atoms, 3)
z = np.array(z)
E = np.array(E) * eV_to_kcalmol
F = np.array(F).reshape(-1, n_atoms, 3) * eV_to_kcalmol
f.close()
return (R, z, E, F)
parser = argparse.ArgumentParser(description='Creates a dataset from FHI-aims format.')
parser.add_argument(
'dataset',
metavar='<dataset>',
type=argparse.FileType('r'),
help='path to xyz dataset file',
)
parser.add_argument(
'-o',
'--overwrite',
dest='overwrite',
action='store_true',
help='overwrite existing dataset file',
)
args = parser.parse_args()
dataset = args.dataset
name = os.path.splitext(os.path.basename(dataset.name))[0]
dataset_file_name = name + '.npz'
dataset_exists = os.path.isfile(dataset_file_name)
if dataset_exists and args.overwrite:
print(ui.color_str('[INFO]', bold=True) + ' Overwriting existing dataset file.')
if not dataset_exists or args.overwrite:
print('Writing dataset to \'%s\'...' % dataset_file_name)
else:
sys.exit(
ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) + ' Dataset \'%s\' already exists.' % dataset_file_name
)
R, z, E, F = read_reference_data(dataset)
# Prune all arrays to same length.
n_mols = min(min(R.shape[0], F.shape[0]), E.shape[0])
if n_mols != R.shape[0] or n_mols != F.shape[0] or n_mols != E.shape[0]:
print(
ui.color_str('[WARN]', fore_color=ui.YELLOW, bold=True)
+ ' Incomplete output detected: Final dataset was pruned to %d points.' % n_mols
)
R = R[:n_mols, :, :]
F = F[:n_mols, :, :]
E = E[:n_mols]
# Base variables contained in every model file.
base_vars = {
'type': 'd',
'R': R,
'z': z,
'E': E[:, None],
'F': F,
'e_unit': 'kcal/mol',
'r_unit': 'Ang',
'name': name,
'theory': 'unknown',
}
base_vars['F_min'], base_vars['F_max'] = np.min(F.ravel()), np.max(F.ravel())
base_vars['F_mean'], base_vars['F_var'] = np.mean(F.ravel()), np.var(F.ravel())
base_vars['E_min'], base_vars['E_max'] = np.min(E), np.max(E)
base_vars['E_mean'], base_vars['E_var'] = np.mean(E), np.var(E)
base_vars['md5'] = io.dataset_md5(base_vars)
np.savez_compressed(dataset_file_name, **base_vars)
print(ui.color_str('DONE', fore_color=ui.GREEN, bold=True))
================================================
FILE: scripts/sgdml_dataset_from_extxyz.py
================================================
#!/usr/bin/python
# MIT License
#
# Copyright (c) 2018-2022 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import print_function
import argparse
import os
import sys
try:
from ase.io import read
except ImportError:
raise ImportError('Optional ASE dependency not found! Please run \'pip install sgdml[ase]\' to install it.')
import numpy as np
from sgdml import __version__
from sgdml.utils import io, ui
if sys.version[0] == '3':
raw_input = input
# Note: assumes that the atoms in each molecule are in the same order.
def read_nonstd_ext_xyz(f):
n_atoms = None
R, z, E, F = [], [], [], []
for i, line in enumerate(f):
line = line.strip()
if not n_atoms:
n_atoms = int(line)
print('Number atoms per geometry: {:,}'.format(n_atoms))
file_i, line_i = divmod(i, n_atoms + 2)
if line_i == 1:
try:
e = float(line)
except ValueError:
pass
else:
E.append(e)
cols = line.split()
if line_i >= 2:
R.append(list(map(float, cols[1:4])))
if file_i == 0: # first molecule
z.append(io._z_str_to_z_dict[cols[0]])
F.append(list(map(float, cols[4:7])))
if file_i % 1000 == 0:
sys.stdout.write('\rNumber geometries found so far: {:,}'.format(file_i))
sys.stdout.flush()
sys.stdout.write('\rNumber geometries found so far: {:,}'.format(file_i))
sys.stdout.flush()
print()
R = np.array(R).reshape(-1, n_atoms, 3)
z = np.array(z)
E = None if not E else np.array(E)
F = np.array(F).reshape(-1, n_atoms, 3)
if F.shape[0] != R.shape[0]:
sys.exit(
ui.color_str('[FAIL]', fore_color=ui.RED, bold=True)
+ ' Force labels are missing from dataset or are incomplete!'
)
f.close()
return (R, z, E, F)
# Extracts info string for each frame.
def extract_info_from_extxyz(file_path):
infos = []
with open(file_path) as f:
lines = f.readlines()
i = 0
while i < len(lines):
try:
n_atoms = int(lines[i])
except ValueError:
raise ValueError(f"Invalid atom count at line {i + 1}")
if i + 1 >= len(lines):
break
comment_line = lines[i + 1].strip()
info = {}
for token in comment_line.split():
if "=" in token:
key, val = token.split("=", 1)
val = val.strip('"')
try:
val = float(val)
except ValueError:
pass
info[key] = val
infos.append(info)
i += 2 + n_atoms
return infos
parser = argparse.ArgumentParser(
description='Creates a dataset from extended XYZ format.'
)
parser.add_argument(
'dataset',
metavar='<dataset>',
type=argparse.FileType('r'),
help='path to extended xyz dataset file',
)
parser.add_argument(
'-o',
'--overwrite',
dest='overwrite',
action='store_true',
help='overwrite existing dataset file',
)
args = parser.parse_args()
dataset = args.dataset
name = os.path.splitext(os.path.basename(dataset.name))[0]
dataset_file_name = name + '.npz'
dataset_exists = os.path.isfile(dataset_file_name)
if dataset_exists and args.overwrite:
print(ui.color_str('[INFO]', bold=True) + ' Overwriting existing dataset file.')
if not dataset_exists or args.overwrite:
print('Writing dataset to \'{}\'...'.format(dataset_file_name))
else:
sys.exit(
ui.color_str('[FAIL]', fore_color=ui.RED, bold=True)
+ ' Dataset \'{}\' already exists.'.format(dataset_file_name)
)
lattice, R, z, E, F = None, None, None, None, None
mols = read(dataset.name, format='extxyz', index=':')
#calc = mols[0].get_calculator() # depreciated
calc = mols[0].calc
is_extxyz = calc is not None
if is_extxyz:
print("\rNumber geometries found: {:,}\n".format(len(mols)))
if 'forces' not in calc.results:
sys.exit(
ui.color_str('[FAIL]', fore_color=ui.RED, bold=True)
+ ' Forces are missing in the input file!'
)
lattice = np.array(mols[0].get_cell().T)
if not np.any(lattice): # all zeros
print(
ui.color_str('[INFO]', bold=True)
+ ' No lattice vectors specified in extended XYZ file.'
)
lattice = None
Z = np.array([mol.get_atomic_numbers() for mol in mols])
all_z_the_same = (Z == Z[0]).all()
if not all_z_the_same:
sys.exit(
ui.color_str('[FAIL]', fore_color=ui.RED, bold=True)
+ ' Order of atoms changes accross dataset.'
)
R = np.array([mol.get_positions() for mol in mols])
z = Z[0]
# ASE did not parse info string. Try doing it manually.
if not mols[0].info:
print(
ui.color_str('[INFO]', bold=True)
+ ' ASE did not parse info string completely. Try doing it manually.'
)
infos = extract_info_from_extxyz(dataset.name)
for mol, info in zip(mols, infos):
mol.info.update(info)
if 'Energy' in mols[0].info:
E = np.array([mol.info['Energy'] for mol in mols])
if 'energy' in mols[0].info:
E = np.array([mol.info['energy'] for mol in mols])
F = np.array([mol.get_forces() for mol in mols])
else: # legacy non-standard XYZ format
with open(dataset.name) as f:
R, z, E, F = read_nonstd_ext_xyz(f)
# Base variables contained in every model file.
base_vars = {
'type': 'd',
'code_version': __version__,
'name': name,
'theory': 'unknown',
'R': R,
'z': z,
'F': F,
}
base_vars['F_min'], base_vars['F_max'] = np.min(F.ravel()), np.max(F.ravel())
base_vars['F_mean'], base_vars['F_var'] = np.mean(F.ravel()), np.var(F.ravel())
print('Please provide a description of the length unit used in your input file, e.g. \'Ang\' or \'au\': ')
print('Note: This string will be stored in the dataset file and passed on to models files for later reference.')
r_unit = raw_input('> ').strip()
if r_unit != '':
base_vars['r_unit'] = r_unit
print('Please provide a description of the energy unit used in your input file, e.g. \'kcal/mol\' or \'eV\': ')
print('Note: This string will be stored in the dataset file and passed on to models files for later reference.')
e_unit = raw_input('> ').strip()
if e_unit != '':
base_vars['e_unit'] = e_unit
if E is not None:
base_vars['E'] = E
base_vars['E_min'], base_vars['E_max'] = np.min(E), np.max(E)
base_vars['E_mean'], base_vars['E_var'] = np.mean(E), np.var(E)
else:
print(ui.color_str('[INFO]', bold=True) + ' No energy labels found in dataset.')
if lattice is not None:
base_vars['lattice'] = lattice
base_vars['md5'] = io.dataset_md5(base_vars)
np.savez_compressed(dataset_file_name, **base_vars)
print(ui.color_str('[DONE]', fore_color=ui.GREEN, bold=True))
================================================
FILE: scripts/sgdml_dataset_from_ipi.py
================================================
#!/usr/bin/python
# MIT License
#
# Copyright (c) 2018 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import print_function
import argparse
import os
import sys
import numpy as np
from sgdml.utils import io, ui
def raw_input_float(prompt):
while True:
try:
return float(input(prompt))
except ValueError:
print(ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) + ' That is not a valid float.')
# Assumes that the atoms in each molecule are in the same order.
def read_concat_xyz(f):
n_atoms = None
R, z = [], []
for i, line in enumerate(f):
line = line.strip()
if not n_atoms:
n_atoms = int(line)
print('Number atoms per geometry: {:>7d}'.format(n_atoms))
file_i, line_i = divmod(i, n_atoms + 2)
cols = line.split()
if line_i >= 2:
if file_i == 0: # first molecule
z.append(io._z_str_to_z_dict[cols[0]])
R.append(list(map(float, cols[1:4])))
if file_i % 1000 == 0:
sys.stdout.write("\rNumber geometries found so far: {:>7d}".format(file_i))
sys.stdout.flush()
sys.stdout.write("\rNumber geometries found so far: {:>7d}\n".format(file_i))
sys.stdout.flush()
# Only keep complete entries.
R = R[: int(n_atoms * np.floor(len(R) / float(n_atoms)))]
R = np.array(R).reshape(-1, n_atoms, 3)
z = np.array(z)
f.close()
return (R, z)
def read_out_file(f, col):
E = []
for i, line in enumerate(f):
line = line.strip()
if line[0] != '#': # Ignore comments.
E.append(float(line.split()[col]))
if i % 1000 == 0:
sys.stdout.write("\rNumber lines processed so far: {:>7d}".format(len(E)))
sys.stdout.flush()
sys.stdout.write("\rNumber lines processed so far: {:>7d}\n".format(len(E)))
sys.stdout.flush()
return np.array(E)
parser = argparse.ArgumentParser(
description='Creates a dataset from extended [TODO] format.'
)
parser.add_argument(
'geometries',
metavar='<geometries>',
type=argparse.FileType('r'),
help='path to XYZ geometry file',
)
parser.add_argument(
'forces',
metavar='<forces>',
type=argparse.FileType('r'),
help='path to XYZ force file',
)
parser.add_argument(
'energies',
metavar='<energies>',
type=argparse.FileType('r'),
help='path to CSV force file',
)
parser.add_argument(
'energy_col',
metavar='<energy_col>',
type=lambda x: io.is_strict_pos_int(x),
help='which column to parse from energy file (zero based)',
nargs='?',
default=0,
)
parser.add_argument(
'-o',
'--overwrite',
dest='overwrite',
action='store_true',
help='overwrite existing dataset file',
)
args = parser.parse_args()
geometries = args.geometries
forces = args.forces
energies = args.energies
energy_col = args.energy_col
name = os.path.splitext(os.path.basename(geometries.name))[0]
dataset_file_name = name + '.npz'
dataset_exists = os.path.isfile(dataset_file_name)
if dataset_exists and args.overwrite:
print(ui.color_str('[INFO]', bold=True) + ' Overwriting existing dataset file.')
if not dataset_exists or args.overwrite:
print('Writing dataset to \'%s\'...' % dataset_file_name)
else:
sys.exit(
ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) + ' Dataset \'%s\' already exists.' % dataset_file_name
)
print('Reading geometries...')
R, z = read_concat_xyz(geometries)
print('Reading forces...')
F, _ = read_concat_xyz(forces)
print('Reading energies from column %d...' % energy_col)
E = read_out_file(energies, energy_col)
# Prune all arrays to same length.
n_mols = min(min(R.shape[0], F.shape[0]), E.shape[0])
if n_mols != R.shape[0] or n_mols != F.shape[0] or n_mols != E.shape[0]:
print(
ui.color_str('[WARN]', fore_color=ui.YELLOW, bold=True)
+ ' Incomplete output detected: Final dataset was pruned to %d points.' % n_mols
)
R = R[:n_mols, :, :]
F = F[:n_mols, :, :]
E = E[:n_mols]
print(
ui.color_str('[INFO]', bold=True)
+ ' Geometries, forces and energies must have consistent units.'
)
R_conv_fact = raw_input_float('Unit conversion factor for geometries: ')
R = R * R_conv_fact
F_conv_fact = raw_input_float('Unit conversion factor for forces: ')
F = F * F_conv_fact
E_conv_fact = raw_input_float('Unit conversion factor for energies: ')
E = E * E_conv_fact
# Base variables contained in every model file.
base_vars = {
'type': 'd',
'R': R,
'z': z,
'E': E[:, None],
'F': F,
'name': name,
'theory': 'unknown',
}
base_vars['md5'] = io.dataset_md5(base_vars)
np.savez_compressed(dataset_file_name, **base_vars)
ui.color_str('[DONE]', fore_color=ui.GREEN, bold=True)
================================================
FILE: scripts/sgdml_dataset_to_extxyz.py
================================================
#!/usr/bin/python
# MIT License
#
# Copyright (c) 2018-2019 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import print_function
import argparse
import os
import sys
import numpy as np
from sgdml.utils import io, ui
parser = argparse.ArgumentParser(
description='Converts a native dataset file to extended XYZ format.'
)
parser.add_argument(
'dataset',
metavar='<dataset>',
type=lambda x: io.is_file_type(x, 'dataset'),
help='path to dataset file',
)
parser.add_argument(
'-o',
'--overwrite',
dest='overwrite',
action='store_true',
help='overwrite existing xyz dataset file',
)
args = parser.parse_args()
dataset_path, dataset = args.dataset
name = os.path.splitext(os.path.basename(dataset_path))[0]
dataset_file_name = name + '.xyz'
xyz_exists = os.path.isfile(dataset_file_name)
if xyz_exists and args.overwrite:
print(ui.color_str('[INFO]', bold=True) + ' Overwriting existing xyz dataset file.')
if not xyz_exists or args.overwrite:
print(ui.color_str('[INFO]', bold=True) + ' Writing dataset to \'{}\'...'.format(dataset_file_name))
else:
sys.exit(
ui.color_str('[FAIL]', fore_color=ui.RED, bold=True) + ' Dataset \'{}\' already exists.'.format(dataset_file_name)
)
R = dataset['R']
z = dataset['z']
F = dataset['F']
lattice = dataset['lattice'] if 'lattice' in dataset else None
try:
with open(dataset_file_name, 'w') as file:
n = R.shape[0]
for i, r in enumerate(R):
e = np.squeeze(dataset['E'][i]) if 'E' in dataset else None
f = dataset['F'][i,:,:]
ext_xyz_str = io.generate_xyz_str(r, z, e=e, f=f, lattice=lattice) + '\n'
file.write(ext_xyz_str)
progr = float(i) / (n - 1)
ui.callback(i, n - 1, disp_str='Exporting %d data points...' % n)
except IOError:
sys.exit("ERROR: Writing xyz file failed.")
print()
================================================
FILE: scripts/sgdml_dataset_via_ase.py
================================================
#!/usr/bin/python
# MIT License
#
# Copyright (c) 2018-2022 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import print_function
import argparse
import os
import sys
try:
from ase.io import read
except ImportError:
raise ImportError('Optional ASE dependency not found! Please run \'pip install sgdml[ase]\' to install it.')
import numpy as np
from sgdml import __version__
from sgdml.utils import io, ui
if sys.version[0] == '3':
raw_input = input
parser = argparse.ArgumentParser(
description='Creates a dataset from any input format supported by ASE.'
)
parser.add_argument(
'dataset',
metavar='<dataset>',
type=argparse.FileType('r'),
help='path to input dataset file',
)
parser.add_argument(
'-o',
'--overwrite',
dest='overwrite',
action='store_true',
help='overwrite existing dataset file',
)
args = parser.parse_args()
dataset = args.dataset
name = os.path.splitext(os.path.basename(dataset.name))[0]
dataset_file_name = name + '.npz'
dataset_exists = os.path.isfile(dataset_file_name)
if dataset_exists and args.overwrite:
print(ui.color_str('[INFO]', bold=True) + ' Overwriting existing dataset file.')
if not dataset_exists or args.overwrite:
print('Writing dataset to \'{}\'...'.format(dataset_file_name))
else:
sys.exit(
ui.color_str('[FAIL]', fore_color=ui.RED, bold=True)
+ ' Dataset \'{}\' already exists.'.format(dataset_file_name)
)
mols = read(dataset.name, index=':')
# filter incomplete outputs from trajectory
mols = [mol for mol in mols if mol.get_calculator() is not None]
lattice, R, z, E, F = None, None, None, None, None
calc = mols[0].get_calculator()
print("\rNumber geometries: {:,}".format(len(mols)))
#print("\rAvailable properties: " + ', '.join(calc.results))
print()
if 'forces' not in calc.results:
sys.exit(
ui.color_str('[FAIL]', fore_color=ui.RED, bold=True)
+ ' Forces are missing in the input file!'
)
lattice = np.array(mols[0].get_cell().T)
if not np.any(lattice):
print(
ui.color_str('[INFO]', bold=True)
+ ' No lattice vectors specified.'
)
lattice = None
Z = np.array([mol.get_atomic_numbers() for mol in mols])
all_z_the_same = (Z == Z[0]).all()
if not all_z_the_same:
sys.exit(
ui.color_str('[FAIL]', fore_color=ui.RED, bold=True)
+ ' Order of atoms changes accross dataset.'
)
R = np.array([mol.get_positions() for mol in mols])
z = Z[0]
if 'Energy' in mols[0].info:
E = np.array([float(mol.info['Energy']) for mol in mols])
else:
E = np.array([mol.get_potential_energy() for mol in mols])
F = np.array([mol.get_forces() for mol in mols])
print('Please provide a name for this dataset. Otherwise the original filename will be reused.')
custom_name = raw_input('> ').strip()
if custom_name != '':
name = custom_name
print('Please provide a descriptor for the level of theory used to create this dataset.')
theory = raw_input('> ').strip()
if theory == '':
theory = 'unknown'
# Base variables contained in every model file.
base_vars = {
'type': 'd',
'code_version': __version__,
'name': name,
'theory': theory,
'R': R,
'z': z,
'F': F,
}
base_vars['F_min'], base_vars['F_max'] = np.min(F.ravel()), np.max(F.ravel())
base_vars['F_mean'], base_vars['F_var'] = np.mean(F.ravel()), np.var(F.ravel())
print('If you want to convert your original length unit, please provide a conversion factor (default: 1.0): ')
R_to_new_unit = raw_input('> ').strip()
if R_to_new_unit != '':
R_to_new_unit = float(R_to_new_unit)
else:
R_to_new_unit = 1.0
print('If you want to convert your original energy unit, please provide a conversion factor (default: 1.0): ')
E_to_new_unit = raw_input('> ').strip()
if E_to_new_unit != '':
E_to_new_unit = float(E_to_new_unit)
else:
E_to_new_unit = 1.0
print('Please provide a description of the length unit, e.g. \'Ang\' or \'au\': ')
print('Note: This string will be stored in the dataset file and passed on to models files for later reference.')
r_unit = raw_input('> ').strip()
if r_unit != '':
base_vars['r_unit'] = r_unit
print('Please provide a description of the energy unit, e.g. \'kcal/mol\' or \'eV\': ')
print('Note: This string will be stored in the dataset file and passed on to models files for later reference.')
e_unit = raw_input('> ').strip()
if e_unit != '':
base_vars['e_unit'] = e_unit
if E is not None:
base_vars['E'] = E * E_to_new_unit
base_vars['E_min'], base_vars['E_max'] = np.min(E), np.max(E)
base_vars['E_mean'], base_vars['E_var'] = np.mean(E), np.var(E)
else:
print(ui.color_str('[INFO]', bold=True) + ' No energy labels found in dataset.')
base_vars['R'] *= R_to_new_unit
base_vars['F'] *= E_to_new_unit / R_to_new_unit
if lattice is not None:
base_vars['lattice'] = lattice
base_vars['md5'] = io.dataset_md5(base_vars)
np.savez_compressed(dataset_file_name, **base_vars)
print(ui.color_str('[DONE]', fore_color=ui.GREEN, bold=True))
================================================
FILE: scripts/sgdml_datasets_from_model.py
================================================
#!/usr/bin/python
# MIT License
#
# Copyright (c) 2018 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import print_function
import argparse
import os
import sys
import numpy as np
from sgdml.utils import io, ui
parser = argparse.ArgumentParser(
description='Extracts the training and test data subsets from a dataset that were used to construct a model.'
)
parser.add_argument(
'model',
metavar='<model_file>',
type=lambda x: io.is_file_type(x, 'model'),
help='path to model file',
)
parser.add_argument(
'dataset',
metavar='<dataset_file>',
type=lambda x: io.is_file_type(x, 'dataset'),
help='path to dataset file referenced in model',
)
parser.add_argument(
'-o',
'--overwrite',
dest='overwrite',
action='store_true',
help='overwrite existing files',
)
args = parser.parse_args()
model_path, model = args.model
dataset_path, dataset = args.dataset
for s in ['train', 'valid']:
if dataset['md5'] != model['md5_' + s]:
sys.exit(
ui.fail_str('[FAIL]')
+ ' Dataset fingerprint does not match the one referenced in model for \'%s\'.'
% s
)
idxs = model['idxs_' + s]
R = dataset['R'][idxs, :, :]
E = dataset['E'][idxs]
F = dataset['F'][idxs, :, :]
base_vars = {
'type': 'd',
'name': dataset['name'].astype(str),
'theory': dataset['theory'].astype(str),
'z': dataset['z'],
'R': R,
'E': E,
'F': F,
}
base_vars['md5'] = io.dataset_md5(base_vars)
subset_file_name = '%s_%s.npz' % (
os.path.splitext(os.path.basename(dataset_path))[0],
s,
)
file_exists = os.path.isfile(subset_file_name)
if file_exists and args.overwrite:
print(ui.info_str('[INFO]') + ' Overwriting existing model file.')
if not file_exists or args.overwrite:
np.savez_compressed(subset_file_name, **base_vars)
ui.callback(1, disp_str='Extracted %s dataset saved to \'%s\'' % (s, subset_file_name)) # DONE
else:
print(
ui.warn_str('[WARN]')
+ ' %s dataset \'%s\' already exists.' % (s.capitalize(), subset_file_name)
+ '\n Run \'python %s -o %s %s\' to overwrite.\n'
% (os.path.basename(__file__), model_path, dataset_path)
)
sys.exit()
================================================
FILE: setup.cfg
================================================
[flake8]
max-complexity = 12
ignore = E501,W503,E741
select = C,E,F,W
[isort]
multi_line_output = 3
include_trailing_comma = 1
line_length = 85
sections = FUTURE,STDLIB,TYPING,THIRDPARTY,FIRSTPARTY,LOCALFOLDER
known_typing = typing, typing_extensions
no_lines_before = TYPING
================================================
FILE: setup.py
================================================
import os
import re
from io import open
from setuptools import setup, find_packages
def get_property(property, package):
result = re.search(
r'{}\s*=\s*[\'"]([^\'"]*)[\'"]'.format(property),
open(package + '/__init__.py').read(),
)
return result.group(1)
from os import path
this_dir = path.abspath(path.dirname(__file__))
with open(path.join(this_dir, 'README.md'), encoding='utf8') as f:
long_description = f.read()
# Scripts
scripts = []
for dirname, dirnames, filenames in os.walk('scripts'):
for filename in filenames:
if filename.endswith('.py'):
scripts.append(os.path.join(dirname, filename))
setup(
name='sgdml',
version=get_property('__version__', 'sgdml'),
description='Reference implementation of the GDML and sGDML force field models.',
long_description=long_description,
long_description_content_type='text/markdown',
classifiers=[
'Development Status :: 4 - Beta',
'Environment :: Console',
'Intended Audience :: Science/Research',
'Intended Audience :: Education',
'Intended Audience :: Developers',
'License :: OSI Approved :: MIT License',
'Operating System :: MacOS :: MacOS X',
'Operating System :: POSIX :: Linux',
'Programming Language :: Python :: 3.7',
'Topic :: Scientific/Engineering :: Chemistry',
'Topic :: Scientific/Engineering :: Physics',
'Topic :: Software Development :: Libraries :: Python Modules',
],
url='http://www.sgdml.org',
author='Stefan Chmiela',
author_email='sgdml@chmiela.com',
license='LICENSE.txt',
packages=find_packages(),
install_requires=['torch >= 1.8', 'numpy >= 1.19.0', 'scipy >= 1.1.0', 'psutil', 'future'],
entry_points={
'console_scripts': ['sgdml=sgdml.cli:main', 'sgdml-get=sgdml.get:main']
},
extras_require={'ase': ['ase >= 3.16.2']},
scripts=scripts,
include_package_data=True,
zip_safe=False,
)
================================================
FILE: sgdml/__init__.py
================================================
#!/usr/bin/python
# MIT License
#
# Copyright (c) 2019-2025 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
__version__ = '1.0.3'
MAX_PRINT_WIDTH = 100
LOG_LEVELNAME_WIDTH = 7 # do not modify
# more descriptive callback status
DONE = 1
NOT_DONE = 0
# Logging
import copy
import logging
import re
import textwrap
from .utils import ui
class ColoredFormatter(logging.Formatter):
LEVEL_COLORS = {
'DEBUG': (ui.CYAN, ui.BLACK),
'INFO': (ui.WHITE, ui.BLACK),
'DONE': (ui.GREEN, ui.BLACK),
'WARNING': (ui.YELLOW, ui.BLACK),
'ERROR': (ui.RED, ui.BLACK),
'CRITICAL': (ui.BLACK, ui.RED),
}
LEVEL_NAMES = {
'DEBUG': '[DEBG]',
'INFO': '[INFO]',
'DONE': '[DONE]',
'WARNING': '[WARN]',
'ERROR': '[FAIL]',
'CRITICAL': '[CRIT]',
}
def __init__(self, msg, use_color=True):
logging.Formatter.__init__(self, msg)
self.use_color = use_color
def format(self, record):
_record = copy.copy(record)
levelname = _record.levelname
msg = _record.msg
levelname = ui.color_str(
self.LEVEL_NAMES[levelname],
self.LEVEL_COLORS[levelname][0],
self.LEVEL_COLORS[levelname][1],
bold=True,
)
if _record.levelname != 'CRITICAL':
# wrap long messages (except for critical [i.e. exceptions, since they print a formatted traceback string])
msg = ui.wrap_str(msg)
# indent multiline strings after the first line
msg = ui.indent_str(msg, LOG_LEVELNAME_WIDTH)[LOG_LEVELNAME_WIDTH:]
_record.levelname = levelname
_record.msg = msg
return logging.Formatter.format(self, _record)
class ColoredLogger(logging.Logger):
def __init__(self, name):
logging.Logger.__init__(self, name, logging.DEBUG)
# add 'DONE' logging level
logging.DONE = logging.INFO + 1
logging.addLevelName(logging.DONE, 'DONE')
# only display levelname and message
formatter = ColoredFormatter('%(levelname)s %(message)s')
# this handler will write to sys.stderr by default
hd = logging.StreamHandler()
hd.setFormatter(formatter)
hd.setLevel(
logging.INFO
) # control logging level here
self.addHandler(hd)
return
def done(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.DONE):
self._log(logging.DONE, msg, args, **kwargs)
logging.setLoggerClass(ColoredLogger)
================================================
FILE: sgdml/cli.py
================================================
#!/usr/bin/python
# MIT License
#
# Copyright (c) 2018-2022 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import print_function
import logging
import multiprocessing as mp
import argparse
import os
import shutil
import psutil
import sys
import traceback
import time
from functools import partial
import numpy as np
import scipy as sp
try:
import torch
except ImportError:
_has_torch = False
else:
_has_torch = True
try:
_torch_mps_is_available = torch.backends.mps.is_available()
except AttributeError:
_torch_mps_is_available = False
_torch_mps_is_available = False
try:
_torch_cuda_is_available = torch.cuda.is_available()
except AttributeError:
_torch_cuda_is_available = False
try:
import ase
except ImportError:
_has_ase = False
else:
_has_ase = True
from . import __version__, DONE, NOT_DONE, MAX_PRINT_WIDTH
from .predict import GDMLPredict
from .train import GDMLTrain
from .utils import io, ui
# BASE_DIR = os.path.dirname(os.path.abspath(__file__))
PACKAGE_NAME = 'sgdml'
log = logging.getLogger(__name__)
class AssistantError(Exception):
pass
def _print_splash(max_memory, max_processes, use_torch):
logo_str = r""" __________ __ _____
_____/ ____/ __ \/ |/ / /
/ ___/ / __/ / / / /|_/ / /
(__ ) /_/ / /_/ / / / / /___
/____/\____/_____/_/ /_/_____/"""
can_update, latest_version = _check_update()
version_str = __version__
version_str += (
' '
+ ui.color_str(
' Latest: ' + latest_version + ' ',
fore_color=ui.BLACK,
back_color=ui.YELLOW,
bold=True,
)
if can_update
else ''
)
max_memory_str = '{:d} GB(s) memory'.format(max_memory)
max_processes_str = '{:d} CPU(s)'.format(max_processes)
hardware_str = 'using {}, {}'.format(max_memory_str, max_processes_str)
if use_torch and _has_torch:
if _torch_cuda_is_available:
num_gpu = torch.cuda.device_count()
if num_gpu > 0:
hardware_str += ', {:d} GPU(s)'.format(num_gpu)
elif _torch_mps_is_available:
hardware_str += ', MPS enabled'
logo_str_split = logo_str.splitlines()
print('\n'.join(logo_str_split[:-1]))
ui.print_two_column_str(logo_str_split[-1] + ' ' + version_str, hardware_str)
# Print update notice.
if can_update:
print(
'\n'
+ ui.color_str(
' UPDATE AVAILABLE ',
fore_color=ui.BLACK,
back_color=ui.YELLOW,
bold=True,
)
+ '\n'
+ '-' * MAX_PRINT_WIDTH
)
print(
'A new stable release version {} of this software is available.'.format(
latest_version
)
)
print(
'You can update your installation by running \'pip install sgdml --upgrade\'.'
)
_print_billboard()
def _check_update():
try:
from urllib.request import urlopen
except ImportError:
from urllib2 import urlopen
base_url = 'http://api.sgdml.org/'
url = '{}update.php?v={}'.format(base_url, __version__)
can_update, must_update = '0', '0'
latest_version = ''
try:
response = urlopen(url, timeout=1)
can_update, must_update, latest_version = response.read().decode().split(',')
response.close()
except:
pass
return can_update == '1', latest_version
def _print_billboard():
try:
from urllib.request import urlopen
except ImportError:
from urllib2 import urlopen
base_url = 'http://api.sgdml.org/'
url = '{}billboard.php'.format(base_url)
resp_str = ''
try:
response = urlopen(url, timeout=1)
resp_str = response.read().decode()
response.close()
except:
pass
bbs = None
try:
import json
bbs = json.loads(resp_str)
except:
pass
if bbs is not None:
for bb in bbs:
back_color = ui.WHITE
if bb['color'] == 'YELLOW':
back_color = ui.YELLOW
elif bb['color'] == 'GREEN':
back_color = ui.GREEN
elif bb['color'] == 'RED':
back_color = ui.RED
elif bb['color'] == 'CYAN':
back_color = ui.CYAN
print(
'\n'
+ ui.color_str(
' {} '.format(bb['title']),
fore_color=ui.BLACK,
back_color=back_color,
bold=True,
)
+ '\n'
+ '-' * MAX_PRINT_WIDTH
)
print(ui.wrap_str(bb['text'], width=MAX_PRINT_WIDTH - 2))
def _print_dataset_properties(dataset, title_str='Dataset properties'):
print(ui.color_str(title_str, bold=True))
n_mols, n_atoms, _ = dataset['R'].shape
print(' {:<18} \'{}\''.format('Name:', ui.unicode_str(dataset['name'])))
print(' {:<18} \'{}\''.format('Theory level:', ui.unicode_str(dataset['theory'])))
print(' {:<18} {:<d}'.format('Atoms:', n_atoms))
print(' {:<18} {:,} data points'.format('Size:', n_mols))
ui.print_lattice(dataset['lattice'] if 'lattice' in dataset else None)
if 'perms' in dataset:
ui.print_two_column_str(
' {:<18} {}'.format('Symmetries:', len(dataset['perms'])),
'This dataset contains precomputed permutations.',
)
if 'E' in dataset:
e_unit = 'unknown unit'
if 'e_unit' in dataset:
e_unit = ui.unicode_str(dataset['e_unit'])
print(' Energies [{}]'.format(e_unit))
if 'E_min' in dataset and 'E_max' in dataset:
E_min, E_max = dataset['E_min'], dataset['E_max']
else:
E_min, E_max = np.min(dataset['E']), np.max(dataset['E'])
E_range_str = ui.gen_range_str(E_min, E_max)
ui.print_two_column_str(
' {:<16} {}'.format('Range:', E_range_str), 'min |-- range --| max'
)
E_mean = dataset['E_mean'] if 'E_mean' in dataset else np.mean(dataset['E'])
print(' {:<16} {:<.3f}'.format('Mean:', E_mean))
E_var = dataset['E_var'] if 'E_var' in dataset else np.var(dataset['E'])
print(' {:<16} {:<.3f}'.format('Variance:', E_var))
else:
print(' {:<18} {}'.format('Energies:', 'n/a'))
f_unit = 'unknown unit'
if 'r_unit' in dataset and 'e_unit' in dataset:
f_unit = (
ui.unicode_str(dataset['e_unit']) + '/' + ui.unicode_str(dataset['r_unit'])
)
print(' Forces [{}]'.format(f_unit))
if 'F_min' in dataset and 'F_max' in dataset:
F_min, F_max = dataset['F_min'], dataset['F_max']
else:
F_min, F_max = np.min(dataset['F'].ravel()), np.max(dataset['F'].ravel())
F_range_str = ui.gen_range_str(F_min, F_max)
ui.print_two_column_str(
' {:<16} {}'.format('Range:', F_range_str), 'min |-- range --| max'
)
F_mean = dataset['F_mean'] if 'F_mean' in dataset else np.mean(dataset['F'].ravel())
print(' {:<16} {:<.3f}'.format('Mean:', F_mean))
F_var = dataset['F_var'] if 'F_var' in dataset else np.var(dataset['F'].ravel())
print(' {:<16} {:<.3f}'.format('Variance:', F_var))
print(' {:<18} {}'.format('Fingerprint:', ui.unicode_str(dataset['md5'])))
# if 'code_version' in dataset:
# print(' {:<18} sGDML {}'.format('Created with:', ui.unicode_str(dataset['code_version'])))
idx = np.random.choice(n_mols, 1)[0]
r = dataset['R'][idx, :, :]
e = np.squeeze(dataset['E'][idx]) if 'E' in dataset else None
f = dataset['F'][idx, :, :]
lattice = dataset['lattice'] if 'lattice' in dataset else None
print(
'\n'
+ ui.color_str('Example geometry', fore_color=ui.WHITE, bold=True)
+ ' (point no. {:,}, chosen randomly)'.format(idx + 1)
)
xyz_info_str = 'Copy & paste the string below into Jmol (www.jmol.org), Avogadro (www.avogadro.cc), etc. to visualize one of the geometries from this dataset. A new example will be drawn on each run.'
xyz_info_str = ui.wrap_str(xyz_info_str, width=MAX_PRINT_WIDTH - 2)
xyz_info_str = ui.indent_str(xyz_info_str, 2)
print(xyz_info_str + '\n')
xyz_str = io.generate_xyz_str(r, dataset['z'], e=e, f=f, lattice=lattice)
xyz_str = ui.indent_str(xyz_str, 2)
cut_str = '---- COPY HERE '
cut_str_reps = int(np.floor((MAX_PRINT_WIDTH - 6) / len(cut_str)))
cutline_str = ui.color_str(
' -' + cut_str * cut_str_reps + '-----', fore_color=ui.GRAY
)
print(cutline_str)
print(xyz_str)
print(cutline_str)
def _print_task_properties_reduced(
use_sym, use_E, use_E_cstr, title_str='Task properties'
):
print(ui.color_str(title_str, bold=True))
energy_fix_str = (
(
'pointwise energy constraints'
if use_E_cstr
else 'global integration constant'
)
if use_E
else 'none'
)
print(' {:<16} {}'.format('Energy offset:', energy_fix_str))
print(
' {:<16} {}'.format(
'Symmetries:', 'include (sGDML)' if use_sym else 'ignore (GDML)'
)
)
def _print_task_properties(task, title_str='Task properties'):
print(ui.color_str(title_str, bold=True))
print(' {:<18}'.format('Dataset'))
print(' {:<16} \'{}\''.format('Name:', ui.unicode_str(task['dataset_name'])))
print(
' {:<16} \'{}\''.format(
'Theory level:', ui.unicode_str(task['dataset_theory'])
)
)
n_atoms = len(task['z'])
print(' {:<16} {:<d}'.format('Atoms:', n_atoms))
ui.print_lattice(task['lattice'] if 'lattice' in task else None, inset=True)
print(' {:<18} {:<d}'.format('Symmetries:', len(task['perms'])))
print(' {:<18}'.format('Hyper-parameters'))
print(' {:<16} {:<d}'.format('Length scale:', task['sig']))
if 'lam' in task:
print(' {:<16} {:<.0e}'.format('Regularization:', task['lam']))
# if 'solver_name' in task:
# print(' {:<18}'.format('Solver configuration'))
# print(' {:<16} \'{}\''.format('Type:', task['solver_name']))
# if task['solver_name'] == 'cg':
# if 'solver_tol' in task:
# print(' {:<16} {:<.0e}'.format('Tolerance:', task['solver_tol']))
# if 'n_inducing_pts_init' in task:
# print(
# ' {:<16} {:<d}'.format(
# 'Inducing points:', task['n_inducing_pts_init']
# )
# )
# else:
# print(' {:<18} {}'.format('Solver:', 'unknown'))
n_train = len(task['idxs_train'])
ui.print_two_column_str(
' {:<18} {:,} points'.format('Train on:', n_train),
'from \'' + ui.unicode_str(task['md5_train']) + '\'',
)
n_valid = len(task['idxs_valid'])
ui.print_two_column_str(
' {:<18} {:,} points'.format('Validate on:', n_valid),
'from \'' + ui.unicode_str(task['md5_valid']) + '\'',
)
# print(' {:<18}'.format('Estimated memory requirement (min.)'))
# mem_kernel_mat_const = 0
# mem_precond_const = 0
# print(
# ' {:<16} {}'.format(
# 'CPU:', ui.gen_memory_str(mem_kernel_mat_const + mem_precond_const)
# )
# )
# print(' {:<14} {}'.format('Kernel matrix:', ui.gen_memory_str(mem_kernel_mat_const))
# print(' {:<14} {}'.format('Precond. factor:', ui.gen_memory_str(mem_precond_const)))
# mem_torch_assemble = 0
# mem_torch_eval = 0
# print(
# ' {:<16} {}'.format(
# 'GPU:', ui.gen_memory_str(mem_torch_assemble + mem_torch_eval)
# )
# )
# print(' {:<14} {}'.format('Kernel matrix assembly:', ui.gen_memory_str(mem_torch_assemble)))
# print(' {:<14} {}'.format('Model evaluation:', ui.gen_memory_str(mem_torch_eval)))
def _print_model_properties(model, title_str='Model properties'):
print(ui.color_str(title_str, bold=True))
print(' {:<18}'.format('Dataset'))
print(' {:<16} \'{}\''.format('Name:', ui.unicode_str(model['dataset_name'])))
print(
' {:<16} \'{}\''.format(
'Theory level:', ui.unicode_str(model['dataset_theory'])
)
)
n_atoms = len(model['z'])
print(' {:<16} {:<d}'.format('Atoms:', n_atoms))
ui.print_lattice(model['lattice'] if 'lattice' in model else None, inset=True)
print(' {:<18} {:<d}'.format('Symmetries:', len(model['perms'])))
print(' {:<18}'.format('Hyper-parameters'))
print(' {:<16} {:<d}'.format('Length scale:', model['sig']))
if 'lam' in model:
print(' {:<16} {:<.0e}'.format('Regularization:', model['lam']))
if 'solver_name' in model:
print(' {:<18}'.format('Solver'))
print(' {:<16} \'{}\''.format('Type:', model['solver_name']))
if model['solver_name'] == 'cg':
if 'solver_tol' in model:
ui.print_two_column_str(
' {:<16} {:<.0e}'.format('Tolerance:', model['solver_tol']),
'iterate until: norm(K*alpha - y) <= tol*norm(y) = {:<.0e}'.format(
model['solver_tol'] * model['norm_y_train']
),
)
if 'solver_resid' in model:
is_conv = (
model['solver_resid']
<= model['solver_tol'] * model['norm_y_train']
)
print(
' {:<16} {:<.0e}{}'.format(
'Converged to:',
model['solver_resid'],
'' if is_conv else ' (NOT CONVERGED)',
)
)
if 'solver_iters' in model:
print(' {:<16} {:<d}'.format('Iterations:', model['solver_iters']))
if 'inducing_pts_idxs' in model:
n_inducing_pts = len(model['inducing_pts_idxs']) // (3 * n_atoms)
ui.print_two_column_str(
' {:<16} {:<d}'.format('Inducing points:', n_inducing_pts),
'inducing columns: {:<d} (multiplied by DOF)'.format(
n_inducing_pts * n_atoms * 3
),
)
else:
print(' {:<18} {}'.format('Solver:', 'unknown'))
n_train = len(model['idxs_train'])
ui.print_two_column_str(
' {:<18} {:,} points'.format('Trained on:', n_train),
'from \'' + ui.unicode_str(model['md5_train']) + '\'',
)
use_E_cstr = 'alphas_E' in model
print(
' {:<16} {}'.format(
'Energy offset',
'[{}] global integration constant'.format('x' if not use_E_cstr else ' '),
)
)
ui.print_two_column_str(
' {:<16}'.format(
'[{}] pointwise energy constraints'.format('x' if use_E_cstr else ' ')
),
'using \'--E_cstr\'',
)
if model['use_E']:
e_err = model['e_err'].item()
f_err = model['f_err'].item()
n_valid = len(model['idxs_valid'])
is_valid = not np.isnan(f_err['mae']) and not np.isnan(f_err['rmse'])
ui.print_two_column_str(
' {:<18} {}{:,} points'.format(
'Validated on:', '' if is_valid else '[pending] ', n_valid
),
'from \'' + ui.unicode_str(model['md5_valid']) + '\'',
)
n_test = int(model['n_test'])
is_test = n_test > 0
if is_test:
ui.print_two_column_str(
' {:<18} {:,} points'.format('Tested on:', n_test),
'from \'' + ui.unicode_str(model['md5_test']) + '\'',
)
else:
print(' {:<18} {}'.format('Test:', '[pending]'))
e_unit = 'unknown unit'
f_unit = 'unknown unit'
if 'r_unit' in model and 'e_unit' in model:
e_unit = model['e_unit']
f_unit = ui.unicode_str(model['e_unit']) + '/' + ui.unicode_str(model['r_unit'])
if is_valid:
action_str = 'Validation' if not is_valid else 'Expected test'
print(' {:<18}'.format('{} errors (MAE/RMSE)'.format(action_str)))
if model['use_E']:
print(
' {:<16} {:>.4f}/{:>.4f} [{}]'.format(
'Energy:', e_err['mae'], e_err['rmse'], e_unit
)
)
print(
' {:<16} {:>.4f}/{:>.4f} [{}]'.format(
'Forces:', f_err['mae'], f_err['rmse'], f_unit
)
)
def _print_next_step(
prev_step, task_dir=None, model_dir=None, model_files=None, dataset_path=None
):
if prev_step == 'create':
assert task_dir is not None
ui.print_step_title(
'NEXT STEP',
'{} train {} <valid_dataset_file>'.format(PACKAGE_NAME, task_dir),
underscore=False,
)
elif prev_step == 'train' or prev_step == 'validate' or prev_step == 'resume':
assert model_dir is not None and model_files is not None
if dataset_path is None:
dataset_path = '<test_dataset_file>'
n_models = len(model_files)
if n_models == 1:
model_file_path = os.path.join(model_dir, model_files[0])
ui.print_step_title(
'NEXT STEP',
'{} test {} {} [<n_test>]'.format(
PACKAGE_NAME, model_file_path, dataset_path
),
underscore=False,
)
else:
ui.print_step_title(
'NEXT STEP',
'{} select {}'.format(PACKAGE_NAME, model_dir),
underscore=False,
)
elif prev_step == 'select':
assert model_files is not None
ui.print_step_title(
'NEXT STEP',
'{} test {} <test_dataset_file> [<n_test>]'.format(
PACKAGE_NAME, model_files[0]
),
underscore=False,
)
else:
raise AssistantError('Unexpected previous step string.')
def all(
dataset,
valid_dataset,
test_dataset,
n_train,
n_valid,
n_test,
sigs,
gdml,
use_E,
use_E_cstr,
lazy_training,
overwrite,
max_memory,
max_processes,
use_torch,
task_dir=None,
model_file=None,
perms_from_arg=None,
**kwargs
):
print(
'\n'
+ ui.color_str(' STEP 0 ', fore_color=ui.BLACK, back_color=ui.WHITE, bold=True)
+ ' Dataset(s)\n'
+ '-' * MAX_PRINT_WIDTH
)
_, dataset_extracted = dataset
_print_dataset_properties(dataset_extracted, title_str='Properties')
if valid_dataset is None:
valid_dataset = dataset
else:
_, valid_dataset_extracted = valid_dataset
print()
_print_dataset_properties(
valid_dataset_extracted, title_str='Properties (validation dataset)'
)
if not np.array_equal(dataset_extracted['z'], valid_dataset_extracted['z']):
raise AssistantError(
'Atom composition or order in validation dataset does not match the one in bulk dataset.'
)
if test_dataset is None:
test_dataset = dataset
else:
_, test_dataset_extracted = test_dataset
_print_dataset_properties(
test_dataset_extracted, title_str='Properties (test dataset)'
)
if not np.array_equal(dataset_extracted['z'], test_dataset_extracted['z']):
raise AssistantError(
'Atom composition or order in test dataset does not match the one in bulk dataset.'
)
ui.print_step_title('STEP 1', 'Cross-validation task creation')
task_dir = create(
dataset,
valid_dataset,
n_train,
n_valid,
sigs,
gdml,
use_E,
use_E_cstr,
overwrite,
task_dir,
perms_from_arg=perms_from_arg,
**kwargs
)
ui.print_step_title('STEP 2', 'Training and validation')
task_dir_arg = io.is_dir_with_file_type(task_dir, 'task')
model_dir_or_file_path = train(
task_dir_arg,
valid_dataset,
lazy_training,
overwrite,
max_memory,
max_processes,
use_torch,
**kwargs
)
model_dir_arg = io.is_dir_with_file_type(
model_dir_or_file_path, 'model', or_file=True
)
_, model_file_names = model_dir_arg
if len(model_file_names) == 0:
raise AssistantError(
'No trained models found!'
+ ('\nTry turning turning off \'--lazy\'-mode.' if lazy_training else '')
)
ui.print_step_title('STEP 3', 'Hyper-parameter selection')
model_file_name = select(model_dir_arg, overwrite, model_file, **kwargs)
# Have all tasks been trained?
_, task_file_names = task_dir_arg
if len(task_file_names) > len(model_file_names):
log.warning(
'Not all training tasks have been completed! The model selected here might not be optimal.'
+ ('\nTry turning turning off \'--lazy\'-mode.' if lazy_training else '')
)
ui.print_step_title('STEP 4', 'Testing')
model_dir_arg = io.is_dir_with_file_type(model_file_name, 'model', or_file=True)
test(
model_dir_arg,
test_dataset,
n_test,
overwrite=False,
max_memory=max_memory,
max_processes=max_processes,
use_torch=use_torch,
**kwargs
)
print(
'\n'
+ ui.color_str(' DONE ', fore_color=ui.BLACK, back_color=ui.GREEN, bold=True)
+ ' Training assistant finished sucessfully.'
)
print(' This is your model file: \'{}\''.format(model_file_name))
# if training job exists and is a subset of the requested cv range, add new tasks
# otherwise, if new range is different or smaller, fail
def create( # noqa: C901
dataset,
valid_dataset,
n_train,
n_valid,
sigs,
gdml,
use_E,
use_E_cstr,
overwrite,
task_dir=None,
perms_from_arg=None,
command=None,
**kwargs
):
has_valid_dataset = not (valid_dataset is None or valid_dataset == dataset)
dataset_path, dataset = dataset
n_data = dataset['F'].shape[0]
func_called_directly = (
command == 'create'
) # has this function been called from command line or from 'all'?
if func_called_directly:
ui.print_step_title('TASK CREATION')
_print_dataset_properties(dataset)
print()
_print_task_properties_reduced(use_sym=not gdml, use_E=use_E, use_E_cstr=use_E_cstr)
print()
if n_data < n_train:
raise AssistantError(
'Dataset only contains {} points, can not train on {}.'.format(
n_data, n_train
)
)
if not has_valid_dataset:
valid_dataset_path, valid_dataset = dataset_path, dataset
if n_data - n_train < n_valid:
raise AssistantError(
'Dataset only contains {} points, can not train on {} and validate on {}.'.format(
n_data, n_train, n_valid
)
)
else:
valid_dataset_path, valid_dataset = valid_dataset
n_valid_data = valid_dataset['R'].shape[0]
if n_valid_data < n_valid:
raise AssistantError(
'Validation dataset only contains {} points, can not validate on {}.'.format(
n_data, n_valid
)
)
if sigs is None:
log.info(
'Kernel hyper-parameter sigma (length scale) was automatically set to range \'10:10:100\'.'
)
sigs = list(range(10, 100, 10)) # default range
if task_dir is None:
task_dir = io.train_dir_name(
dataset,
n_train,
use_sym=not gdml,
use_E=use_E,
use_E_cstr=use_E_cstr,
)
task_file_names = []
if os.path.exists(task_dir):
if overwrite:
log.info('Overwriting existing training directory')
shutil.rmtree(task_dir, ignore_errors=True)
os.makedirs(task_dir)
else:
if io.is_task_dir_resumeable(
task_dir, dataset, valid_dataset, n_train, n_valid, sigs, gdml
):
log.info(
'Resuming existing hyper-parameter search in \'{}\'.'.format(
task_dir
)
)
# Get all task file names.
try:
_, task_file_names = io.is_dir_with_file_type(task_dir, 'task')
except Exception:
pass
else:
raise AssistantError(
'Unfinished hyper-parameter search found in \'{}\'.\n'.format(
task_dir
)
+ 'Run \'%s %s -o %s %d %d -s %s\' to overwrite.'
% (
PACKAGE_NAME,
command,
dataset_path,
n_train,
n_valid,
' '.join(str(s) for s in sigs),
)
)
else:
os.makedirs(task_dir)
if task_file_names:
with np.load(
os.path.join(task_dir, task_file_names[0]), allow_pickle=True
) as task:
tmpl_task = dict(task)
else:
if not use_E:
log.info(
'Energy labels will be ignored for training.\n'
+ 'Note: If available in the dataset file, the energy labels will however still be used to generate stratified training, test and validation datasets. Otherwise a random sampling is used.'
)
if 'E' not in dataset:
log.warning(
'Training dataset will be sampled with no guidance from energy labels (i.e. randomly)!'
)
if 'E' not in valid_dataset:
log.warning(
'Validation dataset will be sampled with no guidance from energy labels (i.e. randomly)!\n'
+ 'Note: Larger validation datasets are recommended due to slower convergence of the error.'
)
if ('lattice' in dataset) ^ ('lattice' in valid_dataset):
log.error('One of the datasets specifies lattice vectors and one does not!')
# TODO: stop program?
if 'lattice' in dataset or 'lattice' in valid_dataset:
log.info(
'Lattice vectors found in dataset: applying periodic boundary conditions.'
)
perms = None
if perms_from_arg is not None:
_, perms_from = perms_from_arg
if 'perms' in perms_from:
perms = perms_from['perms']
else:
raise AssistantError(
'Provided permutation file does not contain any (looking for \'perms\'-key).'
)
gdml_train = (
GDMLTrain()
) # No process number of memory restrictions necessary here.
try:
tmpl_task = gdml_train.create_task(
dataset,
n_train,
valid_dataset,
n_valid,
sig=1,
perms=perms,
use_sym=not gdml,
use_E=use_E,
use_E_cstr=use_E_cstr,
callback=ui.callback,
) # template task
except:
print()
log.critical(traceback.format_exc())
print()
os._exit(1)
n_written = 0
for sig in sigs:
tmpl_task['sig'] = sig
task_file_name = io.task_file_name(tmpl_task)
task_path = os.path.join(task_dir, task_file_name)
if os.path.isfile(task_path):
log.info('Skipping existing task \'{}\'.'.format(task_file_name))
else:
np.savez_compressed(task_path, **tmpl_task)
n_written += 1
if n_written > 0:
log.done(
'Writing {:d}/{:d} task(s) with m={} training points each'.format(
n_written, len(sigs), tmpl_task['R_train'].shape[0]
)
)
if func_called_directly:
_print_next_step('create', task_dir=task_dir)
return task_dir
def train(
task_dir,
valid_dataset,
lazy_training,
overwrite,
max_memory,
max_processes,
use_torch,
command=None,
**kwargs
):
task_dir, task_file_names = task_dir
n_tasks = len(task_file_names)
func_called_directly = (
command == 'train'
) # Has this function been called from command line or from 'all'?
if func_called_directly:
ui.print_step_title('MODEL TRAINING')
def save_progr_callback(
unconv_model, unconv_model_path=None
): # Saves current (unconverged) model during iterative training
if unconv_model_path is None:
log.critical(
'Path for unconverged model not set in \'save_progr_callback\'.'
)
print()
os._exit(1)
np.savez_compressed(unconv_model_path, **unconv_model)
try:
gdml_train = GDMLTrain(
max_memory=max_memory, max_processes=max_processes, use_torch=use_torch
)
except:
print()
log.critical(traceback.format_exc())
print()
os._exit(1)
prev_valid_err = -1
has_converged_once = False
for i, task_file_name in enumerate(task_file_names):
task_file_path = os.path.join(task_dir, task_file_name)
with np.load(task_file_path, allow_pickle=True) as task:
if n_tasks > 1:
if i > 0:
print()
n_train = len(task['idxs_train'])
n_valid = len(task['idxs_valid'])
ui.print_two_column_str(
ui.color_str('Task {:d} of {:d}'.format(i + 1, n_tasks), bold=True),
'{:,} + {:,} points (training + validation), sigma (length scale): {}'.format(
n_train, n_valid, task['sig']
),
)
model_file_name = io.model_file_name(task, is_extended=False)
model_file_path = os.path.join(task_dir, model_file_name)
# is_conv = True
# valid_errs = None
# is_model_validated = False
if not overwrite and os.path.isfile(
model_file_path
): # Train model found, validate if necessary
log.info(
'Model \'{}\' already exists.'.format(model_file_name)
+ (
'\nRun \'{} train -o {}\' to overwrite.'.format(
PACKAGE_NAME, task_file_path
)
if func_called_directly
else ''
)
)
model_path = os.path.join(task_dir, model_file_name)
_, model = io.is_file_type(model_path, 'model')
e_err = {'mae': 0.0, 'rmse': 0.0}
if model['use_E']:
e_err = model['e_err'].item()
f_err = model['f_err'].item()
is_conv = True
if 'solver_resid' in model:
is_conv = (
model['solver_resid']
<= model['solver_tol'] * model['norm_y_train']
)
is_model_validated = not (
np.isnan(f_err['mae']) or np.isnan(f_err['rmse'])
)
if is_model_validated:
disp_str = (
'energy %.3f/%.3f, ' % (e_err['mae'], e_err['rmse'])
if model['use_E']
else ''
)
disp_str += 'forces %.3f/%.3f' % (f_err['mae'], f_err['rmse'])
disp_str = 'Validation errors (MAE/RMSE): ' + disp_str
ui.callback(1, 1, disp_str=disp_str)
valid_errs = [f_err['rmse']]
else: # Train and validate model
# Check if training this task has been attempted before.
if lazy_training and n_tasks > 1:
if 'tried_training' in task and task['tried_training']:
log.warning(
'Skipping task, because it has been tried before (without success).'
)
continue
# Record in task file that there was a training attempt.
task = dict(task)
task['tried_training'] = True
np.savez_compressed(task_file_path, **task)
n_train, n_atoms = task['R_train'].shape[:2]
unconv_model_file = '_unconv_{}'.format(model_file_name)
unconv_model_path = os.path.join(task_dir, unconv_model_file)
try:
model = gdml_train.train(
task,
partial(
save_progr_callback, unconv_model_path=unconv_model_path
),
ui.callback,
)
except:
print()
log.critical(traceback.format_exc())
print()
os._exit(1)
else:
if func_called_directly:
log.done('Writing model to file \'{}\''.format(model_file_path))
np.savez_compressed(model_file_path, **model)
# Delete temporary model, if one exists.
unconv_model_exists = os.path.isfile(unconv_model_path)
if unconv_model_exists:
os.remove(unconv_model_path)
is_model_validated = False
if not is_model_validated:
if (
n_tasks == 1
): # Only validate if there is more than one training task.
log.info(
'Skipping validation step as there is only one model to validate.'
)
break
# Validate model.
model_dir = (task_dir, [model_file_name])
valid_errs = test(
model_dir,
valid_dataset,
-1, # n_test = -1 -> validation mode
overwrite,
max_memory,
max_processes,
use_torch,
command,
**kwargs
)
is_conv = True
if 'solver_resid' in model:
is_conv = (
model['solver_resid']
<= model['solver_tol'] * model['norm_y_train']
)
has_converged_once = has_converged_once or is_conv
if (
has_converged_once
and prev_valid_err != -1
and prev_valid_err < valid_errs[0]
):
print()
log.info(
'Skipping remaining training tasks, as validation error is rising again.'
)
break
prev_valid_err = valid_errs[0]
model_dir_or_file_path = model_file_path if n_tasks == 1 else task_dir
if func_called_directly:
model_dir_arg = io.is_dir_with_file_type(
model_dir_or_file_path, 'model', or_file=True
)
model_dir, model_files = model_dir_arg
_print_next_step('train', model_dir=model_dir, model_files=model_files)
return model_dir_or_file_path # model directory or file
def _batch(iterable, n=1):
l = len(iterable)
for ndx in range(0, l, n):
yield iterable[ndx : min(ndx + n, l)]
def _online_err(err, size, n, mae_n_sum, rmse_n_sum):
err = np.abs(err)
mae_n_sum += np.sum(err) / size
mae = mae_n_sum / n
rmse_n_sum += np.sum(err**2) / size
rmse = np.sqrt(rmse_n_sum / n)
return mae, mae_n_sum, rmse, rmse_n_sum
def resume(
model,
dataset,
valid_dataset,
overwrite,
max_memory,
max_processes,
use_torch,
command=None,
**kwargs
):
model_path, model = model
dataset_path, dataset = dataset
valid_dataset_arg = valid_dataset
valid_dataset_path, valid_dataset = valid_dataset
ui.print_step_title('RESUME TRAINING')
_print_model_properties(model, title_str='Model properties (initial)')
print()
if dataset['md5'] != model['md5_train']:
raise AssistantError(
'Fingerprint of provided training dataset does not match the one specified in model file.'
)
if valid_dataset['md5'] != model['md5_valid']:
raise AssistantError(
'Fingerprint of provided validation dataset does not match the one specified in model file.'
)
if model['solver_name'] == 'analytic':
raise AssistantError(
'This model was trained using a matrix decomposition method and thus already converged to the highest possible accuracy! It does not make sense to resume training in this case.'
)
elif 'solver_resid' in model and 'solver_tol' in model:
if model['solver_resid'] > model['solver_tol'] * model['norm_y_train']:
gdml_train = GDMLTrain(
max_memory=max_memory, max_processes=max_processes, use_torch=use_torch
)
try:
task = gdml_train.create_task_from_model(
model,
dataset,
)
except:
print()
log.critical(traceback.format_exc())
print()
os._exit(1)
del gdml_train
def save_progr_callback(
unconv_model,
): # saves current (unconverged) model during iterative training
np.savez_compressed(model_path, **unconv_model)
try:
gdml_train = GDMLTrain(
max_memory=max_memory,
max_processes=max_processes,
use_torch=use_torch,
)
except:
print()
log.critical(traceback.format_exc())
print()
os._exit(1)
try:
model = gdml_train.train(
task, save_progr_callback=save_progr_callback, callback=ui.callback
)
except:
print()
log.critical(traceback.format_exc())
print()
os._exit(1)
else:
log.done('Model parameters have been updated.')
np.savez_compressed(model_path, **model)
else:
log.warning('Model is already converged to the specified tolerance.')
# Validate model.
model_dir, model_file_name = os.path.split(model_path)
model_dir_arg = (model_dir, [model_file_name])
valid_errs = test(
model_dir_arg,
valid_dataset_arg,
-1, # n_test = -1 -> validation mode
overwrite,
max_memory,
max_processes,
use_torch,
command,
**kwargs
)
_print_next_step('resume', model_dir=model_dir, model_files=[model_file_name])
def validate(
model_dir,
valid_dataset,
overwrite,
max_memory,
max_processes,
use_torch,
command=None,
**kwargs
):
dataset_path_extracted, dataset_extracted = valid_dataset
func_called_directly = (
command == 'validate'
) # has this function been called from command line or from 'all'?
if func_called_directly:
ui.print_step_title('MODEL VALIDATION')
_print_dataset_properties(dataset_extracted)
test(
model_dir,
valid_dataset,
-1, # n_test = -1 -> validation mode
overwrite,
max_memory,
max_processes,
use_torch,
command,
**kwargs
)
if func_called_directly:
model_dir, model_files = model_dir
n_models = len(model_files)
_print_next_step('validate', model_dir=model_dir, model_files=model_files)
def test(
model_dir,
test_dataset,
n_test,
overwrite,
max_memory,
max_processes,
use_torch,
command=None,
**kwargs
): # noqa: C901
# NOTE: this function runs a validation if n_test < 0 and test with all points if n_test == 0
model_dir, model_file_names = model_dir
n_models = len(model_file_names)
n_test = 0 if n_test is None else n_test
is_validation = n_test < 0
is_test = n_test >= 0
dataset_path, dataset = test_dataset
func_called_directly = (
command == 'test'
) # has this function been called from command line or from 'all'?
if func_called_directly:
ui.print_step_title('MODEL TEST')
_print_dataset_properties(dataset)
F_rmse = []
# NEW
DEBUG_WRITE = False
if DEBUG_WRITE:
if os.path.exists('test_pred.xyz'):
os.remove('test_pred.xyz')
if os.path.exists('test_ref.xyz'):
os.remove('test_ref.xyz')
if os.path.exists('test_diff.xyz'):
os.remove('test_diff.xyz')
# NEW
num_workers, batch_size = -1, -1
gdml_train = None
for i, model_file_name in enumerate(model_file_names):
model_path = os.path.join(model_dir, model_file_name)
_, model = io.is_file_type(model_path, 'model')
if i == 0 and command != 'all':
print()
_print_model_properties(model)
print()
if not np.array_equal(model['z'], dataset['z']):
raise AssistantError(
'Atom composition or order in dataset does not match the one in model.'
)
if ('lattice' in model) is not ('lattice' in dataset):
if 'lattice' in model:
raise AssistantError(
'Model contains lattice vectors, but dataset does not.'
)
elif 'lattice' in dataset:
raise AssistantError(
'Dataset contains lattice vectors, but model does not.'
)
if model['use_E']:
e_err = model['e_err'].item()
f_err = model['f_err'].item()
is_model_validated = not (np.isnan(f_err['mae']) or np.isnan(f_err['rmse']))
if n_models > 1:
if i > 0:
print()
print(
ui.color_str(
'%s model %d of %d'
% ('Testing' if is_test else 'Validating', i + 1, n_models),
bold=True,
)
)
if is_validation:
if is_model_validated and not overwrite:
log.info(
'Skipping already validated model \'{}\'.'.format(model_file_name)
+ (
'\nRun \'{} validate -o {} {}\' to overwrite.'.format(
PACKAGE_NAME, model_path, dataset_path
)
if command == 'test'
else ''
)
)
continue
if dataset['md5'] != model['md5_valid']:
raise AssistantError(
'Fingerprint of provided validation dataset does not match the one specified in model file.'
)
test_idxs = model['idxs_valid']
if is_test:
# exclude training and/or test sets from validation set if necessary
excl_idxs = np.empty((0,), dtype=np.uint)
if dataset['md5'] == model['md5_train']:
excl_idxs = np.concatenate([excl_idxs, model['idxs_train']]).astype(
np.uint
)
if dataset['md5'] == model['md5_valid']:
excl_idxs = np.concatenate([excl_idxs, model['idxs_valid']]).astype(
np.uint
)
n_data = dataset['F'].shape[0]
n_data_eff = n_data - len(excl_idxs)
if (
n_test == 0 and n_data_eff != 0
): # test on all data points that have not been used for training or testing
n_test = n_data_eff
log.info(
'Test set size was automatically set to {:,} points.'.format(n_test)
)
if n_test == 0 or n_data_eff == 0:
log.warning('Skipping! No unused points for test in provided dataset.')
return
elif n_data_eff < n_test:
n_test = n_data_eff
log.warning(
'Test size reduced to {:d}. Not enough unused points in provided dataset.'.format(
n_test
)
)
if 'E' in dataset:
if gdml_train is None:
gdml_train = GDMLTrain(
max_memory=max_memory, max_processes=max_processes
)
test_idxs = gdml_train.draw_strat_sample(
dataset['E'], n_test, excl_idxs=excl_idxs
)
else:
test_idxs = np.delete(np.arange(n_data), excl_idxs)
log.warning(
'Test dataset will be sampled with no guidance from energy labels (randomly)!\n'
+ 'Note: Larger test datasets are recommended due to slower convergence of the error.'
)
# shuffle to improve convergence of online error
np.random.shuffle(test_idxs)
# NEW
if DEBUG_WRITE:
test_idxs = np.sort(test_idxs)
z = dataset['z']
R = dataset['R'][test_idxs, :, :]
F = dataset['F'][test_idxs, :, :]
if model['use_E']:
E = dataset['E'][test_idxs]
try:
gdml_predict = GDMLPredict(
model,
max_memory=max_memory,
max_processes=max_processes,
use_torch=use_torch,
)
except:
print()
log.critical(traceback.format_exc())
print()
os._exit(1)
b_size = min(1000, len(test_idxs))
if not use_torch:
if num_workers == -1 or batch_size == -1:
ui.callback(NOT_DONE, disp_str='Optimizing parallelism')
gps, is_from_cache = gdml_predict.prepare_parallel(
n_bulk=b_size, return_is_from_cache=True
)
num_workers, chunk_size, bulk_mp = (
gdml_predict.num_workers,
gdml_predict.chunk_size,
gdml_predict.bulk_mp,
)
sec_disp_str = 'no chunking'.format(chunk_size)
if chunk_size != gdml_predict.n_train:
sec_disp_str = 'chunks of {:d}'.format(chunk_size)
if num_workers == 0:
sec_disp_str = 'no workers / ' + sec_disp_str
else:
sec_disp_str = (
'{:d} workers {}/ '.format(
num_workers, '[MP] ' if bulk_mp else ''
)
+ sec_disp_str
)
ui.callback(
DONE,
disp_str='Optimizing parallelism'
+ (' (from cache)' if is_from_cache else ''),
sec_disp_str=sec_disp_str,
)
else:
gdml_predict._set_num_workers(num_workers)
gdml_predict._set_chunk_size(chunk_size)
gdml_predict._set_bulk_mp(bulk_mp)
n_atoms = z.shape[0]
if model['use_E']:
e_mae_sum, e_rmse_sum = 0, 0
f_mae_sum, f_rmse_sum = 0, 0
cos_mae_sum, cos_rmse_sum = 0, 0
mag_mae_sum, mag_rmse_sum = 0, 0
n_done = 0
t = time.time()
for b_range in _batch(list(range(len(test_idxs))), b_size):
n_done_step = len(b_range)
n_done += n_done_step
r = R[b_range].reshape(n_done_step, -1)
e_pred, f_pred = gdml_predict.predict(r)
# energy error
if model['use_E']:
e = E[b_range]
e_mae, e_mae_sum, e_rmse, e_rmse_sum = _online_err(
np.squeeze(e) - e_pred, 1, n_done, e_mae_sum, e_rmse_sum
)
# import matplotlib.pyplot as plt
# plt.hist(np.squeeze(e) - e_pred)
# plt.show()
# force component error
f = F[b_range].reshape(n_done_step, -1)
f_mae, f_mae_sum, f_rmse, f_rmse_sum = _online_err(
f - f_pred, 3 * n_atoms, n_done, f_mae_sum, f_rmse_sum
)
# magnitude error
f_pred_mags = np.linalg.norm(f_pred.reshape(-1, 3), axis=1)
f_mags = np.linalg.norm(f.reshape(-1, 3), axis=1)
mag_mae, mag_mae_sum, mag_rmse, mag_rmse_sum = _online_err(
f_pred_mags - f_mags, n_atoms, n_done, mag_mae_sum, mag_rmse_sum
)
# normalized cosine error
f_pred_norm = f_pred.reshape(-1, 3) / f_pred_mags[:, None]
f_norm = f.reshape(-1, 3) / f_mags[:, None]
cos_err = (
np.arccos(np.clip(np.einsum('ij,ij->i', f_pred_norm, f_norm), -1, 1))
/ np.pi
)
cos_mae, cos_mae_sum, cos_rmse, cos_rmse_sum = _online_err(
cos_err, n_atoms, n_done, cos_mae_sum, cos_rmse_sum
)
# NEW
if is_test and DEBUG_WRITE:
try:
with open('test_pred.xyz', 'a') as file:
n = r.shape[0]
for i, ri in enumerate(r):
r_out = ri.reshape(-1, 3)
e_out = e_pred[i]
f_out = f_pred[i].reshape(-1, 3)
ext_xyz_str = (
io.generate_xyz_str(r_out, model['z'], e=e_out, f=f_out)
+ '\n'
)
file.write(ext_xyz_str)
except IOError:
sys.exit("ERROR: Writing xyz file failed.")
try:
with open('test_ref.xyz', 'a') as file:
n = r.shape[0]
for i, ri in enumerate(r):
r_out = ri.reshape(-1, 3)
e_out = (
None
if not model['use_E']
else np.squeeze(E[b_range][i])
)
f_out = f[i].reshape(-1, 3)
ext_xyz_str = (
io.generate_xyz_str(r_out, model['z'], e=e_out, f=f_out)
+ '\n'
)
file.write(ext_xyz_str)
except IOError:
sys.exit("ERROR: Writing xyz file failed.")
try:
with open('test_diff.xyz', 'a') as file:
n = r.shape[0]
for i, ri in enumerate(r):
r_out = ri.reshape(-1, 3)
e_out = (
None
if not model['use_E']
else (np.squeeze(E[b_range][i]) - e_pred[i])
)
f_out = (f[i] - f_pred[i]).reshape(-1, 3)
ext_xyz_str = (
io.generate_xyz_str(r_out, model['z'], e=e_out, f=f_out)
+ '\n'
)
file.write(ext_xyz_str)
except IOError:
sys.exit("ERROR: Writing xyz file failed.")
# NEW
sps = n_done / (time.time() - t) # examples per second
disp_str = 'energy %.3f/%.3f, ' % (e_mae, e_rmse) if model['use_E'] else ''
disp_str += 'forces %.3f/%.3f' % (f_mae, f_rmse)
disp_str = (
'{} errors (MAE/RMSE): '.format('Test' if is_test else 'Validation')
+ disp_str
)
sec_disp_str = '@ %.1f geo/s' % sps if b_range is not None else ''
ui.callback(
n_done,
len(test_idxs),
disp_str=disp_str,
sec_disp_str=sec_disp_str,
newline_when_done=False,
)
if is_test:
ui.callback(
DONE,
disp_str='Testing on {:,} points'.format(n_test),
sec_disp_str=sec_disp_str,
)
else:
ui.callback(DONE, disp_str=disp_str, sec_disp_str=sec_disp_str)
if model['use_E']:
e_rmse_pct = (e_rmse / e_err['rmse'] - 1.0) * 100
f_rmse_pct = (f_rmse / f_err['rmse'] - 1.0) * 100
if is_test and n_models == 1:
n_train = len(model['idxs_train'])
n_valid = len(model['idxs_valid'])
print()
ui.print_two_column_str(
ui.color_str('Test errors (MAE/RMSE)', bold=True),
'{:,} + {:,} points (training + validation), sigma (length scale): {}'.format(
n_train, n_valid, model['sig']
),
)
r_unit = 'unknown unit'
e_unit = 'unknown unit'
f_unit = 'unknown unit'
if 'r_unit' in dataset and 'e_unit' in dataset:
r_unit = dataset['r_unit']
e_unit = dataset['e_unit']
f_unit = str(dataset['e_unit']) + '/' + str(dataset['r_unit'])
format_str = ' {:<18} {:>.4f}/{:>.4f} [{}]'
if model['use_E']:
ui.print_two_column_str(
format_str.format('Energy:', e_mae, e_rmse, e_unit),
'relative to expected: {:+.1f}%'.format(e_rmse_pct),
)
ui.print_two_column_str(
format_str.format('Forces:', f_mae, f_rmse, f_unit),
'relative to expected: {:+.1f}%'.format(f_rmse_pct),
)
print(format_str.format(' Magnitude:', mag_mae, mag_rmse, r_unit))
ui.print_two_column_str(
format_str.format(' Angle:', cos_mae, cos_rmse, '0-1'),
'lower is better',
)
print()
model_mutable = dict(model)
model.close()
model = model_mutable
model_needs_update = (
overwrite
or (is_test and model['n_test'] < len(test_idxs))
or (is_validation and not is_model_validated)
)
if model_needs_update:
if is_validation and overwrite:
model['n_test'] = 0 # flag the model as not tested
if is_test:
model['n_test'] = len(test_idxs)
model['md5_test'] = dataset['md5']
if model['use_E']:
model['e_err'] = {
'mae': e_mae.item(),
'rmse': e_rmse.item(),
}
model['f_err'] = {'mae': f_mae.item(), 'rmse': f_rmse.item()}
np.savez_compressed(model_path, **model)
if is_test and model['n_test'] > 0:
log.info('Expected errors were updated in model file.')
else:
add_info_str = (
'the same number of'
if model['n_test'] == len(test_idxs)
else 'only {:,}'.format(len(test_idxs))
)
log.warning(
'This model has previously been tested on {:,} points, which is why the errors for the current test run with {} points have NOT been used to update the model file.\n'.format(
model['n_test'], add_info_str
)
+ 'Run \'{} test -o {} {} {}\' to overwrite.'.format(
PACKAGE_NAME, os.path.relpath(model_path), dataset_path, n_test
)
)
F_rmse.append(f_rmse)
return F_rmse
def select(model_dir, overwrite, model_file=None, command=None, **kwargs): # noqa: C901
func_called_directly = (
command == 'select'
) # has this function been called from command line or from 'all'?
if func_called_directly:
ui.print_step_title('MODEL SELECTION')
any_model_not_validated = False
any_model_is_tested = False
model_dir, model_file_names = model_dir
if len(model_file_names) > 1:
use_E = True
rows = []
data_names = ['sig', 'MAE', 'RMSE', 'MAE', 'RMSE']
for i, model_file_name in enumerate(model_file_names):
model_path = os.path.join(model_dir, model_file_name)
_, model = io.is_file_type(model_path, 'model')
use_E = model['use_E']
if i == 0:
idxs_train = set(model['idxs_train'])
md5_train = model['md5_train']
idxs_valid = set(model['idxs_valid'])
md5_valid = model['md5_valid']
else:
if (
md5_train != model['md5_train']
or md5_valid != model['md5_valid']
or idxs_train != set(model['idxs_train'])
or idxs_valid != set(model['idxs_valid'])
):
raise AssistantError(
'{} contains models trained or validated on different datasets.'.format(
model_dir
)
)
e_err = {'mae': 0.0, 'rmse': 0.0}
if model['use_E']:
e_err = model['e_err'].item()
f_err = model['f_err'].item()
is_model_validated = not (np.isnan(f_err['mae']) or np.isnan(f_err['rmse']))
if not is_model_validated:
any_model_not_validated = True
is_model_tested = model['n_test'] > 0
if is_model_tested:
any_model_is_tested = True
rows.append(
[model['sig'], e_err['mae'], e_err['rmse'], f_err['mae'], f_err['rmse']]
)
model.close()
if any_model_not_validated:
log.warning(
'One or more models in the given directory have not been validated.'
)
print()
if any_model_is_tested:
log.error(
'One or more models in the given directory have already been tested. This means that their recorded expected errors are test errors, not validation errors. However, one should never perform model selection based on the test error!\n'
+ 'Please run the validation command (again) with the overwrite option \'-o\', then this selection command.'
)
return
f_rmse_col = [row[4] for row in rows]
best_idx = f_rmse_col.index(min(f_rmse_col)) # idx of row with lowest f_rmse
best_sig = rows[best_idx][0]
rows = sorted(rows, key=lambda col: col[0]) # sort according to sigma
print(ui.color_str('Cross-validation errors', bold=True))
print(' ' * 7 + 'Energy' + ' ' * 6 + 'Forces')
print((' {:>3} ' + '{:>5} ' * 4).format(*data_names))
print(' ' + '-' * 27)
format_str = ' {:>3} ' + '{:5.2f} ' * 4
format_str_no_E = ' {:>3} - - ' + '{:5.2f} ' * 2
for row in rows:
if use_E:
row_str = format_str.format(*row)
else:
row_str = format_str_no_E.format(*[row[0], row[3], row[4]])
if row[0] != best_sig:
row_str = ui.color_str(row_str, fore_color=ui.GRAY)
print(row_str)
print()
sig_col = [row[0] for row in rows]
if best_sig == min(sig_col) or best_sig == max(sig_col):
log.warning(
'The optimal sigma (length scale) lies on the boundary of the search grid.\n'
+ 'Model performance might improve if the search grid is extended in direction sigma {} {:d}.'.format(
'<' if best_idx == 0 else '>', best_sig
)
)
else: # only one model available
log.info('Skipping model selection step as there is only one model to select.')
best_idx = 0
best_model_path = os.path.join(model_dir, model_file_names[best_idx])
if model_file is None:
# generate model file name based on model properties
best_model = np.load(best_model_path, allow_pickle=True)
model_file = io.model_file_name(best_model, is_extended=True)
best_model.close()
model_exists = os.path.isfile(model_file)
if model_exists and overwrite:
log.info('Overwriting existing model file.')
if not model_exists or overwrite:
if func_called_directly:
log.done('Writing model file \'{}\''.format(model_file))
shutil.copy(best_model_path, model_file)
shutil.rmtree(model_dir, ignore_errors=True)
else:
log.warning(
'Model \'{}\' already exists.\n'.format(model_file)
+ 'Run \'{} select -o {}\' to overwrite.'.format(
PACKAGE_NAME, os.path.relpath(model_dir)
)
)
if func_called_directly:
_print_next_step('select', model_files=[model_file])
return model_file
def show(file, command=None, **kwargs):
ui.print_step_title('SHOW DETAILS')
file_path, file = file
if file['type'].astype(str) == 'd':
_print_dataset_properties(file)
if file['type'].astype(str) == 't':
_print_task_properties(file)
if file['type'].astype(str) == 'm':
_print_model_properties(file)
def reset(command=None, **kwargs):
if ui.yes_or_no('\nDo you really want to purge all caches and temporary files?'):
pkg_dir = os.path.dirname(os.path.abspath(__file__))
bmark_file = '_bmark_cache.npz'
bmark_path = os.path.join(pkg_dir, bmark_file)
if os.path.exists(bmark_path):
try:
os.remove(bmark_path)
except OSError:
print()
log.critical('Exception: unable to delete benchmark cache.')
print()
os._exit(1)
log.done('Benchmark cache deleted.')
else:
log.info('Benchmark cache was already empty.')
else:
print(' Cancelled.')
def main():
def _add_argument_sample_size(parser, subset_str):
subparser.add_argument(
'n_%s' % subset_str,
metavar='<n_%s>' % subset_str,
type=io.is_strict_pos_int,
help='%s sample size' % subset_str,
)
def _add_argument_dir_with_file_type(parser, type, or_file=False):
parser.add_argument(
'%s_dir' % type,
metavar='<%s_dir%s>' % (type, '_or_file' if or_file else ''),
type=lambda x: io.is_dir_with_file_type(x, type, or_file=or_file),
help='path to %s directory%s' % (type, ' or file' if or_file else ''),
)
# Available resources
total_memory = psutil.virtual_memory().total // 2**30
total_cpus = mp.cpu_count()
parser = argparse.ArgumentParser()
parser.add_argument(
'--version',
action='version',
version='%(prog)s '
+ __version__
+ ' [Python {}, NumPy {}, SciPy {}'.format(
'.'.join(map(str, sys.version_info[:3])), np.__version__, sp.__version__
)
+ ', PyTorch {}'.format(torch.__version__ if _has_torch else 'N/A')
+ ', ASE {}'.format(ase.__version__ if _has_ase else 'N/A')
+ ']',
)
parent_parser = argparse.ArgumentParser(add_help=False)
subparsers = parser.add_subparsers(title='commands', dest='command')
subparsers.required = True
parser_all = subparsers.add_parser(
'all',
help='reconstruct a force field from beginning to end',
parents=[parent_parser],
)
parser_create = subparsers.add_parser(
'create', help='create training task(s)', parents=[parent_parser]
)
parser_train = subparsers.add_parser(
'train', help='train model(s) from task(s)', parents=[parent_parser]
)
parser_resume = subparsers.add_parser(
'resume', help='resume training of a model', parents=[parent_parser]
)
parser_valid = subparsers.add_parser(
'validate', help='validate model(s)', parents=[parent_parser]
)
parser_select = subparsers.add_parser(
'select', help='select best performing model', parents=[parent_parser]
)
parser_test = subparsers.add_parser(
'test', help='test a model', parents=[parent_parser]
)
parser_show = subparsers.add_parser(
'show',
help='print details about a dataset, task or model file',
parents=[parent_parser],
)
subparsers.add_parser(
'reset', help='delete all caches and temporary files', parents=[parent_parser]
)
for subparser in [parser_all, parser_create]:
subparser.add_argument(
'dataset',
metavar='<dataset_file>',
type=lambda x: io.is_file_type(x, 'dataset'),
help='path to dataset file (train/validation/test subsets are sampled from here if no seperate dataset are specified)',
)
_add_argument_sample_size(subparser, 'train')
_add_argument_sample_size(subparser, 'valid')
subparser.add_argument(
'-v',
'--validation_dataset',
metavar='<valid_dataset_file>',
dest='valid_dataset',
type=lambda x: io.is_file_type(x, 'dataset'),
help='path to separate validation dataset file',
)
subparser.add_argument(
'-t',
'--test_dataset',
metavar='<test_dataset_file>',
dest='test_dataset',
type=lambda x: io.is_file_type(x, 'dataset'),
help='path to separate test dataset file',
)
subparser.add_argument(
'-s',
'--sig',
metavar=('<s1>', '<s2>'),
dest='sigs',
type=io.parse_list_or_range,
help='integer list and/or range <start>:[<step>:]<stop> for the kernel hyper-parameter sigma (length scale)',
nargs='+',
)
group = subparser.add_mutually_exclusive_group()
group.add_argument(
'--gdml',
action='store_true',
help='don\'t include symmetries in the model (GDML)',
)
group.add_argument(
'--perms_from',
metavar='<file>',
dest='perms_from_arg',
type=lambda x: io.is_valid_file_type(x),
help='path to file to take permutations from (key: \'perms\')',
)
group = subparser.add_mutually_exclusive_group()
group.add_argument(
'--no_E',
dest='use_E',
action='store_false',
help='only reconstruct force field w/o potential energy surface',
)
group.add_argument(
'--E_cstr',
dest='use_E_cstr',
action='store_true',
help='include pointwise energy constraints',
)
subparser.add_argument(
'--task_dir',
metavar='<task_dir>',
dest='task_dir',
help='user-defined task output dir name',
)
for subparser in [parser_all, parser_select]:
subparser.add_argument(
'--model_file',
metavar='<model_file>',
dest='model_file',
help='user-defined model output file name',
)
for subparser in [parser_all, parser_train]:
subparser.add_argument(
'--lazy',
dest='lazy_training',
action='store_true',
help='give up on unfinished tasks (if more than one)',
)
for subparser in [parser_valid, parser_test]:
_add_argument_dir_with_file_type(subparser, 'model', or_file=True)
parser_valid.add_argument(
'valid_dataset',
metavar='<valid_dataset_file>',
type=lambda x: io.is_file_type(x, 'dataset'),
help='path to validation dataset file',
)
parser_test.add_argument(
'test_dataset',
metavar='<test_dataset_file>',
type=lambda x: io.is_file_type(x, 'dataset'),
help='path to test dataset file',
)
for subparser in [parser_all, parser_test]:
subparser.add_argument(
'n_test',
metavar='<n_test>',
type=io.is_strict_pos_int,
help='test sample size',
nargs='?',
default=None,
)
parser_resume.add_argument(
'model',
metavar='<model_file>',
type=lambda x: io.is_file_type(x, 'model'),
help='path to model file to complete training for',
)
parser_resume.add_argument(
'dataset',
metavar='<train_dataset_file>',
type=lambda x: io.is_file_type(x, 'dataset'),
help='path to original training dataset file',
)
_add_argument_dir_with_file_type(parser_train, 'task', or_file=True)
for subparser in [parser_train, parser_resume]:
subparser.add_argument(
'valid_dataset',
metavar='<valid_dataset_file>',
type=lambda x: io.is_file_type(x, 'dataset'),
help='path to validation dataset file',
)
_add_argument_dir_with_file_type(parser_select, 'model')
parser_show.add_argument(
'file',
metavar='<file>',
type=lambda x: io.is_valid_file_type(x),
help='path to dataset, task or model file',
)
for subparser in [
parser_all,
parser_train,
parser_resume,
parser_valid,
parser_test,
]:
subparser.add_argument(
'-m',
'--max_memory',
metavar='<max_memory>',
type=int,
help='limit memory usage (whenever possible) [GB]',
choices=range(1, total_memory + 1),
default=total_memory,
)
subparser.add_argument(
'-p',
'--max_processes',
metavar='<max_processes>',
type=int,
help='limit number of processes',
choices=range(1, total_cpus + 1),
default=total_cpus,
)
subparser.add_argument(
'--cpu',
dest='use_torch',
action='store_false',
help='use CPU implementation (no PyTorch dependency)',
)
for subparser in [
parser_all,
parser_create,
parser_train,
parser_resume,
parser_valid,
parser_select,
parser_test,
]:
subparser.add_argument(
'-o',
'--overwrite',
dest='overwrite',
action='store_true',
help='overwrite existing files',
)
args = parser.parse_args()
# Post-processing for optional sig argument
if 'sigs' in args and args.sigs is not None:
args.sigs = np.hstack(
args.sigs
).tolist() # Flatten list, if (part of it) was generated using the range syntax
args.sigs = sorted(list(set(args.sigs))) # remove potential duplicates
# Post-processing for optional model output file argument
if 'model_file' in args and args.model_file is not None:
if not args.model_file.endswith('.npz'):
args.model_file += '.npz'
# Check PyTorch GPU support.
if ('use_torch' in args and args.use_torch) or 'use_torch' not in args:
if _has_torch:
if not (_torch_cuda_is_available or _torch_mps_is_available):
print() # TODO: print only if log level includes warning
log.warning(
'Your PyTorch installation does not see any GPU(s) on your system and will thus run all calculations on the CPU! If this is what you want, we recommend bypassing PyTorch using \'--cpu\' for improved performance.'
)
else:
print()
log.critical(
'PyTorch dependency not found! Please install or use \'--cpu\' to bypass PyTorch and run everything the CPU.'
)
print()
os._exit(1)
args = vars(args)
_print_splash(
args['max_memory'] if 'max_memory' in args else total_memory,
args['max_processes'] if 'max_processes' in args else total_cpus,
args['use_torch'] if 'use_torch' in args else True,
)
try:
getattr(sys.modules[__name__], args['command'])(**args)
except AssistantError as err:
log.error(str(err))
print()
os._exit(1)
except:
log.critical(traceback.format_exc())
print()
os._exit(1)
print()
if __name__ == "__main__":
main()
================================================
FILE: sgdml/get.py
================================================
#!/usr/bin/python
# MIT License
#
# Copyright (c) 2018-2023 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import print_function
import argparse
import os
import re
import sys
from . import __version__
from .utils import ui
if sys.version[0] == '3':
raw_input = input
try:
from urllib.request import urlopen
except ImportError:
from urllib2 import urlopen
def download(command, file_name):
base_url = 'http://www.quantum-machine.org/gdml/' + (
'data/npz/' if command == 'dataset' else 'models/'
)
request = urlopen(base_url + file_name)
file = open(file_name, 'wb')
filesize = int(request.headers['Content-Length'])
size = 0
block_sz = 1024
while True:
buffer = request.read(block_sz)
if not buffer:
break
size += len(buffer)
file.write(buffer)
ui.callback(
size,
filesize,
disp_str='Downloading: {}'.format(file_name),
sec_disp_str='{:,} bytes'.format(filesize),
)
file.close()
def main():
base_url = 'http://www.quantum-machine.org/gdml/'
parser = argparse.ArgumentParser()
parent_parser = argparse.ArgumentParser(add_help=False)
parent_parser.add_argument(
'-o',
'--overwrite',
dest='overwrite',
action='store_true',
help='overwrite existing files',
)
subparsers = parser.add_subparsers(title='commands', dest='command')
subparsers.required = True
parser_dataset = subparsers.add_parser(
'dataset', help='download benchmark dataset', parents=[parent_parser]
)
parser_model = subparsers.add_parser(
'model', help='download pre-trained model', parents=[parent_parser]
)
for subparser in [parser_dataset, parser_model]:
subparser.add_argument(
'name',
metavar='<name>',
type=str,
help='item name',
nargs='?',
default=None,
)
args = parser.parse_args()
print("Contacting server (%s)..." % base_url)
if args.name is not None:
url = '%sget.php?version=%s&%s=%s' % (
base_url,
__version__,
args.command,
args.name,
)
response = urlopen(url)
match, score = response.read().decode().split(',')
response.close()
if int(score) == 0 or ui.yes_or_no('Do you mean \'%s\'?' % match):
download(args.command, match + '.npz')
return
response = urlopen(
'%sget.php?version=%s&%s' % (base_url, __version__, args.command)
)
line = response.readlines()
response.close()
print()
print('Available %ss:' % args.command)
print('{:<2} {:<31} {:>4}'.format('ID', 'Name', 'Size'))
print('-' * 42)
items = line[0].split(b';')
for i, item in enumerate(items):
name, size = item.split(b',')
size = int(size) / 1024**2 # Bytes to MBytes
print('{:>2d} {:<30} {:>5.1f} MB'.format(i, name.decode("utf-8"), size))
print()
down_list = raw_input(
'Please list which %ss to download (e.g. 0 1 2 6) or type \'all\': '
% args.command
)
down_idxs = []
if 'all' in down_list.lower():
down_idxs = list(range(len(items)))
elif re.match(
"^ *[0-9][0-9 ]*$", down_list
): # only digits and spaces, at least one digit
down_idxs = [int(idx) for idx in re.split(r'\s+', down_list.strip())]
down_idxs = list(set(down_idxs))
else:
print(ui.color_str('ABORTED', fore_color=ui.RED, bold=True))
for idx in down_idxs:
if idx not in range(len(items)):
print(
ui.color_str('[WARN]', fore_color=ui.YELLOW, bold=True)
+ ' Index '
+ str(idx)
+ ' out of range, skipping.'
)
else:
name = items[idx].split(b',')[0].decode("utf-8")
if os.path.exists(name):
print("'%s' exists, skipping." % (name))
continue
download(args.command, name + '.npz')
if __name__ == "__main__":
main()
================================================
FILE: sgdml/intf/__init__.py
================================================
================================================
FILE: sgdml/intf/ase_calc.py
================================================
# MIT License
#
# Copyright (c) 2018-2020 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import logging
import numpy as np
try:
from ase.calculators.calculator import Calculator
from ase.units import kcal, mol
except ImportError:
raise ImportError(
'Optional ASE dependency not found! Please run \'pip install sgdml[ase]\' to install it.'
)
from ..predict import GDMLPredict
class SGDMLCalculator(Calculator):
implemented_properties = ['energy', 'forces']
def __init__(
self,
model_path,
E_to_eV=kcal / mol,
F_to_eV_Ang=kcal / mol,
use_torch=False,
*args,
**kwargs
):
"""
ASE calculator for the sGDML force field.
A calculator takes atomic numbers and atomic positions from an Atoms object and calculates the energy and forces.
Note
----
ASE uses eV and Angstrom as energy and length unit, respectively. Unless the paramerters `E_to_eV` and `F_to_eV_Ang` are specified, the sGDML model is assumed to use kcal/mol and Angstorm and the appropriate conversion factors are set accordingly.
Here is how to find them: `ASE units <https://wiki.fysik.dtu.dk/ase/ase/units.html>`_.
Parameters
----------
model_path : :obj:`str`
Path to a sGDML model file
E_to_eV : float, optional
Conversion factor from whatever energy unit is used by the model to eV. By default this parameter is set to convert from kcal/mol.
F_to_eV_Ang : float, optional
Conversion factor from whatever length unit is used by the model to Angstrom. By default, the length unit is not converted (assumed to be in Angstrom)
use_torch : boolean, optional
Use PyTorch to calculate predictions
"""
super(SGDMLCalculator, self).__init__(*args, **kwargs)
self.log = logging.getLogger(__name__)
model = np.load(model_path, allow_pickle=True)
self.gdml_predict = GDMLPredict(model, use_torch=use_torch)
self.gdml_predict.prepare_parallel(n_bulk=1)
self.log.warning(
'Please remember to specify the proper conversion factors, if your model does not use \'kcal/mol\' and \'Ang\' as units.'
)
# Converts energy from the unit used by the sGDML model to eV.
self.E_to_eV = E_to_eV
# Converts length from eV to unit used in sGDML model.
self.Ang_to_R = F_to_eV_Ang / E_to_eV
# Converts force from the unit used by the sGDML model to eV/Ang.
self.F_to_eV_Ang = F_to_eV_Ang
def calculate(self, atoms=None, *args, **kwargs):
super(SGDMLCalculator, self).calculate(atoms, *args, **kwargs)
# convert model units to ASE default units
r = np.array(atoms.get_positions()) * self.Ang_to_R
e, f = self.gdml_predict.predict(r.ravel())
# convert model units to ASE default units (eV and Ang)
e *= self.E_to_eV
f *= self.F_to_eV_Ang
self.results = {'energy': e, 'forces': f.reshape(-1, 3)}
================================================
FILE: sgdml/predict.py
================================================
"""
This module contains all routines for evaluating GDML and sGDML models.
"""
# MIT License
#
# Copyright (c) 2018-2022 Stefan Chmiela, Gregory Fonseca
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import print_function
import sys
import logging
import os
import psutil
import multiprocessing as mp
Pool = mp.get_context('fork').Pool
import timeit
from functools import partial
try:
import torch
except ImportError:
_has_torch = False
else:
_has_torch = True
try:
_torch_mps_is_available = torch.backends.mps.is_available()
except AttributeError:
_torch_mps_is_available = False
_torch_mps_is_available = False
try:
_torch_cuda_is_available = torch.cuda.is_available()
except AttributeError:
_torch_cuda_is_available = False
import numpy as np
from . import __version__
from .utils.desc import Desc
def share_array(arr_np):
"""
Return a ctypes array allocated from shared memory with data from a
NumPy array of type `float`.
Parameters
----------
arr_np : :obj:`numpy.ndarray`
NumPy array.
Returns
-------
array of :obj:`ctype`
"""
arr = mp.RawArray('d', arr_np.ravel())
return arr, arr_np.shape
def _predict_wkr(
r, r_desc_d_desc, lat_and_inv, glob_id, wkr_start_stop=None, chunk_size=None
):
"""
Compute (part) of a prediction.
Every prediction is a linear combination involving the training points used for
this model. This function evalutates that combination for the range specified by
`wkr_start_stop`. This workload can optionally be processed in chunks,
which can be faster as it requires less memory to be allocated.
Note
----
It is sufficient to provide either the parameter `r` or `r_desc_d_desc`.
The other one can be set to `None`.
Parameters
----------
r : :obj:`numpy.ndarray`
An array of size 3N containing the Cartesian
coordinates of each atom in the molecule.
r_desc_d_desc : tuple of :obj:`numpy.ndarray`
A tuple made up of:
(1) An array of size D containing the descriptors
of dimension D for the molecule.
(2) An array of size D x 3N containing the
descriptor Jacobian for the molecules. It has dimension
D with 3N partial derivatives with respect to the 3N
Cartesian coordinates of each atom.
lat_and_inv : tuple of :obj:`numpy.ndarray`
Tuple of 3 x 3 matrix containing lattice vectors as columns and
its inverse.
glob_id : int
Identifier of the global namespace that this
function is supposed to be using (zero if only one
instance of this class exists at the same time).
wkr_start_stop : tuple of int, optional
Range defined by the indices of first and last (exclusive)
sum element. The full prediction is generated if this parameter
is not specified.
chunk_size : int, optional
Chunk size. The whole linear combination is evaluated in a large
vector operation instead of looping over smaller chunks if this
parameter is left unspecified.
Returns
-------
:obj:`numpy.ndarray`
Partial prediction of all force components and
energy (appended to array as last element).
"""
global globs
glob = globs[glob_id]
sig, n_perms = glob['sig'], glob['n_perms']
desc_func = glob['desc_func']
R_desc_perms = np.frombuffer(glob['R_desc_perms']).reshape(
glob['R_desc_perms_shape']
)
R_d_desc_alpha_perms = np.frombuffer(glob['R_d_desc_alpha_perms']).reshape(
glob['R_d_desc_alpha_perms_shape']
)
if 'alphas_E_lin' in glob:
alphas_E_lin = np.frombuffer(glob['alphas_E_lin']).reshape(
glob['alphas_E_lin_shape']
)
r_desc, r_d_desc = r_desc_d_desc or desc_func.from_R(
r, lat_and_inv, max_processes=1
) # no additional forking during parallelization
n_train = int(R_desc_perms.shape[0] / n_perms)
wkr_start, wkr_stop = (0, n_train) if wkr_start_stop is None else wkr_start_stop
if chunk_size is None:
chunk_size = n_train
dim_d = desc_func.dim
dim_i = desc_func.dim_i
dim_c = chunk_size * n_perms
# Pre-allocate memory.
diff_ab_perms = np.empty((dim_c, dim_d))
a_x2 = np.empty((dim_c,))
mat52_base = np.empty((dim_c,))
# avoid divisions (slower)
sig_inv = 1.0 / sig
mat52_base_fact = 5.0 / (3 * sig**3)
diag_scale_fact = 5.0 / sig
sqrt5 = np.sqrt(5.0)
E_F = np.zeros((dim_d + 1,))
F = E_F[1:]
wkr_start *= n_perms
wkr_stop *= n_perms
b_start = wkr_start
for b_stop in list(range(wkr_start + dim_c, wkr_stop, dim_c)) + [wkr_stop]:
rj_desc_perms = R_desc_perms[b_start:b_stop, :]
rj_d_desc_alpha_perms = R_d_desc_alpha_perms[b_start:b_stop, :]
# Resize pre-allocated memory for last iteration, if chunk_size is not a divisor of the training set size.
# Note: It's faster to process equally sized chunks.
c_size = b_stop - b_start
if c_size < dim_c:
diff_ab_perms = diff_ab_perms[:c_size, :]
a_x2 = a_x2[:c_size]
mat52_base = mat52_base[:c_size]
np.subtract(
np.broadcast_to(r_desc, rj_desc_perms.shape),
rj_desc_perms,
out=diff_ab_perms,
)
norm_ab_perms = sqrt5 * np.linalg.norm(diff_ab_perms, axis=1)
np.exp(-norm_ab_perms * sig_inv, out=mat52_base)
mat52_base *= mat52_base_fact
np.einsum(
'ji,ji->j', diff_ab_perms, rj_d_desc_alpha_perms, out=a_x2
) # colum wise dot product
F += (a_x2 * mat52_base).dot(diff_ab_perms) * diag_scale_fact
mat52_base *= norm_ab_perms + sig
F -= mat52_base.dot(rj_d_desc_alpha_perms)
# Note: Energies are automatically predicted with a flipped sign here (because -E are trained, instead of E)
E_F[0] += a_x2.dot(mat52_base)
# Note: Energies are automatically predicted with a flipped sign here (because -E are trained, instead of E)
if 'alphas_E_lin' in glob:
K_fe = diff_ab_perms * mat52_base[:, None]
F += alphas_E_lin[b_start:b_stop].dot(K_fe)
K_ee = (
1 + (norm_ab_perms * sig_inv) * (1 + norm_ab_perms / (3 * sig))
) * np.exp(-norm_ab_perms * sig_inv)
E_F[0] += K_ee.dot(alphas_E_lin[b_start:b_stop])
b_start = b_stop
out = E_F[: dim_i + 1]
# Descriptor has less entries than 3N, need to extend size of the 'E_F' array.
if dim_d < dim_i:
out = np.empty((dim_i + 1,))
out[0] = E_F[0]
out[1:] = desc_func.vec_dot_d_desc(
r_d_desc,
F,
) # 'r_d_desc.T.dot(F)' for our special representation of 'r_d_desc'
return out
class GDMLPredict(object):
def __init__(
self,
model,
batch_size=None,
num_workers=None,
max_memory=None,
max_processes=None,
use_torch=False,
log_level=None,
):
"""
Query trained sGDML force fields.
This class is used to load a trained model and make energy and
force predictions for new geometries. GPU support is provided
through PyTorch (requires optional `torch` dependency to be
installed).
Note
----
The parameters `batch_size` and `num_workers` are only
relevant if this code runs on a CPU. Both can be set
automatically via the function `prepare_parallel`.
Note: Running calculations via PyTorch is only
recommended with available GPU hardware. CPU calcuations
are faster with our NumPy implementation.
Parameters
----------
model : :obj:`dict`
Data structure that holds all parameters of the
trained model. This object is the output of
`GDMLTrain.train`
batch_size : int, optional
Chunk size for processing parallel tasks
num_workers : int, optional
Number of parallel workers (in addition to the main
process)
max_memory : int, optional
Limit the max. memory usage [GB]. This is only a
soft limit that can not always be enforced.
max_processes : int, optional
Limit the max. number of processes. Otherwise
all CPU cores are used. This parameters has no
effect if `use_torch=True`
use_torch : boolean, optional
Use PyTorch to calculate predictions
log_level : optional
Set custom logging level (e.g. `logging.CRITICAL`)
"""
global globs
if 'globs' not in globals():
globs = []
# Create a personal global space for this model at a new index
# Note: do not call delete entries in this list, since 'self.glob_id' is
# static. Instead, setting them to None conserves positions while still
# freeing up memory.
globs.append({})
self.glob_id = len(globs) - 1
glob = globs[self.glob_id]
self.log = logging.getLogger(__name__)
if log_level is not None:
self.log.setLevel(log_level)
total_memory = psutil.virtual_memory().total // 2**30 # bytes to GB)
self.max_memory = (
min(max_memory, total_memory) if max_memory is not None else total_memory
)
total_cpus = mp.cpu_count()
self.max_processes = (
min(max_processes, total_cpus) if max_processes is not None else total_cpus
)
if 'type' not in model or not (model['type'] == 'm' or model['type'] == b'm'):
self.log.critical('The provided data structure is not a valid model.')
sys.exit()
self.n_atoms = model['z'].shape[0]
self.desc = Desc(self.n_atoms, max_processes=max_processes)
glob['desc_func'] = self.desc
# Cache for iterative training mode.
self.R_desc = None
self.R_d_desc = None
self.lat_and_inv = (
(model['lattice'], np.linalg.inv(model['lattice']))
if 'lattice' in model
else None
)
self.n_train = model['R_desc'].shape[1]
glob['sig'] = model['sig']
self.std = model['std'] if 'std' in model else 1.0
self.c = model['c']
n_perms = model['perms'].shape[0]
glob['n_perms'] = n_perms
self.tril_perms_lin = model['tril_perms_lin']
self.torch_predict = None
self.use_torch = use_torch
if use_torch:
if not _has_torch:
raise ImportError(
'Optional PyTorch dependency not found! Please run \'pip install sgdml[torch]\' to install it or disable the PyTorch option.'
)
from .torchtools import GDMLTorchPredict
self.torch_predict = GDMLTorchPredict(
model,
self.lat_and_inv,
max_memory=max_memory,
max_processes=max_processes,
log_level=self.log.level,
)
# Enable data parallelism
n_gpu = torch.cuda.device_count()
if n_gpu > 1:
self.torch_predict = torch.nn.DataParallel(self.torch_predict)
# Send model to device
# self.torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
if _torch_cuda_is_available:
self.torch_device = 'cuda'
elif _torch_mps_is_available:
self.torch_device = 'mps'
else:
self.torch_device = 'cpu'
while True:
try:
self.torch_predict.to(self.torch_device)
except RuntimeError as e:
if 'out of memory' in str(e):
if _torch_cuda_is_available:
torch.cuda.empty_cache()
model = self.torch_predict
if isinstance(self.torch_predict, torch.nn.DataParallel):
model = model.module
if (
model.get_n_perm_batches() == 1
): # model caches the permutations, this could be why it is too large
model.set_n_perm_batches(
model.get_n_perm_batches() + 1
) # uncache
# self.torch_predict.to( # NOTE!
# self.torch_device
# ) # try sending to device again
pass
else:
self.log.critical(
'Not enough memory on device (RAM or GPU memory). There is no hope!'
)
print()
os._exit(1)
else:
raise e
else:
break
else:
# Precompute permuted training descriptors and its first derivatives multiplied with the coefficients.
R_desc_perms = (
np.tile(model['R_desc'].T, n_perms)[:, self.tril_perms_lin]
.reshape(self.n_train, n_perms, -1, order='F')
.reshape(self.n_train * n_perms, -1)
)
glob['R_desc_perms'], glob['R_desc_perms_shape'] = share_array(R_desc_perms)
R_d_desc_alpha_perms = (
np.tile(model['R_d_desc_alpha'], n_perms)[:, self.tril_perms_lin]
.reshape(self.n_train, n_perms, -1, order='F')
.reshape(self.n_train * n_perms, -1)
)
(
glob['R_d_desc_alpha_perms'],
glob['R_d_desc_alpha_perms_shape'],
) = share_array(R_d_desc_alpha_perms)
if 'alphas_E' in model:
alphas_E_lin = np.tile(model['alphas_E'][:, None], (1, n_perms)).ravel()
glob['alphas_E_lin'], glob['alphas_E_lin_shape'] = share_array(
alphas_E_lin
)
# Parallel processing configuration
self.bulk_mp = False # Bulk predictions with multiple processes?
self.pool = None
# How many workers in addition to main process?
num_workers = num_workers or (
self.max_processes - 1
) # exclude main process
self._set_num_workers(num_workers, force_reset=True)
# Size of chunks in which each parallel task will be processed (unit: number of training samples)
# This parameter should be as large as possible, but it depends on the size of available memory.
self._set_chunk_size(batch_size)
def __del__(self):
global globs
try:
self.pool.terminate()
self.pool.join()
self.pool = None
except:
pass
if 'globs' in globals() and globs is not None and self.glob_id < len(globs):
globs[self.glob_id] = None
## Public ##
# def set_R(self, R):
# """
# Store a reference to the training geometries.
# This function is used to avoid unnecessary copies of the
# traininig geometries when evaluation the training error
# (= gradient of the model's loss function).
# This routine is used during iterative model training.
# Parameters
# ----------
# R : :obj:`numpy.ndarray`
# Array containing the geometry for each training point.
# """
# # Add singleton dimension if input is (,3N).
# if R.ndim == 1:
# R = R[None, :]
# self.R = R
# # if self.use_torch:
# # model = self.torch_predict
# # if isinstance(self.torch_predict, torch.nn.DataParallel):
# # model = model.module
# # R_torch = torch.from_numpy(R.reshape(-1, self.n_atoms, 3)).to(self.torch_device)
# # model.set_R(R_torch)
def set_R_desc(self, R_desc):
"""
Store a reference to the training geometry descriptors.
This can accelerate iterative model training.
Parameters
----------
R_desc : :obj:`numpy.ndarray`, optional
An 2D array of size M x D containing the
descriptors of dimension D for M
molecules.
"""
self.R_desc = R_desc
def set_R_d_desc(self, R_d_desc):
"""
Store a reference to the training geometry descriptor Jacobians.
This function must be called before `set_alphas()` can be used.
This routine is used during iterative model training.
Parameters
----------
R_d_desc : :obj:`numpy.ndarray`, optional
A 2D array of size M x D x 3N containing of the
descriptor Jacobians for M molecules. The descriptor
has dimension D with 3N partial derivatives with
respect to the 3N Cartesian coordinates of each atom.
"""
self.R_d_desc = R_d_desc
if self.use_torch:
model = self.torch_predict
if isinstance(self.torch_predict, torch.nn.DataParallel):
model = model.module
model.set_R_d_desc(R_d_desc)
def set_alphas(self, alphas_F, alphas_E=None):
"""
Reconfigure the current model with a new set of regression parameters.
`R_d_desc` needs to be set for this function to work.
This routine is used during iterative model training.
Parameters
----------
alphas_F : :obj:`numpy.ndarray`
1D array containing the new model parameters.
alphas_E : :obj:`numpy.ndarray`, optional
1D array containing the additional new model parameters, if
energy constraints are used in the kernel (`use_E_cstr=True`)
"""
if self.use_torch:
model = self.torch_predict
if isinstance(self.torch_predict, torch.nn.DataParallel):
model = model.module
model.set_alphas(alphas_F, alphas_E=alphas_E)
else:
assert self.R_d_desc is not None
global globs
glob = globs[self.glob_id]
dim_i = self.desc.dim_i
R_d_desc_alpha = self.desc.d_desc_dot_vec(
self.R_d_desc, alphas_F.reshape(-1, dim_i)
)
R_d_desc_alpha_perms_new = np.tile(R_d_desc_alpha, glob['n_perms'])[
:, self.tril_perms_lin
].reshape(self.n_train, glob['n_perms'], -1, order='F')
R_d_desc_alpha_perms = np.frombuffer(glob['R_d_desc_alpha_perms'])
np.copyto(R_d_desc_alpha_perms, R_d_desc_alpha_perms_new.ravel())
if alphas_E is not None:
alphas_E_lin_new = np.tile(
alphas_E[:, None], (1, glob['n_perms'])
).ravel()
alphas_E_lin = np.frombuffer(glob['alphas_E_lin'])
np.copyto(alphas_E_lin, alphas_E_lin_new)
def _set_num_workers(
self, num_workers=None, force_reset=False
): # TODO: complain if chunk or worker parameters do not fit training data (this causes issues with the caching)!!
"""
Set number of processes to use during prediction.
If bulk_mp == True, each worker handles the whole generation of single prediction (this if for querying multiple geometries at once)
If bulk_mp == False, each worker may handle only a part of a prediction (chunks are defined in 'wkr_starts_stops'). In that scenario multiple proesses
are used to distribute the work of generating a single prediction
This number should not exceed the number of available CPU cores.
Note
----
This parameter can be optimally determined using
`prepare_parallel`.
Parameters
----------
num_workers : int, optional
Number of processes (maximum value is set if `None`).
force_reset : bool, optional
Force applying the new setting.
"""
if force_reset or self.num_workers is not num_workers:
if self.pool is not None:
self.pool.terminate()
self.pool.join()
self.pool = None
self.num_workers = 0
if num_workers is None or num_workers > 0:
self.pool = Pool(num_workers)
self.num_workers = (
self.pool._processes
) # number of actual workers (not max_processes)
# Data ranges for processes
if self.bulk_mp or self.num_workers < 2:
# wkr_starts = [self.n_train]
wkr_starts = [0]
else:
wkr_starts = list(
range(
0,
self.n_train,
int(np.ceil(float(self.n_train) / self.num_workers)),
)
)
wkr_stops = wkr_starts[1:] + [self.n_train]
self.wkr_starts_stops = list(zip(wkr_starts, wkr_stops))
def _set_chunk_size(self, chunk_size=None):
# TODO: complain if chunk or worker parameters do not fit training data (this causes issues with the caching)!!
"""
Set chunk size for each worker process.
Every prediction is generated as a linear combination of the training
points that the model is comprised of. If multiple workers are available
(and bulk mode is disabled), each one processes an (approximatelly equal)
part of those training points. Then, the chunk size determines how much of
a processes workload is passed to NumPy's underlying low-level routines at
once. If the chunk size is smaller than the number of points the worker is
supposed to process, it processes them in multiple steps using a loop. This
can sometimes be faster, depending on the available hardware.
Note
----
This parameter can be optimally determined using
`prepare_parallel`.
Parameters
----------
chunk_size : int
Chunk size (maximum value is set if `None`).
"""
if chunk_size is None:
chunk_size = self.n_train
self.chunk_size = chunk_size
def _set_batch_size(self, batch_size=None): # deprecated
"""
Warning
-------
Deprecated! Please use the function `_set_chunk_size` in future projects.
Set chunk size for each worker process. A chunk is a subset
of the training data points whose linear combination needs to
be evaluated in order to generate a prediction.
The chunk size determines how much of a processes workload will
be passed to Python's underlying low-level routines at once.
This parameter is highly hardware dependent.
Note
----
This parameter can be optimally determined using
`prepare_parallel`.
Parameters
----------
batch_size : int
Chunk size (maximum value is set if `None`).
"""
self._set_chunk_size(batch_size)
def _set_bulk_mp(self, bulk_mp=False):
"""
Toggles bulk prediction mode.
If bulk prediction is enabled, the prediction is parallelized accross
input geometries, i.e. each worker generates the complete prediction for
one query. Otherwise (depending on the number of available CPU cores) the
input geometries are process sequentially, but every one of them may be
processed by multiple workers at once (in chunks).
Note
----
This parameter can be optimally determined using
`prepare_parallel`.
Parameters
----------
bulk_mp : bool, optional
Enable or disable bulk prediction mode.
"""
bulk_mp = bool(bulk_mp)
if self.bulk_mp is not bulk_mp:
self.bulk_mp = bulk_mp
# Reset data ranges for processes stored in 'wkr_starts_stops'
self._set_num_workers(self.num_workers)
def set_opt_num_workers_and_batch_size_fast(self, n_bulk=1, n_reps=1): # deprecated
"""
Warning
-------
Deprecated! Please use the function `prepare_parallel` in future projects.
Parameters
----------
n_bulk : int, optional
Number of geometries that will be passed to the
`predict` function in each call (performance
will be optimized for that exact use case).
n_reps : int, optional
Number of repetitions (bigger value: more
accurate, but also slower).
Returns
-------
int
Force and energy prediciton speed in geometries
per second.
"""
self.prepare_parallel(n_bulk, n_reps)
def prepare_parallel(
self, n_bulk=1, n_reps=1, return_is_from_cache=False
): # noqa: C901
"""
Find and set the optimal parallelization parameters for the
currently loaded model, running on a particular system. The result
also depends on the number of geometries `n_bulk` that will be
passed at once when calling the `predict` function.
This function runs a benchmark in which the prediction routine is
repeatedly called `n_reps`-times (default: 1) with varying parameter
configurations, while the runtime is measured for each one. The
optimal parameters are then cached for fast retrival in future
calls of this function.
We recommend calling this function after initialization of this
class, as it will drastically increase the performance of the
`predict` function.
Note
----
Depending on the parameter `n_reps`, this routine may take
some seconds/minutes to complete. However, once a
statistically significant number of benchmark results has
been gathered for a particular configuration, it starts
returning almost instantly.
Parameters
----------
n_bulk : int, optional
Number of geometries that will be passed to the
`predict` function in each call (performance
will be optimized for that exact use case).
n_reps : int, optional
Number of repetitions (bigger value: more
accurate, but also slower).
return_is_from_cache : bool, optional
If enabled, this function returns a second value
indicating if the returned results were obtained
from cache.
Returns
-------
int
Force and energy prediciton speed in geometries
per second.
boolean, optional
Return, whether this function obtained the results
from cache.
"""
# global globs
# glob = globs[self.glob_id]
# n_perms = glob['n_perms']
# No benchmarking necessary if prediction is running on GPUs.
if self.use_torch:
self.log.info(
'Skipping multi-CPU benchmark, since torch is enabled.'
) # TODO: clarity!
return
# Retrieve cached benchmark results, if available.
bmark_result = self._load_cached_bmark_result(n_bulk)
if bmark_result is not None:
num_workers, chunk_size, bulk_mp, gps = bmark_result
self._set_chunk_size(chunk_size)
self._set_num_workers(num_workers)
self._set_bulk_mp(bulk_mp)
if return_is_from_cache:
is_from_cache = True
return gps, is_from_cache
else:
return gps
warm_up_done = False
best_results = []
last_i = None
best_gps = 0
gps_min = 0.0
best_params = None
r_dummy = np.random.rand(n_bulk, self.n_atoms * 3)
def _dummy_predict():
self.predict(r_dummy)
bulk_mp_rng = [True, False] if n_bulk > 1 else [False]
for bulk_mp in bulk_mp_rng:
self._set_bulk_mp(bulk_mp)
if bulk_mp is False:
last_i = 0
num_workers_rng = list(range(0, self.max_processes))
if bulk_mp:
num_workers_rng.reverse() # benchmark converges faster this way
# num_workers_rng_sizes = [batch_size for batch_size in batch_size_rng if min_batch_size % batch_size == 0]
# for num_workers in range(min_num_workers,self.max_processes+1):
for num_workers in num_workers_rng:
if not bulk_mp and num_workers != 0 and self.n_train % num_workers != 0:
continue
self._set_num_workers(num_workers)
best_gps = 0
gps_rng = (np.inf, 0.0) # min and max per num_workers
min_chunk_size = (
min(self.n_train, n_bulk)
if bulk_mp or num_workers < 2
else int(np.ceil(self.n_train / num_workers))
)
chunk_size_rng = list(range(min_chunk_size, 0, -1))
chunk_size_rng_sizes = [
chunk_size
for chunk_size in chunk_size_rng
if min_chunk_size % chunk_size == 0
]
# print('batch_size_rng_sizes ' + str(bulk_mp))
# print(batch_size_rng_sizes)
i_done = 0
i_dir = 1
i = 0 if last_i is None else last_i
# i = 0
# print(batch_size_rng_sizes)
while i >= 0 and i < len(chunk_size_rng_sizes):
chunk_size = chunk_size_rng_sizes[i]
self._set_chunk_size(chunk_size)
i_done += 1
if warm_up_done == False:
timeit.timeit(_dummy_predict, number=10)
warm_up_done = True
gps = n_bulk * n_reps / timeit.timeit(_dummy_predict, number=n_reps)
# print(
# '{:2d}@{:d} {:d} | {:7.2f} gps'.format(
# num_workers, chunk_size, bulk_mp, gps
# )
# )
gps_rng = (
min(gps_rng[0], gps),
max(gps_rng[1], gps),
) # min and max per num_workers
# gps_min_max = min(gps_min_max[0], gps), max(gps_min_max[1], gps)
# print(' best_gps ' + str(best_gps))
# NEW
# if gps > best_gps and gps > gps_min: # gps is still going up, everything is good
# best_gps = gps
# best_params = num_workers, batch_size, bulk_mp
# else:
# break
# if gps > best_gps: # gps is still going up, everything is good
# best_gps = gps
# best_params = num_workers, batch_size, bulk_mp
# else: # gps did not go up wrt. to previous step
# # can we switch the search direction?
# # did we already?
# # we checked two consecutive configurations
# # are bigger batch sizes possible?
# print(batch_size_rng_sizes)
# turn_search_dir = i_dir > 0 and i_done == 2 and batch_size != batch_size_rng_sizes[1]
# # only turn, if the current gps is not lower than the lowest overall
# if turn_search_dir and gps >= gps_min:
# i -= 2 * i_dir
# i_dir = -1
# print('><')
# continue
# else:
# print('>>break ' + str(i_done))
# break
# NEW
# gps still going up?
# AND: gps not lower than the lowest overall?
# if gps < best_gps and gps >= gps_min:
if gps < best_gps:
if (
i_dir > 0
and i_done == 2
and chunk_size
!= chunk_size_rng_sizes[
1
] # there is no point in turning if this is the second batch size in the range
): # do we turn?
i -= 2 * i_dir
i_dir = -1
# print('><')
continue
else:
if chunk_size == chunk_size_rng_sizes[1]:
i -= 1 * i_dir
# print('>>break ' + str(i_done))
break
else:
best_gps = gps
best_params = num_workers, chunk_size, bulk_mp
if (
not bulk_mp and n_bulk > 1
): # stop search early when multiple cpus are available and the 1 cpu case is tested
if (
gps < gps_min
): # if the batch size run is lower than the lowest overall, stop right here
# print('breaking here')
break
i += 1 * i_dir
last_i = i - 1 * i_dir
i_dir = 1
if len(best_results) > 0:
overall_best_gps = max(best_results, key=lambda x: x[1])[1]
if best_gps < overall_best_gps:
# print('breaking, because best of last test was worse than overall best so far')
break
# if best_gps < gps_min:
# print('breaking here3')
# break
gps_min = gps_rng[0] # FIX me: is this the overall min?
# print ('gps_min ' + str(gps_min))
# print ('best_gps')
# print (best_gps)
best_results.append(
(best_params, best_gps)
) # best results per num_workers
(num_workers, chunk_size, bulk_mp), gps = max(best_results, key=lambda x: x[1])
# Cache benchmark results.
self._save_cached_bmark_result(n_bulk, num_workers, chunk_size, bulk_mp, gps)
self._set_chunk_size(chunk_size)
self._set_num_workers(num_workers)
self._set_bulk_mp(bulk_mp)
if return_is_from_cache:
is_from_cache = False
return gps, is_from_cache
else:
return gps
def _save_cached_bmark_result(self, n_bulk, num_workers, chunk_size, bulk_mp, gps):
pkg_dir = os.path.dirname(os.path.abspath(__file__))
bmark_file = '_bmark_cache.npz'
bmark_path = os.path.join(pkg_dir, bmark_file)
bkey = '{}-{}-{}-{}'.format(
self.n_atoms, self.n_train, n_bulk, self.max_processes
)
if os.path.exists(bmark_path):
with np.load(bmark_path, allow_pickle=True) as bmark:
bmark = dict(bmark)
bmark['runs'] = np.append(bmark['runs'], bkey)
bmark['num_workers'] = np.append(bmark['num_workers'], num_workers)
bmark['batch_size'] = np.append(bmark['batch_size'], chunk_size)
bmark['bulk_mp'] = np.append(bmark['bulk_mp'], bulk_mp)
bmark['gps'] = np.append(bmark['gps'], gps)
else:
bmark = {
'code_version': __version__,
'runs': [bkey],
'gps': [gps],
'num_workers': [num_workers],
'batch_size': [chunk_size],
'bulk_mp': [bulk_mp],
}
np.savez_compressed(bmark_path, **bmark)
def _load_cached_bmark_result(self, n_bulk):
pkg_dir = os.path.dirname(os.path.abspath(__file__))
bmark_file = '_bmark_cache.npz'
bmark_path = os.path.join(pkg_dir, bmark_file)
bkey = '{}-{}-{}-{}'.format(
self.n_atoms, self.n_train, n_bulk, self.max_processes
)
if not os.path.exists(bmark_path):
return None
with np.load(bmark_path, allow_pickle=True) as bmark:
# Keep collecting benchmark runs, until we have at least three.
run_idxs = np.where(bmark['runs'] == bkey)[0]
if len(run_idxs) >= 3:
config_keys = []
for run_idx in run_idxs:
config_keys.append(
'{}-{}-{}'.format(
bmark['num_workers'][run_idx],
bmark['batch_size'][run_idx],
bmark['bulk_mp'][run_idx],
)
)
values, uinverse = np.unique(config_keys, return_index=True)
best_mean = -1
best_gps = 0
for i, config_key in enumerate(zip(values, uinverse)):
mean_gps = np.mean(
bmark['gps'][
np.where(np.array(config_keys) == config_key[0])[0]
]
)
if best_gps == 0 or best_gps < mean_gps:
best_mean = i
best_gps = mean_gps
best_idx = run_idxs[uinverse[best_mean]]
num_workers = bmark['num_workers'][best_idx]
chunk_size = bmark['batch_size'][best_idx]
bulk_mp = bmark['bulk_mp'][best_idx]
return num_workers, chunk_size, bulk_mp, best_gps
return None
def get_GPU_batch(self):
"""
Get batch size used by the GPU implementation to process bulk
predictions (predictions for multiple input geometries at once).
This value is determined on-the-fly depending on the available GPU
memory.
"""
if self.use_torch:
model = self.torch_predict
if isinstance(model, torch.nn.DataParallel):
model = model.module
return model._batch_size()
def predict(self, R=None, return_E=True):
"""
Predict energy and forces for multiple geometries. This function
can run on the GPU, if the optional PyTorch dependency is
installed and `use_torch=True` was speciefied during
initialization of this class.
Optionally, the descriptors and descriptor Jacobians for the
same geometries can be provided, if already available from some
previous calculations.
Note
----
The order of the atoms in `R` is not arbitrary and must
be the same as used for training the model.
Parameters
----------
R : :obj:`numpy.ndarray`, optional
An 2D array of size M x 3N containing the
Cartesian coordinates of each atom of M
molecules. If this parameter is ommited, the training
error is returned. Note that the training geometries
need to be set right after initialization using
`set_R()` for this to work.
return_E : boolean, optional
If false (default: true), only the forces are returned.
Returns
-------
:obj:`numpy.ndarray`
Energies stored in an 1D array of size M (unless `return_E == False`)
:obj:`numpy.ndarray`
Forces stored in an 2D arry of size M x 3N.
"""
# Add singleton dimension if input is (,3N).
if R is not None and R.ndim == 1:
R = R[None, :]
if self.use_torch: # multi-GPU (or CPU if no GPUs are available)
R_torch = torch.arange(self.n_train)
if R is None:
if self.R_d_desc is None:
self.log.critical(
'A reference to the training geometry descriptors needs to be set (using \'set_R_d_desc()\') for this function to work without arguments (using PyTorch).'
)
print()
os._exit(1)
else:
R_torch = (
torch.from_numpy(R.reshape(-1, self.n_atoms, 3))
.type(torch.float32)
.to(self.torch_device)
)
model = self.torch_predict
if R_torch.shape[0] < torch.cuda.device_count() and isinstance(
model, torch.nn.DataParallel
):
model = self.torch_predict.module
E_torch_F_torch = model.forward(R_torch, return_E=return_E)
if return_E:
E_torch, F_torch = E_torch_F_torch
E = E_torch.cpu().numpy()
else:
(F_torch,) = E_torch_F_torch
F = F_torch.cpu().numpy().reshape(-1, 3 * self.n_atoms)
else: # multi-CPU
# Use precomputed descriptors in training mode.
is_desc_in_cache = self.R_desc is not None and self.R_d_desc is not None
if R is None and not is_desc_in_cache:
self.log.critical(
'A reference to the training geometry descriptors and Jacobians needs to be set for this function to work without arguments.'
)
print()
os._exit(1)
assert is_desc_in_cache or R is not None
dim_i = 3 * self.n_atoms
n_pred = self.R_desc.shape[0] if R is None else R.shape[0]
E_F = np.empty((n_pred, dim_i + 1))
if (
self.bulk_mp and self.num_workers > 0
): # One whole prediction per worker (and multiple workers).
_predict_wo_r_or_desc = partial(
_predict_wkr,
lat_and_inv=self.lat_and_inv,
glob_id=self.glob_id,
wkr_start_stop=None,
chunk_size=self.chunk_size,
)
for i, e_f in enumerate(
self.pool.imap(
partial(_predict_wo_r_or_desc, None)
if is_desc_in_cache
else partial(_predict_wo_r_or_desc, r_desc_d_desc=None),
zip(self.R_desc, self.R_d_desc) if is_desc_in_cache else R,
)
):
E_F[i, :] = e_f
else: # Multiple workers per prediction (or just one worker).
for i in range(n_pred):
if is_desc_in_cache:
r_desc, r_d_desc = self.R_desc[i], self.R_d_desc[i]
else:
r_desc, r_d_desc = self.desc.from_R(R[i], self.lat_and_inv)
_predict_wo_wkr_starts_stops = partial(
_predict_wkr,
None,
(r_desc, r_d_desc),
self.lat_and_inv,
self.glob_id,
chunk_size=self.chunk_size,
)
if self.num_workers == 0:
E_F[i, :] = _predict_wo_wkr_starts_stops()
else:
E_F[i, :] = sum(
self.pool.imap_unordered(
_predict_wo_wkr_starts_stops, self.wkr_starts_stops
)
)
E_F *= self.std
F = E_F[:, 1:]
E = E_F[:, 0] + self.c
ret = (F,)
if return_E:
ret = (E,) + ret
return ret
================================================
FILE: sgdml/solvers/__init__.py
================================================
================================================
FILE: sgdml/solvers/analytic.py
================================================
#!/usr/bin/python
# MIT License
#
# Copyright (c) 2020-2022 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import sys
import logging
import warnings
from functools import partial
import numpy as np
import scipy as sp
import timeit
from .. import DONE, NOT_DONE
class Analytic(object):
def __init__(self, gdml_train, desc, callback=None):
self.log = logging.getLogger(__name__)
self.gdml_train = gdml_train
self.desc = desc
self.callback = callback
# from memory_profiler import profile
# @profile
def solve(self, task, R_desc, R_d_desc, tril_perms_lin, y):
sig = task['sig']
lam = task['lam']
use_E_cstr = task['use_E_cstr']
n_train, dim_d = R_d_desc.shape[:2]
n_atoms = int((1 + np.sqrt(8 * dim_d + 1)) / 2)
dim_i = 3 * n_atoms
if self.callback is not None:
self.callback = partial(
self.callback,
disp_str='Assembling kernel matrix',
)
K = -self.gdml_train._assemble_kernel_mat(
R_desc,
R_d_desc,
tril_perms_lin,
sig,
self.desc,
use_E_cstr=use_E_cstr,
callback=self.callback,
) # Flip sign to make convex
start = timeit.default_timer()
with warnings.catch_warnings():
warnings.simplefilter('ignore')
if K.shape[0] == K.shape[1]:
K[np.diag_indices_from(K)] += lam # Regularize
if self.callback is not None:
self.callback = partial(
self.callback,
disp_str='Solving linear system (Cholesky factorization)',
)
self.callback(NOT_DONE)
try:
# Cholesky (do not overwrite K in case we need to retry)
L, lower = sp.linalg.cho_factor(
K, overwrite_a=False, check_finite=False
)
alphas = -sp.linalg.cho_solve(
(L, lower), y, overwrite_b=False, check_finite=False
)
except np.linalg.LinAlgError: # Try a solver that makes less assumptions
if self.callback is not None:
self.callback = partial(
self.callback,
disp_str='Solving linear system (LU factorization) ', # Keep whitespaces!
)
self.callback(NOT_DONE)
try:
# LU
alphas = -sp.linalg.solve(
K, y, overwrite_a=True, overwrite_b=True, check_finite=False
)
except MemoryError:
self.log.critical(
'Not enough memory to train this system using a closed form solver.'
)
print()
os._exit(1)
except MemoryError:
self.log.critical(
'Not enough memory to train this system using a closed form solver.'
)
print()
os._exit(1)
else:
if self.callback is not None:
self.callback = partial(
self.callback,
disp_str='Solving over-determined linear system (least squares approximation)',
)
self.callback(NOT_DONE)
# Least squares for non-square K
alphas = -np.linalg.lstsq(K, y, rcond=-1)[0]
stop = timeit.default_timer()
if self.callback is not None:
dur_s = stop - start
sec_disp_str = 'took {:.1f} s'.format(dur_s) if dur_s >= 0.1 else ''
self.callback(
DONE,
disp_str='Training on {:,} points'.format(n_train),
sec_disp_str=sec_disp_str,
)
return alphas
@staticmethod
def est_memory_requirement(n_train, n_atoms):
est_bytes = 3 * (n_train * 3 * n_atoms) ** 2 * 8 # K + factor(s) of K
est_bytes += (n_train * 3 * n_atoms) * 8 # alpha
return est_bytes
================================================
FILE: sgdml/solvers/iterative.py
================================================
#!/usr/bin/python
# MIT License
#
# Copyright (c) 2020-2025 Stefan Chmiela
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import os
import logging
from functools import partial
import inspect
import multiprocessing as mp
import numpy as np
import scipy as sp
import timeit
import collections
from .. import DONE, NOT_DONE
from ..utils import ui
from ..predict import GDMLPredict
try:
import torch
except ImportError:
_has_torch = False
else:
_has_torch = True
CG_STEPS_HIST_LEN = (
100 # number of past steps to consider when calculatating solver effectiveness
)
EFF_RESTART_THRESH = 0 # if solver effectiveness is less than that percentage after 'CG_STEPS_HIST_LEN'-steps, a solver restart is triggert (with stronger preconditioner)
MAX_NUM_RESTARTS = 6
class CGRestartException(Exception):
pass
class Iterative(object):
def __init__(
self,
gdml_train,
desc,
max_memory,
max_processes,
use_torch,
callback=None,
):
self.log = logging.getLogger(__name__)
self.gdml_train = gdml_train
self.gdml_predict = None
self.desc = desc
self.callback = callback
self._max_memory = max_memory
self._max_processes = max_processes
self._use_torch = use_torch
def _init_precon_operator(
self, task, R_desc, R_d_desc, tril_perms_lin, inducing_pts_idxs, callback=None
):
lam = task['lam']
lam_inv = 1.0 / lam
sig = task['sig']
use_E_cstr = task['use_E_cstr']
L_inv_K_mn = self._nystroem_cholesky_factor(
R_desc,
R_d_desc,
tril_perms_lin,
sig,
lam,
use_E_cstr=use_E_cstr,
col_idxs=inducing_pts_idxs,
callback=callback,
)
L_inv_K_mn = np.ascontiguousarray(L_inv_K_mn)
lev_scores = np.einsum(
'i...,i...->...', L_inv_K_mn, L_inv_K_mn
) # compute leverage scores because it is basically free once we got the factor
m, n = L_inv_K_mn.shape
if self._use_torch and False: # TURNED OFF!
_torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
L_inv_K_mn_torch = torch.from_numpy(L_inv_K_mn).to(_torch_device)
global is_primed
is_primed = False
def _P_vec(v):
global is_primed
if not is_primed:
is_primed = True
return v
if self._use_torch and False: # TURNED OFF!
v_torch = torch.from_numpy(v).to(_torch_device)[:, None]
return (
L_inv_K_mn_torch.t().mm(L_inv_K_mn_torch.mm(v_torch)) - v_torch
).cpu().numpy() * lam_inv
else:
ret = L_inv_K_mn.T.dot(L_inv_K_mn.dot(v))
ret -= v
ret *= lam_inv
return ret
return sp.sparse.linalg.LinearOperator((n, n), matvec=_P_vec), lev_scores
def _init_kernel_operator(
self, task, R_desc, R_d_desc, tril_perms_lin, lam, n, callback=None
):
n_train = R_desc.shape[0]
# dummy alphas
v_F = np.zeros((n - n_train, 1)) if task['use_E_cstr'] else np.zeros((n, 1))
v_E = np.zeros((n_train, 1)) if task['use_E_cstr'] else None
# Note: The standard deviation is set to 1.0, because we are predicting normalized labels here.
model = self.gdml_train.create_model(
task, 'cg', R_desc, R_d_desc, tril_perms_lin, 1.0, v_F, alphas_E=v_E
)
self.gdml_predict = GDMLPredict(
model,
max_memory=self._max_memory,
max_processes=self._max_processes,
use_torch=self._use_torch,
)
self.gdml_predict.set_R_desc(R_desc) # only needed on CPU
self.gdml_predict.set_R_d_desc(R_d_desc)
if not self._use_torch:
if callback is not None:
callback = partial(callback, disp_str='Optimizing CPU parallelization')
callback(NOT_DONE)
self.gdml_predict.prepare_parallel(n_bulk=n_train)
if callback is not None:
callback(DONE)
global is_primed
is_primed = False
def _K_vec(v):
global is_primed
if not is_primed:
is_primed = True
return v
v_F, v_E = v, None
if task['use_E_cstr']:
v_F, v_E = v[:-n_train], v[-n_train:]
self.gdml_predict.set_alphas(v_F, alphas_E=v_E)
pred = self.gdml_predict.predict(return_E=task['use_E_cstr'])
if task['use_E_cstr']:
e_pred, f_pred = pred
pred = np.hstack((f_pred.ravel(), -e_pred))
else:
pred = pred[0].ravel()
pred -= lam * v
return pred
return sp.sparse.linalg.LinearOperator((n, n), matvec=_K_vec)
def _nystroem_cholesky_factor(
self,
R_desc,
R_d_desc,
tril_perms_lin,
sig,
lam,
use_E_cstr,
col_idxs,
callback_task_name='',
callback=None,
):
if callback_task_name != '':
callback_task_name = ' ({})'.format(callback_task_name)
if callback is not None:
callback = partial(
callback,
disp_str='Assembling kernel [m x k]{}'.format(callback_task_name),
)
dim_d = R_desc.shape[1]
n_atoms = int((1 + np.sqrt(8 * dim_d + 1)) / 2)
n = R_desc.shape[0] * n_atoms * 3 + (R_desc.shape[0] if use_E_cstr else 0)
m = len(
range(*col_idxs.indices(n)) if isinstance(col_idxs, slice) else col_idxs
)
K_nmm = self.gdml_train._assemble_kernel_mat(
R_desc,
R_d_desc,
tril_perms_lin,
sig,
self.desc,
use_E_cstr=use_E_cstr,
col_idxs=col_idxs,
alloc_extra_rows=m,
callback=callback,
)
# Store (psd) copy of K_mm in lower part of this oversized K_(n+m)m matrix.
K_nmm[-m:, :] = -K_nmm[col_idxs, :]
K_nm = K_nmm[:-m, :]
K_mm = K_nmm[-m:, :]
if callback is not None:
callback = partial(
callback,
disp_str='Cholesky fact. (1/2) [k x k]{}'.format(callback_task_name),
)
callback(NOT_DONE)
# Additional regularization is almost always necessary here (hence pre_reg=True).
K_mm, lower = self._cho_factor_stable(K_mm, pre_reg=True) # overwrites input!
L_mm = K_mm
# del K_mm
if callback is not None:
callback(DONE)
callback = partial(
callback,
disp_str='m tri. solves (1/2) [k x k]{}'.format(callback_task_name),
)
callback(0, n)
b_start, b_size = 0, int(n / 4) # update in percentage steps of 25
for b_stop in list(range(b_size, n, b_size)) + [n]:
K_nm[b_start:b_stop, :] = sp.linalg.solve_triangular(
L_mm,
K_nm[b_start:b_stop, :].T,
lower=lower,
trans='T',
overwrite_b=True,
check_finite=False,
).T
b_start = b_stop
if callback is not None:
callback(b_stop, n)
del L_mm
K_nmm[-m:, :] = K_nm.T.dot(K_nm)
K_nmm[-m:, :][np.diag_indices_from(K_nmm[-m:, :])] += lam
inner = K_nmm[-m:, :]
if callback is not None:
callback = partial(
callback,
disp_str='Cholesky fact. (2/2) [k x k]{}'.format(callback_task_name),
)
callback(NOT_DONE)
L_lower = self._cho_factor_stable(
inner, eps_mag_max=-14
) # Do not regularize more than 1e-14.
if L_lower is not None:
K_nmm[-m:, :], lower = L_lower
L = K_nmm[-m:, :]
del inner
else:
callback = partial(
callback,
disp_str='QR fact. (alt.) [k x k]{}'.format(callback_task_name),
)
callback(NOT_DONE)
K_nmm[-m:, :] = 0
K_nmm[-m:, :][np.diag_indices(m)] = np.sqrt(lam)
K_nmm[-m:, :] = np.linalg.qr(K_nmm, mode='r')
L = K_nmm[-m:, :]
lower = False
if callback is not None:
callback(DONE)
callback = partial(
callback,
disp_str='m tri. solves (2/2) [k x k]{}'.format(callback_task_name),
)
callback(0, n)
b_start, b_size = 0, int(n / 4) # update in percentage steps of 25
for b_stop in list(range(b_size, n, b_size)) + [n]:
K_nm[b_start:b_stop, :] = sp.linalg.solve_triangular(
L,
K_nm[b_start:b_stop, :].T,
lower=lower,
trans='T',
overwrite_b=True,
check_finite=False,
).T # Note: Overwrites K_nm to save memory
b_start = b_stop
if callback is not None:
callback(b_stop, n)
del L
return K_nm.T
def _lev_scores(
self,
R_desc,
R_d_desc,
tril_perms_lin,
sig,
lam,
use_E_cstr,
n_inducing_pts,
callback=None,
):
n_train, dim_d = R_d_desc.shape[:2]
dim_i = 3 * int((1 + np.sqrt(8 * dim_d + 1)) / 2)
# Convert from training points to actual columns.
# dim_m = (
# np.maximum(1, n_inducing_pts // 4) * dim_i
# ) # only use 1/4 of inducing points for leverage score estimate
dim_m = dim_i * min(n_inducing_pts, 10)
# Which columns to use for leverage score approximation?
lev_approx_idxs = np.sort(
np.random.choice(
n_train * dim_i + (n_train if use_E_cstr else 0), dim_m, replace=False
)
) # random subset of columns
# lev_approx_idxs = np.sort(np.random.choice(n_train*dim_i, dim_m, replace=False)) # random subset of columns
# lev_approx_idxs = np.s_[
# :dim_m
# ] # first 'dim_m' columns (faster kernel construction)
L_inv_K_mn = self._nystroem_cholesky_factor(
R_desc,
R_d_desc,
tril_perms_lin,
sig,
lam,
use_E_cstr=use_E_cstr,
col_idxs=lev_approx_idxs,
callback_task_name='lev. scores',
callback=callback,
)
lev_scores = np.einsum('i...,i...->...', L_inv_K_mn, L_inv_K_mn)
return lev_scores
def inducing_pts_from_lev_scores(self, lev_scores, N):
# Sample 'N' columns with probabilities proportional to the leverage scores.
inducing_pts_idxs = np.random.choice(
np.arange(lev_scores.size),
N,
replace=False,
p=lev_scores / lev_scores.sum(),
)
return np.sort(inducing_pts_idxs)
# performs a cholesky decompostion of a matrix, but regularizes the matrix (if neeeded) until its positive definite
def _cho_factor_stable(self, M, pre_reg=False, eps_mag_max=1):
"""
Performs a Cholesky decompostion of a matrix, but regularizes
as needed until its positive definite.
Parameters
----------
M : :obj:`numpy.ndarray`
Matrix to factorize.
pre_reg : boolean, optional
Regularize M right away (machine precision), before
trying to factorize it (default: False).
Returns
-------
:obj:`numpy.ndarray`
Matrix whose upper or lower triangle contains the Cholesky factor of a. Other parts of the matrix contain random data.
boolean
Flag indicating whether the factor is in the lower or upper triangle
"""
eps = np.finfo(float).eps
eps_mag = int(np.floor(np.log10(eps)))
if pre_reg:
M[np.diag_indices_from(M)] += eps
eps_mag += 1 # if additional regularization is necessary, start from the next order of magnitude
for reg in 10.0 ** np.arange(
eps_mag, eps_mag_max + 1
): # regularize more and more aggressively (strongest regularization: 1)
try:
L, lower = sp.linalg.cho_factor(
M, overwrite_a=False, check_finite=False
)
except np.linalg.LinAlgError as e:
if 'not positive definite' in str(e):
self.log.debug(
'Cholesky solver needs more aggressive regularization (adding {} to diagonal)'.format(
reg
)
)
M[np.diag_indices_from(M)] += reg
else:
raise e
else:
return L, lower
self.log.critical(
'Failed to factorize despite strong regularization (max: {})!\nYou could try a larger sigma.'.format(
10.0**eps_mag_max
)
)
print()
os._exit(1)
def solve(
self,
task,
R_desc,
R_d_desc,
tril_perms_lin,
y,
y_std,
tol=1e-4,
save_progr_callback=None,
):
global num_iters, start, resid, avg_tt, m # , P_t
n_train, n_atoms = task['R_train'].shape[:2]
dim_i = 3 * n_atoms
sig = task['sig']
lam = task['lam']
# these keys are only present if the task was created from an existing model
alphas0_F = task['alphas0_F'] if 'alphas0_F' in task else None
alphas0_E = task['alphas0_E'] if 'alphas0_E' in task else None
num_iters0 = task['solver_iters'] if 'solver_iters' in task else 0
# Number of inducing points to use for Nystrom approximation.
max_memory_bytes = self._max_memory * 1024**3
max_n_inducing_pts = Iterative.max_n_inducing_pts(
n_train, n_atoms, max_memory_bytes
)
n_inducing_pts = min(n_train, max_n_inducing_pts)
n_inducing_pts_init = (
len(task['inducing_pts_idxs']) // (3 * n_atoms)
if 'inducing_pts_idxs' in task
else None
)
if self.callback is not None:
self.callback = partial(
self.callback,
disp_str='Building preconditioner (k={} ind. point{})'.format(
n_inducing_pts, 's' if n_inducing_pts > 1 else ''
),
)
subtask_callback = (
partial(ui.sec_callback, main_callback=self.callback)
if self.callback is not None
else None
)
lev_scores = None
if n_inducing_pts_init is not None and n_inducing_pts_init == n_inducing_pts:
inducing_pts_idxs = task['inducing_pts_idxs'] # reuse old inducing points
else:
# Determine good inducing points.
lev_scores = self._lev_scores(
R_desc,
R_d_desc,
tril_perms_lin,
sig,
lam,
task['use_E_cstr'],
n_inducing_pts,
callback=subtask_callback,
)
dim_m = n_inducing_pts * dim_i
inducing_pts_idxs = self.inducing_pts_from_lev_scores(lev_scores, dim_m)
start = timeit.default_timer()
P_op, lev_scores = self._init_precon_operator(
task,
R_desc,
R_d_desc,
tril_perms_lin,
inducing_pts_idxs,
callback=subtask_callback,
)
stop = timeit.default_timer()
if self.callback is not None:
dur_s = stop - start
sec_disp_str = 'took {:.1f} s'.format(dur_s) if dur_s >= 0.1 else ''
self.callback(DONE, sec_disp_str=sec_disp_str)
self.callback = partial(
self.callback,
disp_str='Initializing solver',
)
subtask_callback = (
partial(ui.sec_callback, main_callback=self.callback)
if self.callback is not None
else None
)
n = P_op.shape[0]
K_op = self._init_kernel_operator(
task, R_desc, R_d_desc, tril_perms_lin, lam, n, callback=subtask_callback
)
num_iters = int(num_iters0)
if self.callback is not None:
num_devices = (
mp.cpu_count() if self._max_processes is None else self._max_processes
)
if self._use_torch:
gitextract_7dj7aa45/
├── .gitignore
├── LICENSE.txt
├── README.md
├── pyproject.toml
├── scripts/
│ ├── sgdml_dataset_from_aims.py
│ ├── sgdml_dataset_from_extxyz.py
│ ├── sgdml_dataset_from_ipi.py
│ ├── sgdml_dataset_to_extxyz.py
│ ├── sgdml_dataset_via_ase.py
│ └── sgdml_datasets_from_model.py
├── setup.cfg
├── setup.py
└── sgdml/
├── __init__.py
├── cli.py
├── get.py
├── intf/
│ ├── __init__.py
│ └── ase_calc.py
├── predict.py
├── solvers/
│ ├── __init__.py
│ ├── analytic.py
│ └── iterative.py
├── torchtools.py
├── train.py
└── utils/
├── __init__.py
├── desc.py
├── io.py
├── perm.py
└── ui.py
SYMBOL INDEX (169 symbols across 17 files)
FILE: scripts/sgdml_dataset_from_aims.py
function read_reference_data (line 36) | def read_reference_data(f): # noqa C901
FILE: scripts/sgdml_dataset_from_extxyz.py
function read_nonstd_ext_xyz (line 46) | def read_nonstd_ext_xyz(f):
function extract_info_from_extxyz (line 95) | def extract_info_from_extxyz(file_path):
FILE: scripts/sgdml_dataset_from_ipi.py
function raw_input_float (line 36) | def raw_input_float(prompt):
function read_concat_xyz (line 45) | def read_concat_xyz(f):
function read_out_file (line 79) | def read_out_file(f, col):
FILE: setup.py
function get_property (line 7) | def get_property(property, package):
FILE: sgdml/__init__.py
class ColoredFormatter (line 45) | class ColoredFormatter(logging.Formatter):
method __init__ (line 65) | def __init__(self, msg, use_color=True):
method format (line 70) | def format(self, record):
class ColoredLogger (line 95) | class ColoredLogger(logging.Logger):
method __init__ (line 96) | def __init__(self, name):
method done (line 117) | def done(self, msg, *args, **kwargs):
FILE: sgdml/cli.py
class AssistantError (line 77) | class AssistantError(Exception):
function _print_splash (line 81) | def _print_splash(max_memory, max_processes, use_torch):
function _check_update (line 146) | def _check_update():
function _print_billboard (line 168) | def _print_billboard():
function _print_dataset_properties (line 223) | def _print_dataset_properties(dataset, title_str='Dataset properties'):
function _print_task_properties_reduced (line 325) | def _print_task_properties_reduced(
function _print_task_properties (line 349) | def _print_task_properties(task, title_str='Task properties'):
function _print_model_properties (line 427) | def _print_model_properties(model, title_str='Model properties'):
function _print_next_step (line 558) | def _print_next_step(
function all (line 612) | def all(
function create (line 745) | def create( # noqa: C901
function train (line 946) | def train(
function _batch (line 1164) | def _batch(iterable, n=1):
function _online_err (line 1170) | def _online_err(err, size, n, mae_n_sum, rmse_n_sum):
function resume (line 1183) | def resume(
function validate (line 1288) | def validate(
function test (line 1327) | def test(
function select (line 1797) | def select(model_dir, overwrite, model_file=None, command=None, **kwargs...
function show (line 1940) | def show(file, command=None, **kwargs):
function reset (line 1955) | def reset(command=None, **kwargs):
function main (line 1979) | def main():
FILE: sgdml/get.py
function download (line 44) | def download(command, file_name):
function main (line 71) | def main():
FILE: sgdml/intf/ase_calc.py
class SGDMLCalculator (line 37) | class SGDMLCalculator(Calculator):
method __init__ (line 41) | def __init__(
method calculate (line 93) | def calculate(self, atoms=None, *args, **kwargs):
FILE: sgdml/predict.py
function share_array (line 65) | def share_array(arr_np):
function _predict_wkr (line 84) | def _predict_wkr(
class GDMLPredict (line 248) | class GDMLPredict(object):
method __init__ (line 249) | def __init__(
method __del__ (line 465) | def __del__(self):
method set_R_desc (line 510) | def set_R_desc(self, R_desc):
method set_R_d_desc (line 526) | def set_R_d_desc(self, R_d_desc):
method set_alphas (line 551) | def set_alphas(self, alphas_F, alphas_E=None):
method _set_num_workers (line 603) | def _set_num_workers(
method _set_chunk_size (line 658) | def _set_chunk_size(self, chunk_size=None):
method _set_batch_size (line 689) | def _set_batch_size(self, batch_size=None): # deprecated
method _set_bulk_mp (line 717) | def _set_bulk_mp(self, bulk_mp=False):
method set_opt_num_workers_and_batch_size_fast (line 745) | def set_opt_num_workers_and_batch_size_fast(self, n_bulk=1, n_reps=1):...
method prepare_parallel (line 770) | def prepare_parallel(
method _save_cached_bmark_result (line 1044) | def _save_cached_bmark_result(self, n_bulk, num_workers, chunk_size, b...
method _load_cached_bmark_result (line 1076) | def _load_cached_bmark_result(self, n_bulk):
method get_GPU_batch (line 1129) | def get_GPU_batch(self):
method predict (line 1146) | def predict(self, R=None, return_E=True):
FILE: sgdml/solvers/analytic.py
class Analytic (line 37) | class Analytic(object):
method __init__ (line 38) | def __init__(self, gdml_train, desc, callback=None):
method solve (line 49) | def solve(self, task, R_desc, R_d_desc, tril_perms_lin, y):
method est_memory_requirement (line 154) | def est_memory_requirement(n_train, n_atoms):
FILE: sgdml/solvers/iterative.py
class CGRestartException (line 56) | class CGRestartException(Exception):
class Iterative (line 60) | class Iterative(object):
method __init__ (line 61) | def __init__(
method _init_precon_operator (line 83) | def _init_precon_operator(
method _init_kernel_operator (line 144) | def _init_kernel_operator(
method _nystroem_cholesky_factor (line 208) | def _nystroem_cholesky_factor(
method _lev_scores (line 353) | def _lev_scores(
method inducing_pts_from_lev_scores (line 401) | def inducing_pts_from_lev_scores(self, lev_scores, N):
method _cho_factor_stable (line 414) | def _cho_factor_stable(self, M, pre_reg=False, eps_mag_max=1):
method solve (line 473) | def solve(
method max_n_inducing_pts (line 828) | def max_n_inducing_pts(n_train, n_atoms, max_memory_bytes):
method est_memory_requirement (line 847) | def est_memory_requirement(n_train, n_inducing_pts, n_atoms):
FILE: sgdml/torchtools.py
function _next_batch_size (line 52) | def _next_batch_size(n_total, batch_size):
class GDMLTorchAssemble (line 61) | class GDMLTorchAssemble(nn.Module):
method __init__ (line 67) | def __init__(
method _forward (line 110) | def _forward(
method forward (line 343) | def forward(self, J_indx):
class GDMLTorchPredict (line 395) | class GDMLTorchPredict(nn.Module):
method __init__ (line 401) | def __init__(
method get_n_perm_batches (line 595) | def get_n_perm_batches(self):
method set_n_perm_batches (line 600) | def set_n_perm_batches(self, n_perm_batches):
method apply_perms_to_obj (line 618) | def apply_perms_to_obj(self, xs, perm_idxs=None):
method remove_perms_from_obj (line 631) | def remove_perms_from_obj(self, xs):
method uncache_perms (line 637) | def uncache_perms(self):
method cache_perms (line 651) | def cache_perms(self):
method est_mem_requirement (line 667) | def est_mem_requirement(self, return_min=False):
method set_R_d_desc (line 728) | def set_R_d_desc(self, R_d_desc):
method set_alphas (line 760) | def set_alphas(self, alphas, alphas_E=None):
method _forward (line 877) | def _forward(self, Rs_or_train_idxs, return_E=True):
method forward (line 1048) | def forward(self, Rs_or_train_idxs, return_E=True):
FILE: sgdml/train.py
function _share_array (line 75) | def _share_array(arr_np, typecode_or_type):
function _assemble_kernel_mat_wkr (line 97) | def _assemble_kernel_mat_wkr(
class GDMLTrain (line 305) | class GDMLTrain(object):
method __init__ (line 306) | def __init__(self, max_memory=None, max_processes=None, use_torch=False):
method __del__ (line 363) | def __del__(self):
method create_task (line 370) | def create_task(
method create_task_from_model (line 649) | def create_task_from_model(self, model, dataset):
method create_model (line 727) | def create_model(
method train (line 836) | def train( # noqa: C901
method _recov_int_const (line 1090) | def _recov_int_const(
method _assemble_kernel_mat (line 1260) | def _assemble_kernel_mat(
method draw_strat_sample (line 1537) | def draw_strat_sample(self, T, n, excl_idxs=None):
FILE: sgdml/utils/desc.py
function _pbc_diff (line 44) | def _pbc_diff(diffs, lat_and_inv, use_torch=False):
function _pdist (line 80) | def _pdist(r, lat_and_inv=None):
function _squareform (line 113) | def _squareform(vec_or_mat):
function _r_to_desc (line 139) | def _r_to_desc(r, pdist):
function _r_to_d_desc (line 166) | def _r_to_d_desc(r, pdist, lat_and_inv=None):
function _from_r (line 208) | def _from_r(r, lat_and_inv=None):
class Desc (line 242) | class Desc(object):
method __init__ (line 244) | def __init__(self, n_atoms, max_processes=None):
method from_R (line 288) | def from_R(self, R, lat_and_inv=None, max_processes=None, callback=None):
method d_desc_dot_vec (line 368) | def d_desc_dot_vec(self, R_d_desc, vecs, overwrite_vecs=False):
method vec_dot_d_desc (line 388) | def vec_dot_d_desc(self, R_d_desc, vecs, out=None):
method d_desc_from_comp (line 422) | def d_desc_from_comp(self, R_d_desc, out=None):
method d_desc_to_comp (line 473) | def d_desc_to_comp(self, R_d_desc):
method perm (line 510) | def perm(perm):
FILE: sgdml/utils/io.py
function z_str_to_z (line 154) | def z_str_to_z(z_str):
function z_to_z_str (line 158) | def z_to_z_str(z):
function train_dir_name (line 162) | def train_dir_name(dataset, n_train, use_sym, use_E, use_E_cstr):
function task_file_name (line 183) | def task_file_name(task):
function model_file_name (line 192) | def model_file_name(task_or_model, is_extended=False):
function dataset_md5 (line 208) | def dataset_md5(dataset):
function read_xyz (line 238) | def read_xyz(file_path):
function write_geometry (line 264) | def write_geometry(filename, r, z, comment_str=''):
function generate_xyz_str (line 278) | def generate_xyz_str(r, z, e=None, f=None, lattice=None):
function lattice_vec_to_par (line 303) | def lattice_vec_to_par(lat):
function is_file_type (line 327) | def is_file_type(arg, type):
function filter_file_type (line 414) | def filter_file_type(dir, type, md5_match=None):
function is_valid_file_type (line 464) | def is_valid_file_type(arg_in):
function is_dir_with_file_type (line 514) | def is_dir_with_file_type(arg, type, or_file=False):
function is_task_dir_resumeable (line 572) | def is_task_dir_resumeable(
function is_strict_pos_int (line 642) | def is_strict_pos_int(arg):
function parse_list_or_range (line 667) | def parse_list_or_range(arg):
FILE: sgdml/utils/perm.py
function share_array (line 48) | def share_array(arr_np, typecode):
function _bipartite_match_wkr (line 53) | def _bipartite_match_wkr(i, n_train, same_z_cost):
function bipartite_match (line 90) | def bipartite_match(R, z, lat_and_inv=None, max_processes=None, callback...
function sync_perm_mat (line 238) | def sync_perm_mat(match_perms_all, match_cost, n_atoms, callback=None):
function to_cycles (line 263) | def to_cycles(perm):
function salvage_subgroup (line 289) | def salvage_subgroup(perms):
function complete_sym_group (line 344) | def complete_sym_group(
function find_perms (line 384) | def find_perms(R, z, lat_and_inv=None, callback=None, max_processes=None):
function find_extra_perms (line 415) | def find_extra_perms(R, z, lat_and_inv=None, callback=None, max_processe...
function find_frags (line 527) | def find_frags(r, z, lat_and_inv=None):
function find_frag_perms (line 564) | def find_frag_perms(R, z, lat_and_inv=None, callback=None, max_processes...
function _frag_perm_to_perm (line 759) | def _frag_perm_to_perm(n_atoms, frag_idxs, frag_perms):
function find_perms_in_frag (line 774) | def find_perms_in_frag(R, z, frag_idxs, lat_and_inv=None, max_processes=...
function find_perms_via_alignment (line 790) | def find_perms_via_alignment(
function find_perms_via_reflection (line 917) | def find_perms_via_reflection(
function print_perm_colors (line 967) | def print_perm_colors(perm, pts, plane_3idxs=None):
function inv_perm (line 1035) | def inv_perm(perm):
FILE: sgdml/utils/ui.py
function yes_or_no (line 39) | def yes_or_no(question):
function callback (line 61) | def callback(
function sec_callback (line 150) | def sec_callback(
function color_str (line 189) | def color_str(str, fore_color=WHITE, back_color=BLACK, bold=False):
function blink_str (line 203) | def blink_str(str):
function unicode_str (line 208) | def unicode_str(s):
function gen_memory_str (line 218) | def gen_memory_str(bytes):
function gen_lattice_str (line 232) | def gen_lattice_str(lat):
function str_plen (line 244) | def str_plen(str):
function wrap_str (line 265) | def wrap_str(str, width=MAX_PRINT_WIDTH - LOG_LEVELNAME_WIDTH):
function indent_str (line 297) | def indent_str(str, indent):
function wrap_indent_str (line 318) | def wrap_indent_str(label, str, width=MAX_PRINT_WIDTH - LOG_LEVELNAME_WI...
function merge_col_str (line 346) | def merge_col_str(
function gen_mat_str (line 378) | def gen_mat_str(mat):
function gen_range_str (line 434) | def gen_range_str(min, max):
function print_step_title (line 457) | def print_step_title(title_str, sec_title_str='', underscore=True):
function print_two_column_str (line 474) | def print_two_column_str(str, sec_str=''):
function print_lattice (line 489) | def print_lattice(lat=None, inset=False):
Condensed preview — 28 files, each showing path, character count, and a content snippet. Download the .json file or copy for the full structured content (404K chars).
[
{
"path": ".gitignore",
"chars": 198,
"preview": "\n.DS_Store\n\n# Compiled python modules.\n*.pyc\n\n# Setuptools distribution folder.\n/dist/\n\n# Python egg metadata, regenerat"
},
{
"path": "LICENSE.txt",
"chars": 1075,
"preview": "MIT License\n\nCopyright (c) 2018-2022 Stefan Chmiela\n\nPermission is hereby granted, free of charge, to any person obtaini"
},
{
"path": "README.md",
"chars": 3667,
"preview": "# Symmetric Gradient Domain Machine Learning (sGDML)\n\nFor more details visit: [sgdml.org](http://sgdml.org/) \nDocumenta"
},
{
"path": "pyproject.toml",
"chars": 91,
"preview": "[tool.black]\nskip-string-normalization = true\nskip-numeric-underscore-normalization = true\n"
},
{
"path": "scripts/sgdml_dataset_from_aims.py",
"chars": 5518,
"preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2022 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
},
{
"path": "scripts/sgdml_dataset_from_extxyz.py",
"chars": 8022,
"preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2022 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
},
{
"path": "scripts/sgdml_dataset_from_ipi.py",
"chars": 5828,
"preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018 Stefan Chmiela\n#\n# Permission is hereby granted, free of charge,"
},
{
"path": "scripts/sgdml_dataset_to_extxyz.py",
"chars": 2950,
"preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2019 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
},
{
"path": "scripts/sgdml_dataset_via_ase.py",
"chars": 6055,
"preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2022 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
},
{
"path": "scripts/sgdml_datasets_from_model.py",
"chars": 3388,
"preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018 Stefan Chmiela\n#\n# Permission is hereby granted, free of charge,"
},
{
"path": "setup.cfg",
"chars": 277,
"preview": "[flake8]\nmax-complexity = 12\nignore = E501,W503,E741\nselect = C,E,F,W\n\n[isort]\nmulti_line_output = 3\ninclude_trailing_co"
},
{
"path": "setup.py",
"chars": 2004,
"preview": "import os\nimport re\nfrom io import open\nfrom setuptools import setup, find_packages\n\n\ndef get_property(property, package"
},
{
"path": "sgdml/__init__.py",
"chars": 3591,
"preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2019-2025 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
},
{
"path": "sgdml/cli.py",
"chars": 74967,
"preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2022 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
},
{
"path": "sgdml/get.py",
"chars": 5223,
"preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2023 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
},
{
"path": "sgdml/intf/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "sgdml/intf/ase_calc.py",
"chars": 4188,
"preview": "# MIT License\n#\n# Copyright (c) 2018-2020 Stefan Chmiela\n#\n# Permission is hereby granted, free of charge, to any person"
},
{
"path": "sgdml/predict.py",
"chars": 46891,
"preview": "\"\"\"\nThis module contains all routines for evaluating GDML and sGDML models.\n\"\"\"\n\n# MIT License\n#\n# Copyright (c) 2018-20"
},
{
"path": "sgdml/solvers/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "sgdml/solvers/analytic.py",
"chars": 5423,
"preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2020-2022 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
},
{
"path": "sgdml/solvers/iterative.py",
"chars": 27847,
"preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2020-2025 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
},
{
"path": "sgdml/torchtools.py",
"chars": 40652,
"preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2019-2023 Stefan Chmiela, Jan Hermann\n#\n# Permission is hereby grante"
},
{
"path": "sgdml/train.py",
"chars": 60606,
"preview": "\"\"\"\nThis module contains all routines for training GDML and sGDML models.\n\"\"\"\n\n# MIT License\n#\n# Copyright (c) 2018-2022"
},
{
"path": "sgdml/utils/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "sgdml/utils/desc.py",
"chars": 17340,
"preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2022 Stefan Chmiela, Luis Galvez\n#\n# Permission is hereby grante"
},
{
"path": "sgdml/utils/io.py",
"chars": 19139,
"preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2021 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
},
{
"path": "sgdml/utils/perm.py",
"chars": 31377,
"preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2021 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
},
{
"path": "sgdml/utils/ui.py",
"chars": 13361,
"preview": "#!/usr/bin/python\n\n# MIT License\n#\n# Copyright (c) 2018-2021 Stefan Chmiela\n#\n# Permission is hereby granted, free of ch"
}
]
About this extraction
This page contains the full source code of the stefanch/sGDML GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 28 files (380.5 KB), approximately 92.7k tokens, and a symbol index with 169 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.