From 18adfdc295065dd4aae82f6bec3f04b73b27f3ee Mon Sep 17 00:00:00 2001 From: Peter Eastman Date: Wed, 5 Oct 2022 16:15:55 -0700 Subject: [PATCH 01/12] Began implementing ZBL potential --- torchmdnet/data.py | 5 ++++- torchmdnet/datasets/hdf.py | 33 +++++++++++++++++----------- torchmdnet/models/model.py | 8 +++++-- torchmdnet/priors.py | 45 ++++++++++++++++++++++++++++++++++---- 4 files changed, 71 insertions(+), 20 deletions(-) diff --git a/torchmdnet/data.py b/torchmdnet/data.py index cc4dade65..4e497475c 100644 --- a/torchmdnet/data.py +++ b/torchmdnet/data.py @@ -28,8 +28,11 @@ def setup(self, stage): self.hparams["force_files"], ) else: + args = self.hparams["dataset_arg"] + if args is None: + args = {} self.dataset = getattr(datasets, self.hparams["dataset"])( - self.hparams["dataset_root"], **self.hparams["dataset_arg"] + self.hparams["dataset_root"], **args ) self.idx_train, self.idx_val, self.idx_test = make_splits( diff --git a/torchmdnet/datasets/hdf.py b/torchmdnet/datasets/hdf.py index 684e68002..87f5974d2 100644 --- a/torchmdnet/datasets/hdf.py +++ b/torchmdnet/datasets/hdf.py @@ -1,6 +1,7 @@ import torch from torch_geometric.data import Dataset, Data import h5py +import numpy as np class HDF5(Dataset): @@ -27,7 +28,12 @@ def __init__(self, filename, **kwargs): files = [h5py.File(f, "r") for f in self.filename.split(";")] for file in files: for group_name in file: - self.num_molecules += len(file[group_name]["energy"]) + if group_name == '_metadata': + group = file[group_name] + for name in group: + setattr(self, name, torch.tensor(np.array(group[name]))) + else: + self.num_molecules += len(file[group_name]["energy"]) file.close() def setup_index(self): @@ -36,18 +42,19 @@ def setup_index(self): self.index = [] for file in files: for group_name in file: - group = file[group_name] - types = group["types"] - pos = group["pos"] - energy = group["energy"] - if "forces" in group: - self.has_forces = True - forces = group["forces"] - for i in range(len(energy)): - self.index.append((types, pos, energy, forces, i)) - else: - for i in range(len(energy)): - self.index.append((types, pos, energy, i)) + if group_name != '_metadata': + group = file[group_name] + types = group["types"] + pos = group["pos"] + energy = group["energy"] + if "forces" in group: + self.has_forces = True + forces = group["forces"] + for i in range(len(energy)): + self.index.append((types, pos, energy, forces, i)) + else: + for i in range(len(energy)): + self.index.append((types, pos, energy, i)) assert self.num_molecules == len(self.index), ( "Mismatch between previously calculated " diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 903e9774f..e3d6a3e54 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -178,13 +178,17 @@ def forward( if self.std is not None: x = x * self.std - # apply prior model - if self.prior_model is not None: + # apply atomwise prior model + if self.prior_model is not None and self.prior_model.atomwise: x = self.prior_model(x, z, pos, batch) # aggregate atoms out = scatter(x, batch, dim=0, reduce=self.reduce_op) + # apply non-atomwise prior model + if self.prior_model is not None and not self.prior_model.atomwise: + out += self.prior_model(x, z, pos, batch) + # shift by data mean if self.mean is not None: out = out + self.mean diff --git a/torchmdnet/priors.py b/torchmdnet/priors.py index 85c27cd39..7169d5c23 100644 --- a/torchmdnet/priors.py +++ b/torchmdnet/priors.py @@ -2,9 +2,9 @@ import torch from torch import nn from pytorch_lightning.utilities import rank_zero_warn +from torchmdnet.models.utils import Distance - -__all__ = ["Atomref"] +__all__ = ["Atomref", "ZBL"] class BasePrior(nn.Module, metaclass=ABCMeta): @@ -13,8 +13,9 @@ class BasePrior(nn.Module, metaclass=ABCMeta): As an example, have a look at the `torchmdnet.priors.Atomref` prior. """ - def __init__(self, dataset=None): + def __init__(self, dataset=None, atomwise=True): super(BasePrior, self).__init__() + self.atomwise = atomwise @abstractmethod def get_init_args(self): @@ -36,7 +37,9 @@ def forward(self, x, z, pos, batch): batch (torch.Tensor): tensor containing the sample index for each atom. Returns: - torch.Tensor: updated scalar atomwise predictions + torch.Tensor: If this is an atom-wise prior (self.atomwise is True), the return value + contains updated scalar atomwise predictions. Otherwise, it is a single scalar that + is added to the result after summing over atoms. """ return @@ -76,3 +79,37 @@ def get_init_args(self): def forward(self, x, z, pos, batch): return x + self.atomref(z) + + +class ZBL(BasePrior): + """This class implements the Ziegler-Biersack-Littmark (ZBL) potential for screened nuclear repulsion. + Is is described in https://doi.org/10.1007/978-3-642-68779-2_5 (equations 9 and 10 on page 147). It + is an empirical potential that does a good job of describing the repulsion between atoms at very short + distances. + + To use this prior, the Dataset must provide the following attributes. + + atomic_number: 1D tensor of length max_z. atomic_number[z] is the atomic number of atoms with atom type z. + distance_scale: multiply by this factor to convert coordinates stored in the dataset to meters + energy_scale: multiply by this factor to convert energies stored in the dataset to Joules + """ + def __init__(self, dataset=None): + super(ZBL, self).__init__(atomwise=False) + self.register_buffer("atomic_number", dataset.atomic_number) + self.distance = Distance(0, 10.0, max_num_neighbors=100) + self.distance_scale = dataset.distance_scale*1.88973e10 # convert to Bohr units + self.energy_scale = dataset.energy_scale # convert to Joules + + def get_init_args(self): + return {} + + def reset_parameters(self): + pass + + def forward(self, x, z, pos, batch): + edge_index, distance, _ = self.distance(pos*self.distance_scale, batch) + atomic_number = self.atomic_number[z[edge_index]] + a = 0.8854/(atomic_number[0]**0.23 + atomic_number[0]**0.23) + d = distance/a + f = 0.1818*torch.exp(-3.2*d) + 0.5099*torch.exp(-0.9423*d) + 0.2802*torch.exp(-0.4029*d) + 0.02817*torch.exp(-0.2016*d) + return (2.30707755e-19/self.energy_scale)*torch.sum(f/distance, dim=-1) From 78bf317989edf20387617e473daaf39e0b4e296d Mon Sep 17 00:00:00 2001 From: Peter Eastman Date: Wed, 5 Oct 2022 16:36:24 -0700 Subject: [PATCH 02/12] Apply smooth cutoff to ZBL --- torchmdnet/priors.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchmdnet/priors.py b/torchmdnet/priors.py index 7169d5c23..818b6ec47 100644 --- a/torchmdnet/priors.py +++ b/torchmdnet/priors.py @@ -2,7 +2,7 @@ import torch from torch import nn from pytorch_lightning.utilities import rank_zero_warn -from torchmdnet.models.utils import Distance +from torchmdnet.models.utils import Distance, CosineCutoff __all__ = ["Atomref", "ZBL"] @@ -95,8 +95,10 @@ class ZBL(BasePrior): """ def __init__(self, dataset=None): super(ZBL, self).__init__(atomwise=False) + cutoff_distance = 10.0 self.register_buffer("atomic_number", dataset.atomic_number) - self.distance = Distance(0, 10.0, max_num_neighbors=100) + self.distance = Distance(0, cutoff_distance, max_num_neighbors=100) + self.cutoff = CosineCutoff(cutoff_upper=cutoff_distance) self.distance_scale = dataset.distance_scale*1.88973e10 # convert to Bohr units self.energy_scale = dataset.energy_scale # convert to Joules @@ -112,4 +114,5 @@ def forward(self, x, z, pos, batch): a = 0.8854/(atomic_number[0]**0.23 + atomic_number[0]**0.23) d = distance/a f = 0.1818*torch.exp(-3.2*d) + 0.5099*torch.exp(-0.9423*d) + 0.2802*torch.exp(-0.4029*d) + 0.02817*torch.exp(-0.2016*d) + f *= self.cutoff(distance) return (2.30707755e-19/self.energy_scale)*torch.sum(f/distance, dim=-1) From c63a1a39c3f65705cdd6e6bd1592abfc32ca4757 Mon Sep 17 00:00:00 2001 From: Peter Eastman Date: Thu, 6 Oct 2022 12:44:28 -0700 Subject: [PATCH 03/12] Allow cutoff to be specified in config file --- tests/test_module.py | 2 +- torchmdnet/models/model.py | 6 +++--- torchmdnet/priors.py | 28 ++++++++++++++++++++-------- torchmdnet/scripts/train.py | 8 ++++++-- torchmdnet/utils.py | 2 +- 5 files changed, 31 insertions(+), 15 deletions(-) diff --git a/tests/test_module.py b/tests/test_module.py index 17002d7c8..c920c010e 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -43,7 +43,7 @@ def test_train(model_name, use_atomref, tmpdir): prior = None if use_atomref: prior = getattr(priors, args["prior_model"])(dataset=datamodule.dataset) - args["prior_args"] = prior.get_init_args() + args["prior_init_args"] = prior.get_init_args() module = LNNP(args, prior_model=prior) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index e3d6a3e54..0ac17b1ed 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -65,16 +65,16 @@ def create_model(args, prior_model=None, mean=None, std=None): # prior model if args["prior_model"] and prior_model is None: - assert "prior_args" in args, ( + assert "prior_init_args" in args, ( f"Requested prior model {args['prior_model']} but the " - f'arguments are lacking the key "prior_args".' + f'arguments are lacking the key "prior_init_args".' ) assert hasattr(priors, args["prior_model"]), ( f'Unknown prior model {args["prior_model"]}. ' f'Available models are {", ".join(priors.__all__)}' ) # instantiate prior model if it was not passed to create_model (i.e. when loading a model) - prior_model = getattr(priors, args["prior_model"])(**args["prior_args"]) + prior_model = getattr(priors, args["prior_model"])(**args["prior_init_args"]) # create output network output_prefix = "Equivariant" if is_equivariant else "" diff --git a/torchmdnet/priors.py b/torchmdnet/priors.py index 818b6ec47..662feb5d3 100644 --- a/torchmdnet/priors.py +++ b/torchmdnet/priors.py @@ -93,23 +93,35 @@ class ZBL(BasePrior): distance_scale: multiply by this factor to convert coordinates stored in the dataset to meters energy_scale: multiply by this factor to convert energies stored in the dataset to Joules """ - def __init__(self, dataset=None): + def __init__(self, cutoff_distance, max_num_neighbors, atomic_number=None, distance_scale=None, energy_scale=None, dataset=None): super(ZBL, self).__init__(atomwise=False) - cutoff_distance = 10.0 - self.register_buffer("atomic_number", dataset.atomic_number) - self.distance = Distance(0, cutoff_distance, max_num_neighbors=100) + if atomic_number is None: + atomic_number = dataset.atomic_number + if distance_scale is None: + distance_scale = dataset.distance_scale + if energy_scale is None: + energy_scale = dataset.energy_scale + atomic_number = torch.as_tensor(atomic_number, dtype=torch.int8) + self.register_buffer("atomic_number", atomic_number) + self.distance = Distance(0, cutoff_distance, max_num_neighbors=max_num_neighbors) self.cutoff = CosineCutoff(cutoff_upper=cutoff_distance) - self.distance_scale = dataset.distance_scale*1.88973e10 # convert to Bohr units - self.energy_scale = dataset.energy_scale # convert to Joules + self.cutoff_distance = cutoff_distance + self.max_num_neighbors = max_num_neighbors + self.distance_scale = distance_scale + self.energy_scale = energy_scale def get_init_args(self): - return {} + return {'cutoff_distance': self.cutoff_distance, + 'max_num_neighbors': self.max_num_neighbors, + 'atomic_number': self.atomic_number, + 'distance_scale': self.distance_scale, + 'energy_scale': self.energy_scale} def reset_parameters(self): pass def forward(self, x, z, pos, batch): - edge_index, distance, _ = self.distance(pos*self.distance_scale, batch) + edge_index, distance, _ = self.distance(pos*self.distance_scale*1.88973e10, batch) atomic_number = self.atomic_number[z[edge_index]] a = 0.8854/(atomic_number[0]**0.23 + atomic_number[0]**0.23) d = distance/a diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index f6f4e9325..1345269cb 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -63,6 +63,7 @@ def get_args(): parser.add_argument('--model', type=str, default='graph-network', choices=models.__all__, help='Which model to train') parser.add_argument('--output-model', type=str, default='Scalar', choices=output_modules.__all__, help='The type of output model') parser.add_argument('--prior-model', type=str, default=None, choices=priors.__all__, help='Which prior model to use') + parser.add_argument('--prior-args', default=None, type=str, help='Additional arguments for the prior model. Need to be specified in JSON format i.e. \'{"cutoff_distance": 10.0, "max_num_neighbors": 100}\'') # architectural args parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge') @@ -125,8 +126,11 @@ def main(): f"Available models are {', '.join(priors.__all__)}" ) # initialize the prior model - prior = getattr(priors, args.prior_model)(dataset=data.dataset) - args.prior_args = prior.get_init_args() + prior_args = args.prior_args + if prior_args is None: + prior_args = {} + prior = getattr(priors, args.prior_model)(dataset=data.dataset, **prior_args) + args.prior_init_args = prior.get_init_args() # initialize lightning module model = LNNP(args, prior_model=prior, mean=data.mean, std=data.std) diff --git a/torchmdnet/utils.py b/torchmdnet/utils.py index c3fc67afe..82321273b 100644 --- a/torchmdnet/utils.py +++ b/torchmdnet/utils.py @@ -176,7 +176,7 @@ def __call__(self, parser, namespace, values, option_string=None): with open(hparams_path, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) for key in config.keys(): - if key not in namespace and key != "prior_args": + if key not in namespace and key != "prior_init_args": raise ValueError(f"Unknown argument in the model checkpoint: {key}") namespace.__dict__.update(config) namespace.__dict__.update(load_model=values) From b17d504fbdc411fd76b17e9071d1e16c6cbd0335 Mon Sep 17 00:00:00 2001 From: Peter Eastman Date: Mon, 10 Oct 2022 12:48:09 -0700 Subject: [PATCH 04/12] Workaround for pytorch bug --- torchmdnet/models/model.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 0ac17b1ed..580e7b314 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -179,15 +179,17 @@ def forward( x = x * self.std # apply atomwise prior model - if self.prior_model is not None and self.prior_model.atomwise: - x = self.prior_model(x, z, pos, batch) + if self.prior_model is not None: + if self.prior_model.atomwise: + x = self.prior_model(x, z, pos, batch) # aggregate atoms out = scatter(x, batch, dim=0, reduce=self.reduce_op) # apply non-atomwise prior model - if self.prior_model is not None and not self.prior_model.atomwise: - out += self.prior_model(x, z, pos, batch) + if self.prior_model is not None: + if not self.prior_model.atomwise: + out += self.prior_model(x, z, pos, batch) # shift by data mean if self.mean is not None: From d6709d8607f2515b4f48c3920f224b62f588ae1f Mon Sep 17 00:00:00 2001 From: Peter Eastman Date: Thu, 27 Oct 2022 16:50:09 -0700 Subject: [PATCH 05/12] Bug fixes to ZBL potential --- torchmdnet/priors.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/torchmdnet/priors.py b/torchmdnet/priors.py index 662feb5d3..16484513a 100644 --- a/torchmdnet/priors.py +++ b/torchmdnet/priors.py @@ -121,10 +121,13 @@ def reset_parameters(self): pass def forward(self, x, z, pos, batch): - edge_index, distance, _ = self.distance(pos*self.distance_scale*1.88973e10, batch) + edge_index, distance, _ = self.distance(pos, batch) atomic_number = self.atomic_number[z[edge_index]] - a = 0.8854/(atomic_number[0]**0.23 + atomic_number[0]**0.23) - d = distance/a + # 5.29e-11 is the Bohr radius in meters. All other numbers are magic constants from the ZBL potential. + a = 0.8854*5.29177210903e-11/(atomic_number[0]**0.23 + atomic_number[1]**0.23) + d = distance*self.distance_scale/a f = 0.1818*torch.exp(-3.2*d) + 0.5099*torch.exp(-0.9423*d) + 0.2802*torch.exp(-0.4029*d) + 0.02817*torch.exp(-0.2016*d) f *= self.cutoff(distance) - return (2.30707755e-19/self.energy_scale)*torch.sum(f/distance, dim=-1) + # Compute the energy, converting to the dataset's units. Multiply by 0.5 because every atom pair + # appears twice. + return 0.5*(2.30707755e-28/self.energy_scale/self.distance_scale)*torch.sum(f*atomic_number[0]*atomic_number[1]/distance, dim=-1) From d23e6500f2cef1fa56d6c99ce5fdb983f1379bca Mon Sep 17 00:00:00 2001 From: Peter Eastman Date: Fri, 28 Oct 2022 17:29:22 -0700 Subject: [PATCH 06/12] Adapted to new API for priors --- examples/ET-ANI1.yaml | 6 +- examples/ET-MD17.yaml | 6 +- examples/ET-QM9.yaml | 6 +- examples/ET-SPICE.yaml | 6 +- tests/test_dataset_comp6.py | 2 +- tests/test_datasets.py | 2 +- tests/utils.py | 2 +- torchmdnet/data.py | 10 +- torchmdnet/datasets/ace.py | 189 ++++++++++++++++++++-------- torchmdnet/datasets/ani.py | 94 ++++++++------ torchmdnet/datasets/comp6.py | 47 +++---- torchmdnet/datasets/custom.py | 2 +- torchmdnet/datasets/hdf.py | 2 +- torchmdnet/datasets/md17.py | 6 +- torchmdnet/datasets/qm9q.py | 45 +++---- torchmdnet/datasets/spice.py | 71 +++++++---- torchmdnet/models/model.py | 35 +++--- torchmdnet/models/output_modules.py | 49 ++++++-- torchmdnet/module.py | 64 +++++----- torchmdnet/priors.py | 133 -------------------- torchmdnet/priors/__init__.py | 2 + torchmdnet/priors/atomref.py | 41 ++++++ torchmdnet/priors/base.py | 47 +++++++ torchmdnet/priors/zbl.py | 54 ++++++++ torchmdnet/scripts/train.py | 6 +- 25 files changed, 543 insertions(+), 384 deletions(-) delete mode 100644 torchmdnet/priors.py create mode 100644 torchmdnet/priors/__init__.py create mode 100644 torchmdnet/priors/atomref.py create mode 100644 torchmdnet/priors/base.py create mode 100644 torchmdnet/priors/zbl.py diff --git a/examples/ET-ANI1.yaml b/examples/ET-ANI1.yaml index dc377eed1..0270c48f5 100644 --- a/examples/ET-ANI1.yaml +++ b/examples/ET-ANI1.yaml @@ -12,14 +12,14 @@ dataset_root: ~/data derivative: false distance_influence: both early_stopping_patience: 500 -ema_alpha_dy: 1.0 +ema_alpha_neg_dy: 1.0 ema_alpha_y: 1.0 embed_files: null embedding_dimension: 128 energy_files: null -energy_weight: 1.0 +y_weight: 1.0 force_files: null -force_weight: 1.0 +neg_dy_weight: 1.0 inference_batch_size: 2048 load_model: null log_dir: logs/ diff --git a/examples/ET-MD17.yaml b/examples/ET-MD17.yaml index 9b5e4708b..5c38b9eb2 100644 --- a/examples/ET-MD17.yaml +++ b/examples/ET-MD17.yaml @@ -13,14 +13,14 @@ dataset_root: ~/data derivative: true distance_influence: both early_stopping_patience: 300 -ema_alpha_dy: 1.0 +ema_alpha_neg_dy: 1.0 ema_alpha_y: 0.05 embed_files: null embedding_dimension: 128 energy_files: null -energy_weight: 0.2 +y_weight: 0.2 force_files: null -force_weight: 0.8 +neg_dy_weight: 0.8 inference_batch_size: 64 load_model: null log_dir: logs/ diff --git a/examples/ET-QM9.yaml b/examples/ET-QM9.yaml index b7be13866..ef9048578 100644 --- a/examples/ET-QM9.yaml +++ b/examples/ET-QM9.yaml @@ -13,14 +13,14 @@ dataset_root: ~/data derivative: false distance_influence: both early_stopping_patience: 150 -ema_alpha_dy: 1.0 +ema_alpha_neg_dy: 1.0 ema_alpha_y: 1.0 embed_files: null embedding_dimension: 256 energy_files: null -energy_weight: 1.0 +y_weight: 1.0 force_files: null -force_weight: 1.0 +neg_dy_weight: 1.0 inference_batch_size: 128 load_model: null log_dir: logs/ diff --git a/examples/ET-SPICE.yaml b/examples/ET-SPICE.yaml index 9ac87d30a..f2e5b189f 100644 --- a/examples/ET-SPICE.yaml +++ b/examples/ET-SPICE.yaml @@ -13,14 +13,14 @@ dataset_root: data derivative: true distance_influence: both early_stopping_patience: 50 -ema_alpha_dy: 1.0 +ema_alpha_neg_dy: 1.0 ema_alpha_y: 1.0 embed_files: null embedding_dimension: 128 energy_files: null -energy_weight: 0.5 +y_weight: 0.5 force_files: null -force_weight: 0.5 +neg_dy_weight: 0.5 inference_batch_size: 16 load_model: null log_dir: logs/ diff --git a/tests/test_dataset_comp6.py b/tests/test_dataset_comp6.py index 3f2dda8e7..bf6906402 100644 --- a/tests/test_dataset_comp6.py +++ b/tests/test_dataset_comp6.py @@ -37,7 +37,7 @@ def test_dataset_s66x8(): ) assert pt.allclose(sample.y, pt.tensor([[-47.5919]])) assert pt.allclose( - sample.dy, + sample.neg_dy, pt.tensor( [ [0.2739, -0.2190, -0.0012], diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 7d1cd9af2..b8c7ce960 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -51,7 +51,7 @@ def test_custom(energy, forces, num_files, tmpdir, num_samples=100): if energy: assert hasattr(sample, "y"), "Sample doesn't contain energy" if forces: - assert hasattr(sample, "dy"), "Sample doesn't contain forces" + assert hasattr(sample, "neg_dy"), "Sample doesn't contain forces" def test_hdf5_multiprocessing(tmpdir, num_entries=100): diff --git a/tests/utils.py b/tests/utils.py index c4165eeda..d8cd8322b 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -75,7 +75,7 @@ def get(self, idx): if self.energies is not None: features["y"] = self.energies[idx].clone() if self.forces is not None: - features["dy"] = self.forces[idx].clone() + features["neg_dy"] = self.forces[idx].clone() return Data(**features) def len(self): diff --git a/torchmdnet/data.py b/torchmdnet/data.py index 4e497475c..516158faf 100644 --- a/torchmdnet/data.py +++ b/torchmdnet/data.py @@ -28,11 +28,11 @@ def setup(self, stage): self.hparams["force_files"], ) else: - args = self.hparams["dataset_arg"] - if args is None: - args = {} + dataset_arg = {} + if self.hparams["dataset_arg"] is not None: + dataset_arg = self.hparams["dataset_arg"] self.dataset = getattr(datasets, self.hparams["dataset"])( - self.hparams["dataset_root"], **args + self.hparams["dataset_root"], **dataset_arg ) self.idx_train, self.idx_val, self.idx_test = make_splits( @@ -114,7 +114,7 @@ def _get_dataloader(self, dataset, stage, store_dataloader=True): def _standardize(self): def get_energy(batch, atomref): - if batch.y is None: + if "y" not in batch or batch.y is None: raise MissingEnergyException() if atomref is None: diff --git a/torchmdnet/datasets/ace.py b/torchmdnet/datasets/ace.py index 077be4c97..9ab3ecbaf 100644 --- a/torchmdnet/datasets/ace.py +++ b/torchmdnet/datasets/ace.py @@ -8,7 +8,6 @@ class Ace(Dataset): - def __init__( self, root=None, @@ -17,13 +16,16 @@ def __init__( pre_filter=None, paths=None, max_gradient=None, + subsample_molecules=1, ): + assert isinstance(paths, (str, list)) - arg_hash = f"{paths}{max_gradient}" + arg_hash = f"{paths}{max_gradient}{subsample_molecules}" arg_hash = hashlib.md5(arg_hash.encode()).hexdigest() self.name = f"{self.__class__.__name__}-{arg_hash}" self.paths = paths self.max_gradient = max_gradient + self.subsample_molecules = int(subsample_molecules) super().__init__(root, transform, pre_transform, pre_filter) ( @@ -31,7 +33,7 @@ def __init__( z_name, pos_name, y_name, - dy_name, + neg_dy_name, q_name, pq_name, dp_name, @@ -42,8 +44,8 @@ def __init__( pos_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3) ) self.y_mm = np.memmap(y_name, mode="r", dtype=np.float64) - self.dy_mm = np.memmap( - dy_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3) + self.neg_dy_mm = np.memmap( + neg_dy_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3) ) self.q_mm = np.memmap(q_name, mode="r", dtype=np.int8) self.pq_mm = np.memmap(pq_name, mode="r", dtype=np.float32) @@ -58,59 +60,140 @@ def __init__( @property def raw_paths(self): - if os.path.isfile(self.paths): - return [self.paths] - if os.path.isdir(self.paths): - return [ - os.path.join(self.paths, file_) - for file_ in os.listdir(self.paths) - if file_.endswith(".h5") - ] + paths_init = self.paths if isinstance(self.paths, list) else [self.paths] + paths = [] + for path in paths_init: - raise RuntimeError(f"Cannot load {self.paths}") + if os.path.isfile(path): + paths.append(path) + continue - def sample_iter(self): + if os.path.isdir(path): + for file_ in os.listdir(path): + if file_.endswith(".h5"): + paths.append(os.path.join(path, file_)) + continue - for path in tqdm(self.raw_paths, desc="Files"): - molecules = list(h5py.File(path).values()) + raise RuntimeError(f"{path} is neither a directory nor a file") - for mol in tqdm(molecules, desc="Molecules", leave=False): - z = pt.tensor(mol["atomic_numbers"], dtype=pt.long) - fq = pt.tensor(mol["formal_charges"], dtype=pt.long) - q = fq.sum() + return paths - for conf in mol["conformations"].values(): + @staticmethod + def _load_confs_1_0(mol, n_atoms): - # Skip failed calculations - if "formation_energy" not in conf: - continue + for conf in mol["conformations"].values(): + + # Skip failed calculations + if "formation_energy" not in conf: + continue + + assert conf["positions"].attrs["units"] == "Å" + pos = pt.tensor(conf["positions"][...], dtype=pt.float32) + assert pos.shape == (n_atoms, 3) + + assert conf["formation_energy"].attrs["units"] == "eV" + y = pt.tensor(conf["formation_energy"][()], dtype=pt.float64) + assert y.shape == () + + assert conf["forces"].attrs["units"] == "eV/Å" + neg_dy = pt.tensor(conf["forces"][...], dtype=pt.float32) + assert neg_dy.shape == pos.shape + + assert conf["partial_charges"].attrs["units"] == "e" + pq = pt.tensor(conf["partial_charges"][:], dtype=pt.float32) + assert pq.shape == (n_atoms,) + + assert conf["dipole_moment"].attrs["units"] == "e*Å" + dp = pt.tensor(conf["dipole_moment"][:], dtype=pt.float32) + assert dp.shape == (3,) + + yield pos, y, neg_dy, pq, dp + + @staticmethod + def _load_confs_2_0(mol, n_atoms): - assert conf["positions"].attrs["units"] == "Å" - pos = pt.tensor(conf["positions"], dtype=pt.float32) - assert pos.shape == (z.shape[0], 3) + assert mol["positions"].attrs["units"] == "Å" + all_pos = pt.tensor(mol["positions"][...], dtype=pt.float32) + n_confs = all_pos.shape[0] + assert all_pos.shape == (n_confs, n_atoms, 3) - assert conf["formation_energy"].attrs["units"] == "eV" - y = pt.tensor(conf["formation_energy"][()], dtype=pt.float64) - assert y.shape == () + assert mol["formation_energies"].attrs["units"] == "eV" + all_y = pt.tensor(mol["formation_energies"][:], dtype=pt.float64) + assert all_y.shape == (n_confs,) - assert conf["forces"].attrs["units"] == "eV/Å" - dy = -pt.tensor(conf["forces"], dtype=pt.float32) - assert dy.shape == pos.shape + assert mol["forces"].attrs["units"] == "eV/Å" + all_neg_dy = pt.tensor(mol["forces"][...], dtype=pt.float32) + assert all_neg_dy.shape == all_pos.shape - assert conf["partial_charges"].attrs["units"] == "e" - pq = pt.tensor(conf["partial_charges"], dtype=pt.float32) - assert pq.shape == z.shape + assert mol["partial_charges"].attrs["units"] == "e" + all_pq = pt.tensor(mol["partial_charges"][...], dtype=pt.float32) + assert all_pq.shape == (n_confs, n_atoms) - assert conf["dipole_moment"].attrs["units"] == "e*Å" - dp = pt.tensor(conf["dipole_moment"], dtype=pt.float32) - assert dp.shape == (3,) + assert mol["dipole_moments"].attrs["units"] == "e*Å" + all_dp = pt.tensor(mol["dipole_moments"][...], dtype=pt.float32) + assert all_dp.shape == (n_confs, 3) + + for pos, y, neg_dy, pq, dp in zip(all_pos, all_y, all_neg_dy, all_pq, all_dp): + + # Skip failed calculations + if y.isnan(): + continue + + yield pos, y, neg_dy, pq, dp + + def sample_iter(self, mol_ids=False): + + assert self.subsample_molecules > 0 + + for path in tqdm(self.raw_paths, desc="Files"): + + h5 = h5py.File(path) + assert h5.attrs["layout"] == "Ace" + version = h5.attrs["layout_version"] + + mols = None + load_confs = None + if version == "1.0": + assert "name" in h5.attrs + mols = h5.items() + load_confs = self._load_confs_1_0 + elif version == "2.0": + assert len(h5.keys()) == 1 + mols = list(h5.values())[0].items() + load_confs = self._load_confs_2_0 + else: + raise RuntimeError(f"Unsuported layout verions: {version}") + + # Iterate over the molecules + for i_mol, (mol_id, mol) in tqdm( + enumerate(mols), + desc="Molecules", + total=len(mols), + leave=False, + ): + + # Subsample molecules + if i_mol % self.subsample_molecules != 0: + continue + + z = pt.tensor(mol["atomic_numbers"], dtype=pt.long) + fq = pt.tensor(mol["formal_charges"], dtype=pt.long) + q = fq.sum() + + for pos, y, neg_dy, pq, dp in load_confs(mol, n_atoms=len(z)): # Skip samples with large forces if self.max_gradient: - if dy.norm(dim=1).max() > float(self.max_gradient): + if neg_dy.norm(dim=1).max() > float(self.max_gradient): continue - data = Data(z=z, pos=pos, y=y.view(1, 1), dy=dy, q=q, pq=pq, dp=dp) + # Create a sample + args = dict( + z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy, q=q, pq=pq, dp=dp + ) + if mol_ids: + args["mol_id"] = mol_id + data = Data(**args) if self.pre_filter is not None and not self.pre_filter(data): continue @@ -127,7 +210,7 @@ def processed_file_names(self): f"{self.name}.z.mmap", f"{self.name}.pos.mmap", f"{self.name}.y.mmap", - f"{self.name}.dy.mmap", + f"{self.name}.neg_dy.mmap", f"{self.name}.q.mmap", f"{self.name}.pq.mmap", f"{self.name}.dp.mmap", @@ -135,6 +218,10 @@ def processed_file_names(self): def process(self): + print("Arguments") + print(f" max_gradient: {self.max_gradient} eV/A") + print(f" subsample_molecules: {self.subsample_molecules}\n") + print("Gathering statistics...") num_all_confs = 0 num_all_atoms = 0 @@ -150,7 +237,7 @@ def process(self): z_name, pos_name, y_name, - dy_name, + neg_dy_name, q_name, pq_name, dp_name, @@ -165,8 +252,8 @@ def process(self): y_mm = np.memmap( y_name + ".tmp", mode="w+", dtype=np.float64, shape=num_all_confs ) - dy_mm = np.memmap( - dy_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_atoms, 3) + neg_dy_mm = np.memmap( + neg_dy_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_atoms, 3) ) q_mm = np.memmap(q_name + ".tmp", mode="w+", dtype=np.int8, shape=num_all_confs) pq_mm = np.memmap( @@ -185,7 +272,7 @@ def process(self): z_mm[i_atom:i_next_atom] = data.z.to(pt.int8) pos_mm[i_atom:i_next_atom] = data.pos y_mm[i_conf] = data.y - dy_mm[i_atom:i_next_atom] = data.dy + neg_dy_mm[i_atom:i_next_atom] = data.neg_dy q_mm[i_conf] = data.q.to(pt.int8) pq_mm[i_atom:i_next_atom] = data.pq dp_mm[i_conf] = data.dp @@ -199,7 +286,7 @@ def process(self): z_mm.flush() pos_mm.flush() y_mm.flush() - dy_mm.flush() + neg_dy_mm.flush() q_mm.flush() pq_mm.flush() dp_mm.flush() @@ -208,7 +295,7 @@ def process(self): os.rename(z_mm.filename, z_name) os.rename(pos_mm.filename, pos_name) os.rename(y_mm.filename, y_name) - os.rename(dy_mm.filename, dy_name) + os.rename(neg_dy_mm.filename, neg_dy_name) os.rename(q_mm.filename, q_name) os.rename(pq_mm.filename, pq_name) os.rename(dp_mm.filename, dp_name) @@ -224,9 +311,9 @@ def get(self, idx): y = pt.tensor(self.y_mm[idx], dtype=pt.float32).view( 1, 1 ) # It would be better to use float64, but the trainer complaints - dy = pt.tensor(self.dy_mm[atoms], dtype=pt.float32) + neg_dy = pt.tensor(self.neg_dy_mm[atoms], dtype=pt.float32) q = pt.tensor(self.q_mm[idx], dtype=pt.long) pq = pt.tensor(self.pq_mm[atoms], dtype=pt.float32) dp = pt.tensor(self.dp_mm[idx], dtype=pt.float32) - return Data(z=z, pos=pos, y=y, dy=dy, q=q, pq=pq, dp=dp) + return Data(z=z, pos=pos, y=y, neg_dy=neg_dy, q=q, pq=pq, dp=dp) diff --git a/torchmdnet/datasets/ani.py b/torchmdnet/datasets/ani.py index e00202625..febfa7168 100644 --- a/torchmdnet/datasets/ani.py +++ b/torchmdnet/datasets/ani.py @@ -24,7 +24,7 @@ def compute_reference_energy(self, atomic_numbers): energy = sum(self.ELEMENT_ENERGIES[z] for z in atomic_numbers) return energy * ANIBase.HARTREE_TO_EV - def sample_iter(self): + def sample_iter(self, mol_ids=False): raise NotImplementedError() def get_atomref(self, max_z=100): @@ -40,18 +40,18 @@ def __init__( self.name = self.__class__.__name__ super().__init__(root, transform, pre_transform, pre_filter) - idx_name, z_name, pos_name, y_name, dy_name = self.processed_paths + idx_name, z_name, pos_name, y_name, neg_dy_name = self.processed_paths self.idx_mm = np.memmap(idx_name, mode="r", dtype=np.int64) self.z_mm = np.memmap(z_name, mode="r", dtype=np.int8) self.pos_mm = np.memmap( pos_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3) ) self.y_mm = np.memmap(y_name, mode="r", dtype=np.float64) - self.dy_mm = ( + self.neg_dy_mm = ( np.memmap( - dy_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3) + neg_dy_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3) ) - if os.path.getsize(dy_name) > 0 + if os.path.getsize(neg_dy_name) > 0 else None ) @@ -66,7 +66,7 @@ def processed_file_names(self): f"{self.name}.z.mmap", f"{self.name}.pos.mmap", f"{self.name}.y.mmap", - f"{self.name}.dy.mmap", + f"{self.name}.neg_dy.mmap", ] def filter_and_pre_transform(self, data): @@ -80,20 +80,19 @@ def filter_and_pre_transform(self, data): return data def process(self): - print("Gathering statistics...") num_all_confs = 0 num_all_atoms = 0 for data in self.sample_iter(): num_all_confs += 1 num_all_atoms += data.z.shape[0] - has_dy = "dy" in data + has_neg_dy = "neg_dy" in data print(f" Total number of conformers: {num_all_confs}") print(f" Total number of atoms: {num_all_atoms}") - print(f" Forces available: {has_dy}") + print(f" Forces available: {has_neg_dy}") - idx_name, z_name, pos_name, y_name, dy_name = self.processed_paths + idx_name, z_name, pos_name, y_name, neg_dy_name = self.processed_paths idx_mm = np.memmap( idx_name + ".tmp", mode="w+", dtype=np.int64, shape=(num_all_confs + 1,) ) @@ -106,12 +105,12 @@ def process(self): y_mm = np.memmap( y_name + ".tmp", mode="w+", dtype=np.float64, shape=(num_all_confs,) ) - dy_mm = ( + neg_dy_mm = ( np.memmap( - dy_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_atoms, 3) + neg_dy_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_atoms, 3) ) - if has_dy - else open(dy_name, "w") + if has_neg_dy + else open(neg_dy_name, "w") ) print("Storing data...") @@ -123,8 +122,8 @@ def process(self): z_mm[i_atom:i_next_atom] = data.z.to(pt.int8) pos_mm[i_atom:i_next_atom] = data.pos y_mm[i_conf] = data.y - if has_dy: - dy_mm[i_atom:i_next_atom] = data.dy + if has_neg_dy: + neg_dy_mm[i_atom:i_next_atom] = data.neg_dy i_atom = i_next_atom @@ -135,15 +134,15 @@ def process(self): z_mm.flush() pos_mm.flush() y_mm.flush() - if has_dy: - dy_mm.flush() + if has_neg_dy: + neg_dy_mm.flush() os.rename(idx_mm.filename, idx_name) os.rename(z_mm.filename, z_name) os.rename(pos_mm.filename, pos_name) os.rename(y_mm.filename, y_name) - if has_dy: - os.rename(dy_mm.filename, dy_name) + if has_neg_dy: + os.rename(neg_dy_mm.filename, neg_dy_name) def len(self): return len(self.y_mm) @@ -158,11 +157,11 @@ def get(self, idx): ) # It would be better to use float64, but the trainer complaints y -= self.compute_reference_energy(z) - if self.dy_mm is None: + if self.neg_dy_mm is None: return Data(z=z, pos=pos, y=y) else: - dy = pt.tensor(self.dy_mm[atoms], dtype=pt.float32) - return Data(z=z, pos=pos, y=y, dy=dy) + neg_dy = pt.tensor(self.neg_dy_mm[atoms], dtype=pt.float32) + return Data(z=z, pos=pos, y=y, neg_dy=neg_dy) class ANI1(ANIBase): @@ -188,14 +187,14 @@ def download(self): extract_tar(archive, self.raw_dir) os.remove(archive) - def sample_iter(self): + def sample_iter(self, mol_ids=False): atomic_numbers = {b"H": 1, b"C": 6, b"N": 7, b"O": 8} for path in tqdm(self.raw_paths, desc="Files"): - molecules = list(h5py.File(path).values())[0].values() + molecules = list(h5py.File(path).values())[0].items() - for mol in tqdm(molecules, desc="Molecules", leave=False): + for mol_id, mol in tqdm(molecules, desc="Molecules", leave=False): z = pt.tensor( [atomic_numbers[atom] for atom in mol["species"]], dtype=pt.long ) @@ -209,7 +208,13 @@ def sample_iter(self): assert all_pos.shape[2] == 3 for pos, y in zip(all_pos, all_y): - data = Data(z=z, pos=pos, y=y.view(1, 1)) + + # Create a sample + args = dict(z=z, pos=pos, y=y.view(1, 1)) + if mol_ids: + args["mol_id"] = mol_id + data = Data(**args) + if data := self.filter_and_pre_transform(data): yield data @@ -264,19 +269,19 @@ class ANI1X(ANI1XBase): 8: -75.0362229210, } - def sample_iter(self): + def sample_iter(self, mol_ids=False): assert len(self.raw_paths) == 1 with h5py.File(self.raw_paths[0]) as h5: - for mol in tqdm(h5.values(), desc="Molecules"): + for mol_id, mol in tqdm(h5.items(), desc="Molecules"): z = pt.tensor(mol["atomic_numbers"][:], dtype=pt.long) all_pos = pt.tensor(mol["coordinates"][:], dtype=pt.float32) all_y = pt.tensor( mol["wb97x_dz.energy"][:] * self.HARTREE_TO_EV, dtype=pt.float64 ) - all_dy = pt.tensor( + all_neg_dy = pt.tensor( mol["wb97x_dz.forces"][:] * self.HARTREE_TO_EV, dtype=pt.float32 ) @@ -284,16 +289,21 @@ def sample_iter(self): assert all_pos.shape[1] == z.shape[0] assert all_pos.shape[2] == 3 - assert all_dy.shape[0] == all_y.shape[0] - assert all_dy.shape[1] == z.shape[0] - assert all_dy.shape[2] == 3 + assert all_neg_dy.shape[0] == all_y.shape[0] + assert all_neg_dy.shape[1] == z.shape[0] + assert all_neg_dy.shape[2] == 3 - for pos, y, dy in zip(all_pos, all_y, all_dy): + for pos, y, neg_dy in zip(all_pos, all_y, all_neg_dy): - if y.isnan() or dy.isnan().any(): + if y.isnan() or neg_dy.isnan().any(): continue - data = Data(z=z, pos=pos, y=y.view(1, 1), dy=dy) + # Create a sample + args = dict(z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy) + if mol_ids: + args["mol_id"] = mol_id + data = Data(**args) + if data := self.filter_and_pre_transform(data): yield data @@ -309,12 +319,13 @@ def process(self): class ANI1CCX(ANI1XBase): - def sample_iter(self): + + def sample_iter(self, mol_ids=False): assert len(self.raw_paths) == 1 with h5py.File(self.raw_paths[0]) as h5: - for mol in tqdm(h5.values(), desc="Molecules"): + for mol_id, mol in tqdm(h5.items(), desc="Molecules"): z = pt.tensor(mol["atomic_numbers"][:], dtype=pt.long) all_pos = pt.tensor(mol["coordinates"][:], dtype=pt.float32) @@ -331,7 +342,12 @@ def sample_iter(self): if y.isnan(): continue - data = Data(z=z, pos=pos, y=y.view(1, 1)) + # Create a sample + args = dict(z=z, pos=pos, y=y.view(1, 1)) + if mol_ids: + args["mol_id"] = mol_id + data = Data(**args) + if data := self.filter_and_pre_transform(data): yield data diff --git a/torchmdnet/datasets/comp6.py b/torchmdnet/datasets/comp6.py index 3f0affc25..3d29665fa 100644 --- a/torchmdnet/datasets/comp6.py +++ b/torchmdnet/datasets/comp6.py @@ -37,15 +37,15 @@ def __init__( self.name = self.__class__.__name__ super().__init__(root, transform, pre_transform, pre_filter) - idx_name, z_name, pos_name, y_name, dy_name = self.processed_paths + idx_name, z_name, pos_name, y_name, neg_dy_name = self.processed_paths self.idx_mm = np.memmap(idx_name, mode="r", dtype=np.int64) self.z_mm = np.memmap(z_name, mode="r", dtype=np.int8) self.pos_mm = np.memmap( pos_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3) ) self.y_mm = np.memmap(y_name, mode="r", dtype=np.float64) - self.dy_mm = np.memmap( - dy_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3) + self.neg_dy_mm = np.memmap( + neg_dy_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3) ) assert self.idx_mm[0] == 0 @@ -80,15 +80,15 @@ def processed_file_names(self): f"{self.name}.z.mmap", f"{self.name}.pos.mmap", f"{self.name}.y.mmap", - f"{self.name}.dy.mmap", + f"{self.name}.neg_dy.mmap", ] - def sample_iter(self): + def sample_iter(self, mol_ids=False): for path in tqdm(self.raw_paths, desc="Files"): - molecules = list(h5py.File(path).values())[0].values() + molecules = list(h5py.File(path).values())[0].items() - for mol in tqdm(molecules, desc="Molecules", leave=False): + for mol_id, mol in tqdm(molecules, desc="Molecules", leave=False): z = pt.tensor( [self.ATOMIC_NUMBERS[atom] for atom in mol["species"]], dtype=pt.long, @@ -97,7 +97,7 @@ def sample_iter(self): all_y = pt.tensor( mol["energies"][:] * self.HARTREE_TO_EV, dtype=pt.float64 ) - all_dy = pt.tensor( + all_neg_dy = pt.tensor( mol["forces"][:] * self.HARTREE_TO_EV, dtype=pt.float32 ) all_y -= self.compute_reference_energy(z) @@ -106,12 +106,17 @@ def sample_iter(self): assert all_pos.shape[1] == z.shape[0] assert all_pos.shape[2] == 3 - assert all_dy.shape[0] == all_y.shape[0] - assert all_dy.shape[1] == z.shape[0] - assert all_dy.shape[2] == 3 + assert all_neg_dy.shape[0] == all_y.shape[0] + assert all_neg_dy.shape[1] == z.shape[0] + assert all_neg_dy.shape[2] == 3 - for pos, y, dy in zip(all_pos, all_y, all_dy): - data = Data(z=z, pos=pos, y=y.view(1, 1), dy=dy) + for pos, y, neg_dy in zip(all_pos, all_y, all_neg_dy): + + # Create a sample + args = dict(z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy) + if mol_ids: + args["mol_id"] = mol_id + data = Data(**args) if self.pre_filter is not None and not self.pre_filter(data): continue @@ -133,7 +138,7 @@ def process(self): print(f" Total number of conformers: {num_all_confs}") print(f" Total number of atoms: {num_all_atoms}") - idx_name, z_name, pos_name, y_name, dy_name = self.processed_paths + idx_name, z_name, pos_name, y_name, neg_dy_name = self.processed_paths idx_mm = np.memmap( idx_name + ".tmp", mode="w+", dtype=np.int64, shape=(num_all_confs + 1,) ) @@ -146,8 +151,8 @@ def process(self): y_mm = np.memmap( y_name + ".tmp", mode="w+", dtype=np.float64, shape=(num_all_confs,) ) - dy_mm = np.memmap( - dy_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_atoms, 3) + neg_dy_mm = np.memmap( + neg_dy_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_atoms, 3) ) print("Storing data...") @@ -159,7 +164,7 @@ def process(self): z_mm[i_atom:i_next_atom] = data.z.to(pt.int8) pos_mm[i_atom:i_next_atom] = data.pos y_mm[i_conf] = data.y - dy_mm[i_atom:i_next_atom] = data.dy + neg_dy_mm[i_atom:i_next_atom] = data.neg_dy i_atom = i_next_atom @@ -170,13 +175,13 @@ def process(self): z_mm.flush() pos_mm.flush() y_mm.flush() - dy_mm.flush() + neg_dy_mm.flush() os.rename(idx_mm.filename, idx_name) os.rename(z_mm.filename, z_name) os.rename(pos_mm.filename, pos_name) os.rename(y_mm.filename, y_name) - os.rename(dy_mm.filename, dy_name) + os.rename(neg_dy_mm.filename, neg_dy_name) def len(self): return len(self.y_mm) @@ -189,9 +194,9 @@ def get(self, idx): y = pt.tensor(self.y_mm[idx], dtype=pt.float32).view( 1, 1 ) # It would be better to use float64, but the trainer complaints - dy = pt.tensor(self.dy_mm[atoms], dtype=pt.float32) + neg_dy = pt.tensor(self.neg_dy_mm[atoms], dtype=pt.float32) - return Data(z=z, pos=pos, y=y, dy=dy) + return Data(z=z, pos=pos, y=y, neg_dy=neg_dy) class ANIMD(COMP6Base): diff --git a/torchmdnet/datasets/custom.py b/torchmdnet/datasets/custom.py index a3200439a..b34d9cc31 100644 --- a/torchmdnet/datasets/custom.py +++ b/torchmdnet/datasets/custom.py @@ -97,7 +97,7 @@ def get(self, idx): force_data = np.array( np.load(self.forcefiles[fileid], mmap_mode="r")[index] ) - features["dy"] = torch.from_numpy(force_data) + features["neg_dy"] = torch.from_numpy(force_data) return Data(**features) diff --git a/torchmdnet/datasets/hdf.py b/torchmdnet/datasets/hdf.py index 87f5974d2..4bfb1edf8 100644 --- a/torchmdnet/datasets/hdf.py +++ b/torchmdnet/datasets/hdf.py @@ -73,7 +73,7 @@ def get(self, idx): pos=torch.from_numpy(pos[i]), z=torch.from_numpy(types[i]).to(torch.long), y=torch.tensor([[energy[i]]]), - dy=torch.from_numpy(forces[i]), + neg_dy=torch.from_numpy(forces[i]), ) else: types, pos, energy, i = self.index[idx] diff --git a/torchmdnet/datasets/md17.py b/torchmdnet/datasets/md17.py index 6a8f8deda..98d8c12ce 100644 --- a/torchmdnet/datasets/md17.py +++ b/torchmdnet/datasets/md17.py @@ -36,7 +36,7 @@ def __init__(self, root, transform=None, pre_transform=None, molecules=None): molecules = ",".join(MD17.available_molecules) self.molecules = molecules.split(",") - for mol in molecules: + for mol in self.molecules: if mol not in MD17.available_molecules: raise RuntimeError(f"Molecule '{mol}' does not exist in MD17") @@ -92,8 +92,8 @@ def process(self): forces = torch.from_numpy(data_npz["F"]).float() samples = [] - for pos, y, dy in zip(positions, energies, forces): - samples.append(Data(z=z, pos=pos, y=y.unsqueeze(1), dy=dy)) + for pos, y, neg_dy in zip(positions, energies, forces): + samples.append(Data(z=z, pos=pos, y=y.unsqueeze(1), neg_dy=neg_dy)) if self.pre_filter is not None: samples = [data for data in samples if self.pre_filter(data)] diff --git a/torchmdnet/datasets/qm9q.py b/torchmdnet/datasets/qm9q.py index d41755ef2..431964b69 100644 --- a/torchmdnet/datasets/qm9q.py +++ b/torchmdnet/datasets/qm9q.py @@ -45,7 +45,7 @@ def __init__( z_name, pos_name, y_name, - dy_name, + neg_dy_name, q_name, pq_name, dp_name, @@ -56,8 +56,8 @@ def __init__( pos_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3) ) self.y_mm = np.memmap(y_name, mode="r", dtype=np.float64) - self.dy_mm = np.memmap( - dy_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3) + self.neg_dy_mm = np.memmap( + neg_dy_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3) ) self.q_mm = np.memmap(q_name, mode="r", dtype=np.int8) self.pq_mm = np.memmap(pq_name, mode="r", dtype=np.float32) @@ -120,12 +120,12 @@ def compute_reference_energy(atomic_numbers, charge): return energy * QM9q.HARTREE_TO_EV - def sample_iter(self): + def sample_iter(self, mol_ids=False): for path in tqdm(self.raw_paths, desc="Files"): - molecules = list(h5py.File(path).values())[0].values() + molecules = list(h5py.File(path).values())[0].items() - for mol in tqdm(molecules, desc="Molecules", leave=False): + for mol_id, mol in tqdm(molecules, desc="Molecules", leave=False): z = pt.tensor(mol["atomic_numbers"], dtype=pt.long) for conf in mol["energy"]: @@ -144,13 +144,13 @@ def sample_iter(self): mol["gradient_vector"].attrs["units"] == "vector : Hartree/Bohr " ) - dy = ( + neg_dy = ( -pt.tensor(mol["gradient_vector"][conf], dtype=pt.float32) * self.HARTREE_TO_EV / self.BORH_TO_ANGSTROM ) - assert z.shape[0] == dy.shape[0] - assert dy.shape[1] == 3 + assert z.shape[0] == neg_dy.shape[0] + assert neg_dy.shape[1] == 3 assert ( mol["electronic_charge"].attrs["units"] @@ -168,10 +168,14 @@ def sample_iter(self): y -= self.compute_reference_energy(z, q) # Skip samples with large forces - if dy.norm(dim=1).max() > 100: # eV/A + if neg_dy.norm(dim=1).max() > 100: # eV/A continue - data = Data(z=z, pos=pos, y=y.view(1, 1), dy=dy, q=q, pq=pq, dp=dp) + # Create a sample + args = dict(z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy, q=q, pq=pq, dp=dp) + if mol_ids: + args["mol_id"] = mol_id + data = Data(**args) if self.pre_filter is not None and not self.pre_filter(data): continue @@ -188,7 +192,7 @@ def processed_file_names(self): f"{self.name}.z.mmap", f"{self.name}.pos.mmap", f"{self.name}.y.mmap", - f"{self.name}.dy.mmap", + f"{self.name}.neg_dy.mmap", f"{self.name}.q.mmap", f"{self.name}.pq.mmap", f"{self.name}.dp.mmap", @@ -211,7 +215,7 @@ def process(self): z_name, pos_name, y_name, - dy_name, + neg_dy_name, q_name, pq_name, dp_name, @@ -226,8 +230,8 @@ def process(self): y_mm = np.memmap( y_name + ".tmp", mode="w+", dtype=np.float64, shape=num_all_confs ) - dy_mm = np.memmap( - dy_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_atoms, 3) + neg_dy_mm = np.memmap( + neg_dy_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_atoms, 3) ) q_mm = np.memmap(q_name + ".tmp", mode="w+", dtype=np.int8, shape=num_all_confs) pq_mm = np.memmap( @@ -246,7 +250,7 @@ def process(self): z_mm[i_atom:i_next_atom] = data.z.to(pt.int8) pos_mm[i_atom:i_next_atom] = data.pos y_mm[i_conf] = data.y - dy_mm[i_atom:i_next_atom] = data.dy + neg_dy_mm[i_atom:i_next_atom] = data.neg_dy q_mm[i_conf] = data.q.to(pt.int8) pq_mm[i_atom:i_next_atom] = data.pq dp_mm[i_conf] = data.dp @@ -260,7 +264,7 @@ def process(self): z_mm.flush() pos_mm.flush() y_mm.flush() - dy_mm.flush() + neg_dy_mm.flush() q_mm.flush() pq_mm.flush() dp_mm.flush() @@ -269,7 +273,7 @@ def process(self): os.rename(z_mm.filename, z_name) os.rename(pos_mm.filename, pos_name) os.rename(y_mm.filename, y_name) - os.rename(dy_mm.filename, dy_name) + os.rename(neg_dy_mm.filename, neg_dy_name) os.rename(q_mm.filename, q_name) os.rename(pq_mm.filename, pq_name) os.rename(dp_mm.filename, dp_name) @@ -278,16 +282,15 @@ def len(self): return len(self.y_mm) def get(self, idx): - atoms = slice(self.idx_mm[idx], self.idx_mm[idx + 1]) z = pt.tensor(self.z_mm[atoms], dtype=pt.long) pos = pt.tensor(self.pos_mm[atoms], dtype=pt.float32) y = pt.tensor(self.y_mm[idx], dtype=pt.float32).view( 1, 1 ) # It would be better to use float64, but the trainer complaints - dy = pt.tensor(self.dy_mm[atoms], dtype=pt.float32) + neg_dy = pt.tensor(self.neg_dy_mm[atoms], dtype=pt.float32) q = pt.tensor(self.q_mm[idx], dtype=pt.long) pq = pt.tensor(self.pq_mm[atoms], dtype=pt.float32) dp = pt.tensor(self.dp_mm[idx], dtype=pt.float32) - return Data(z=z, pos=pos, y=y, dy=dy, q=q, pq=pq, dp=dp) + return Data(z=z, pos=pos, y=y, neg_dy=neg_dy, q=q, pq=pq, dp=dp) diff --git a/torchmdnet/datasets/spice.py b/torchmdnet/datasets/spice.py index 269e2e236..f459b2369 100644 --- a/torchmdnet/datasets/spice.py +++ b/torchmdnet/datasets/spice.py @@ -26,8 +26,14 @@ class SPICE(Dataset): The loader can filter conformations with large gradients. The maximum gradient norm threshold can be set with `max_gradient`. By default, the filter is not applied. - For examples, the filter the threshold is set to 100 eV/A: + For example, the filter the threshold is set to 100 eV/A: >>> ds = SPICE(".", max_gradient=100) + + The molecules can be subsampled by loading only every `subsample_molecules`-th molecule. + By default is `subsample_molecules` is set to 1 (load all the molecules). + + For example, only every 10th molecule is loaded: + >>> ds = SPICE(".", subsample_molecules=10) """ HARTREE_TO_EV = 27.211386246 @@ -52,7 +58,7 @@ def processed_file_names(self): f"{self.name}.z.mmap", f"{self.name}.pos.mmap", f"{self.name}.y.mmap", - f"{self.name}.dy.mmap", + f"{self.name}.neg_dy.mmap", ] def __init__( @@ -61,43 +67,50 @@ def __init__( transform=None, pre_transform=None, pre_filter=None, - version="1.0", + version="1.1.1", subsets=None, max_gradient=None, + subsample_molecules=1, ): - arg_hash = f"{version}{subsets}{max_gradient}" + arg_hash = f"{version}{subsets}{max_gradient}{subsample_molecules}" arg_hash = hashlib.md5(arg_hash.encode()).hexdigest() self.name = f"{self.__class__.__name__}-{arg_hash}" self.version = version self.subsets = subsets self.max_gradient = max_gradient + self.subsample_molecules = int(subsample_molecules) super().__init__(root, transform, pre_transform, pre_filter) - idx_name, z_name, pos_name, y_name, dy_name = self.processed_paths + idx_name, z_name, pos_name, y_name, neg_dy_name = self.processed_paths self.idx_mm = np.memmap(idx_name, mode="r", dtype=np.int64) self.z_mm = np.memmap(z_name, mode="r", dtype=np.int8) self.pos_mm = np.memmap( pos_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3) ) self.y_mm = np.memmap(y_name, mode="r", dtype=np.float64) - self.dy_mm = np.memmap( - dy_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3) + self.neg_dy_mm = np.memmap( + neg_dy_name, mode="r", dtype=np.float32, shape=(self.z_mm.shape[0], 3) ) assert self.idx_mm[0] == 0 assert self.idx_mm[-1] == len(self.z_mm) assert len(self.idx_mm) == len(self.y_mm) + 1 - def sample_iter(self): - + def sample_iter(self, mol_ids=False): assert len(self.raw_paths) == 1 + assert self.subsample_molecules > 0 - for mol in tqdm(h5py.File(self.raw_paths[0]).values(), desc="Molecules"): + molecules = h5py.File(self.raw_paths[0]).items() + for i_mol, (mol_id, mol) in tqdm(enumerate(molecules), desc="Molecules"): if self.subsets: if mol["subset"][0].decode() not in list(self.subsets): continue + # Subsample molecules + if i_mol % self.subsample_molecules != 0: + continue + z = pt.tensor(mol["atomic_numbers"], dtype=pt.long) all_pos = ( pt.tensor(mol["conformations"], dtype=pt.float32) @@ -107,7 +120,7 @@ def sample_iter(self): pt.tensor(mol["formation_energy"], dtype=pt.float64) * self.HARTREE_TO_EV ) - all_dy = ( + all_neg_dy = ( -pt.tensor(mol["dft_total_gradient"], dtype=pt.float32) * self.HARTREE_TO_EV / self.BORH_TO_ANGSTROM @@ -117,18 +130,22 @@ def sample_iter(self): assert all_pos.shape[1] == z.shape[0] assert all_pos.shape[2] == 3 - assert all_dy.shape[0] == all_y.shape[0] - assert all_dy.shape[1] == z.shape[0] - assert all_dy.shape[2] == 3 + assert all_neg_dy.shape[0] == all_y.shape[0] + assert all_neg_dy.shape[1] == z.shape[0] + assert all_neg_dy.shape[2] == 3 - for pos, y, dy in zip(all_pos, all_y, all_dy): + for pos, y, neg_dy in zip(all_pos, all_y, all_neg_dy): # Skip samples with large forces if self.max_gradient: - if dy.norm(dim=1).max() > float(self.max_gradient): + if neg_dy.norm(dim=1).max() > float(self.max_gradient): continue - data = Data(z=z, pos=pos, y=y.view(1, 1), dy=dy) + # Create a sample + args = dict(z=z, pos=pos, y=y.view(1, 1), neg_dy=neg_dy) + if mol_ids: + args["mol_id"] = mol_id + data = Data(**args) if self.pre_filter is not None and not self.pre_filter(data): continue @@ -146,7 +163,8 @@ def process(self): print("Arguments") print(f" version: {self.version}") print(f" subsets: {self.subsets}") - print(f" max_gradient: {self.max_gradient} eV/A\n") + print(f" max_gradient: {self.max_gradient} eV/A") + print(f" subsample_molecules: {self.subsample_molecules}\n") print("Gathering statistics...") num_all_confs = 0 @@ -158,7 +176,7 @@ def process(self): print(f" Total number of conformers: {num_all_confs}") print(f" Total number of atoms: {num_all_atoms}") - idx_name, z_name, pos_name, y_name, dy_name = self.processed_paths + idx_name, z_name, pos_name, y_name, neg_dy_name = self.processed_paths idx_mm = np.memmap( idx_name + ".tmp", mode="w+", dtype=np.int64, shape=(num_all_confs + 1,) ) @@ -171,8 +189,8 @@ def process(self): y_mm = np.memmap( y_name + ".tmp", mode="w+", dtype=np.float64, shape=(num_all_confs,) ) - dy_mm = np.memmap( - dy_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_atoms, 3) + neg_dy_mm = np.memmap( + neg_dy_name + ".tmp", mode="w+", dtype=np.float32, shape=(num_all_atoms, 3) ) print("Storing data...") @@ -184,7 +202,7 @@ def process(self): z_mm[i_atom:i_next_atom] = data.z.to(pt.int8) pos_mm[i_atom:i_next_atom] = data.pos y_mm[i_conf] = data.y - dy_mm[i_atom:i_next_atom] = data.dy + neg_dy_mm[i_atom:i_next_atom] = data.neg_dy i_atom = i_next_atom @@ -195,25 +213,24 @@ def process(self): z_mm.flush() pos_mm.flush() y_mm.flush() - dy_mm.flush() + neg_dy_mm.flush() os.rename(idx_mm.filename, idx_name) os.rename(z_mm.filename, z_name) os.rename(pos_mm.filename, pos_name) os.rename(y_mm.filename, y_name) - os.rename(dy_mm.filename, dy_name) + os.rename(neg_dy_mm.filename, neg_dy_name) def len(self): return len(self.y_mm) def get(self, idx): - atoms = slice(self.idx_mm[idx], self.idx_mm[idx + 1]) z = pt.tensor(self.z_mm[atoms], dtype=pt.long) pos = pt.tensor(self.pos_mm[atoms], dtype=pt.float32) y = pt.tensor(self.y_mm[idx], dtype=pt.float32).view( 1, 1 ) # It would be better to use float64, but the trainer complaints - dy = pt.tensor(self.dy_mm[atoms], dtype=pt.float32) + neg_dy = pt.tensor(self.neg_dy_mm[atoms], dtype=pt.float32) - return Data(z=z, pos=pos, y=y, dy=dy) + return Data(z=z, pos=pos, y=y, neg_dy=neg_dy) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 580e7b314..be6e32dfb 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -79,7 +79,9 @@ def create_model(args, prior_model=None, mean=None, std=None): # create output network output_prefix = "Equivariant" if is_equivariant else "" output_model = getattr(output_modules, output_prefix + args["output_model"])( - args["embedding_dimension"], args["activation"] + args["embedding_dimension"], + activation=args["activation"], + reduce_op=args["reduce_op"], ) # combine representation and output network @@ -87,7 +89,6 @@ def create_model(args, prior_model=None, mean=None, std=None): representation_model, output_model, prior_model=prior_model, - reduce_op=args["reduce_op"], mean=mean, std=std, derivative=args["derivative"], @@ -118,7 +119,6 @@ def __init__( representation_model, output_model, prior_model=None, - reduce_op="add", mean=None, std=None, derivative=False, @@ -137,7 +137,6 @@ def __init__( ) ) - self.reduce_op = reduce_op self.derivative = derivative mean = torch.scalar_tensor(0) if mean is None else mean @@ -178,31 +177,29 @@ def forward( if self.std is not None: x = x * self.std - # apply atomwise prior model + # apply atom-wise prior model if self.prior_model is not None: - if self.prior_model.atomwise: - x = self.prior_model(x, z, pos, batch) + x = self.prior_model.pre_reduce(x, z, pos, batch) # aggregate atoms - out = scatter(x, batch, dim=0, reduce=self.reduce_op) - - # apply non-atomwise prior model - if self.prior_model is not None: - if not self.prior_model.atomwise: - out += self.prior_model(x, z, pos, batch) + x = self.output_model.reduce(x, batch) # shift by data mean if self.mean is not None: - out = out + self.mean + x = x + self.mean # apply output model after reduction - out = self.output_model.post_reduce(out) + y = self.output_model.post_reduce(x) + + # apply molecular-wise prior model + if self.prior_model is not None: + y = self.prior_model.post_reduce(y, z, pos, batch) # compute gradients with respect to coordinates if self.derivative: - grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(out)] + grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(y)] dy = grad( - [out], + [y], [pos], grad_outputs=grad_outputs, create_graph=True, @@ -210,6 +207,6 @@ def forward( )[0] if dy is None: raise RuntimeError("Autograd returned None for the force prediction.") - return out, -dy + return y, -dy # TODO: return only `out` once Union typing works with TorchScript (https://github.com/pytorch/pytorch/pull/53180) - return out, None + return y, None diff --git a/torchmdnet/models/output_modules.py b/torchmdnet/models/output_modules.py index ab003048b..0bcec935f 100644 --- a/torchmdnet/models/output_modules.py +++ b/torchmdnet/models/output_modules.py @@ -1,4 +1,5 @@ from abc import abstractmethod, ABCMeta +from torch_scatter import scatter from typing import Optional from torchmdnet.models.utils import act_class_mapping, GatedEquivariantBlock from torchmdnet.utils import atomic_masses @@ -11,9 +12,10 @@ class OutputModel(nn.Module, metaclass=ABCMeta): - def __init__(self, allow_prior_model): + def __init__(self, allow_prior_model, reduce_op): super(OutputModel, self).__init__() self.allow_prior_model = allow_prior_model + self.reduce_op = reduce_op def reset_parameters(self): pass @@ -22,13 +24,24 @@ def reset_parameters(self): def pre_reduce(self, x, v, z, pos, batch): return + def reduce(self, x, batch): + return scatter(x, batch, dim=0, reduce=self.reduce_op) + def post_reduce(self, x): return x class Scalar(OutputModel): - def __init__(self, hidden_channels, activation="silu", allow_prior_model=True): - super(Scalar, self).__init__(allow_prior_model=allow_prior_model) + def __init__( + self, + hidden_channels, + activation="silu", + allow_prior_model=True, + reduce_op="sum", + ): + super(Scalar, self).__init__( + allow_prior_model=allow_prior_model, reduce_op=reduce_op + ) act_class = act_class_mapping[activation] self.output_network = nn.Sequential( nn.Linear(hidden_channels, hidden_channels // 2), @@ -49,8 +62,16 @@ def pre_reduce(self, x, v: Optional[torch.Tensor], z, pos, batch): class EquivariantScalar(OutputModel): - def __init__(self, hidden_channels, activation="silu", allow_prior_model=True): - super(EquivariantScalar, self).__init__(allow_prior_model=allow_prior_model) + def __init__( + self, + hidden_channels, + activation="silu", + allow_prior_model=True, + reduce_op="sum", + ): + super(EquivariantScalar, self).__init__( + allow_prior_model=allow_prior_model, reduce_op=reduce_op + ) self.output_network = nn.ModuleList( [ GatedEquivariantBlock( @@ -77,9 +98,9 @@ def pre_reduce(self, x, v, z, pos, batch): class DipoleMoment(Scalar): - def __init__(self, hidden_channels, activation="silu"): + def __init__(self, hidden_channels, activation="silu", reduce_op="sum"): super(DipoleMoment, self).__init__( - hidden_channels, activation, allow_prior_model=False + hidden_channels, activation, allow_prior_model=False, reduce_op=reduce_op ) atomic_mass = torch.from_numpy(atomic_masses).float() self.register_buffer("atomic_mass", atomic_mass) @@ -98,9 +119,9 @@ def post_reduce(self, x): class EquivariantDipoleMoment(EquivariantScalar): - def __init__(self, hidden_channels, activation="silu"): + def __init__(self, hidden_channels, activation="silu", reduce_op="sum"): super(EquivariantDipoleMoment, self).__init__( - hidden_channels, activation, allow_prior_model=False + hidden_channels, activation, allow_prior_model=False, reduce_op=reduce_op ) atomic_mass = torch.from_numpy(atomic_masses).float() self.register_buffer("atomic_mass", atomic_mass) @@ -120,8 +141,10 @@ def post_reduce(self, x): class ElectronicSpatialExtent(OutputModel): - def __init__(self, hidden_channels, activation="silu"): - super(ElectronicSpatialExtent, self).__init__(allow_prior_model=False) + def __init__(self, hidden_channels, activation="silu", reduce_op="sum"): + super(ElectronicSpatialExtent, self).__init__( + allow_prior_model=False, reduce_op=reduce_op + ) act_class = act_class_mapping[activation] self.output_network = nn.Sequential( nn.Linear(hidden_channels, hidden_channels // 2), @@ -155,9 +178,9 @@ class EquivariantElectronicSpatialExtent(ElectronicSpatialExtent): class EquivariantVectorOutput(EquivariantScalar): - def __init__(self, hidden_channels, activation="silu"): + def __init__(self, hidden_channels, activation="silu", reduce_op="sum"): super(EquivariantVectorOutput, self).__init__( - hidden_channels, activation, allow_prior_model=False + hidden_channels, activation, allow_prior_model=False, reduce_op="sum" ) def pre_reduce(self, x, v, z, pos, batch): diff --git a/torchmdnet/module.py b/torchmdnet/module.py index 429fb8b31..f079f2cbf 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -72,7 +72,7 @@ def step(self, batch, loss_fn, stage): with torch.set_grad_enabled(stage == "train" or self.hparams.derivative): # TODO: the model doesn't necessarily need to return a derivative once # Union typing works under TorchScript (https://github.com/pytorch/pytorch/pull/53180) - pred, deriv = self( + y, neg_dy = self( batch.z, batch.pos, batch=batch.batch, @@ -80,37 +80,37 @@ def step(self, batch, loss_fn, stage): s=batch.s if self.hparams.spin else None, ) - loss_y, loss_dy = 0, 0 + loss_y, loss_neg_dy = 0, 0 if self.hparams.derivative: if "y" not in batch: # "use" both outputs of the model's forward function but discard the first - # to only use the derivative and avoid 'Expected to have finished reduction + # to only use the negative derivative and avoid 'Expected to have finished reduction # in the prior iteration before starting a new one.', which otherwise get's # thrown because of setting 'find_unused_parameters=False' in the DDPPlugin - deriv = deriv + pred.sum() * 0 - - # force/derivative loss - loss_dy = loss_fn(deriv, batch.dy) - - if stage in ["train", "val"] and self.hparams.ema_alpha_dy < 1: - if self.ema[stage + "_dy"] is None: - self.ema[stage + "_dy"] = loss_dy.detach() - # apply exponential smoothing over batches to dy - loss_dy = ( - self.hparams.ema_alpha_dy * loss_dy - + (1 - self.hparams.ema_alpha_dy) * self.ema[stage + "_dy"] + neg_dy = neg_dy + y.sum() * 0 + + # negative derivative loss + loss_neg_dy = loss_fn(neg_dy, batch.neg_dy) + + if stage in ["train", "val"] and self.hparams.ema_alpha_neg_dy < 1: + if self.ema[stage + "_neg_dy"] is None: + self.ema[stage + "_neg_dy"] = loss_neg_dy.detach() + # apply exponential smoothing over batches to neg_dy + loss_neg_dy = ( + self.hparams.ema_alpha_neg_dy * loss_neg_dy + + (1 - self.hparams.ema_alpha_neg_dy) * self.ema[stage + "_neg_dy"] ) - self.ema[stage + "_dy"] = loss_dy.detach() + self.ema[stage + "_neg_dy"] = loss_neg_dy.detach() - if self.hparams.force_weight > 0: - self.losses[stage + "_dy"].append(loss_dy.detach()) + if self.hparams.neg_dy_weight > 0: + self.losses[stage + "_neg_dy"].append(loss_neg_dy.detach()) if "y" in batch: if batch.y.ndim == 1: batch.y = batch.y.unsqueeze(1) - # energy/prediction loss - loss_y = loss_fn(pred, batch.y) + # y loss + loss_y = loss_fn(y, batch.y) if stage in ["train", "val"] and self.hparams.ema_alpha_y < 1: if self.ema[stage + "_y"] is None: @@ -122,11 +122,11 @@ def step(self, batch, loss_fn, stage): ) self.ema[stage + "_y"] = loss_y.detach() - if self.hparams.energy_weight > 0: + if self.hparams.y_weight > 0: self.losses[stage + "_y"].append(loss_y.detach()) # total loss - loss = loss_y * self.hparams.energy_weight + loss_dy * self.hparams.force_weight + loss = loss_y * self.hparams.y_weight + loss_neg_dy * self.hparams.neg_dy_weight self.losses[stage].append(loss.detach()) return loss @@ -172,20 +172,20 @@ def validation_epoch_end(self, validation_step_outputs): result_dict["test_loss"] = torch.stack(self.losses["test"]).mean() # if prediction and derivative are present, also log them separately - if len(self.losses["train_y"]) > 0 and len(self.losses["train_dy"]) > 0: + if len(self.losses["train_y"]) > 0 and len(self.losses["train_neg_dy"]) > 0: result_dict["train_loss_y"] = torch.stack(self.losses["train_y"]).mean() - result_dict["train_loss_dy"] = torch.stack( - self.losses["train_dy"] + result_dict["train_loss_neg_dy"] = torch.stack( + self.losses["train_neg_dy"] ).mean() result_dict["val_loss_y"] = torch.stack(self.losses["val_y"]).mean() - result_dict["val_loss_dy"] = torch.stack(self.losses["val_dy"]).mean() + result_dict["val_loss_neg_dy"] = torch.stack(self.losses["val_neg_dy"]).mean() if len(self.losses["test"]) > 0: result_dict["test_loss_y"] = torch.stack( self.losses["test_y"] ).mean() - result_dict["test_loss_dy"] = torch.stack( - self.losses["test_dy"] + result_dict["test_loss_neg_dy"] = torch.stack( + self.losses["test_neg_dy"] ).mean() self.log_dict(result_dict, sync_dist=True) @@ -199,10 +199,10 @@ def _reset_losses_dict(self): "train_y": [], "val_y": [], "test_y": [], - "train_dy": [], - "val_dy": [], - "test_dy": [], + "train_neg_dy": [], + "val_neg_dy": [], + "test_neg_dy": [], } def _reset_ema_dict(self): - self.ema = {"train_y": None, "val_y": None, "train_dy": None, "val_dy": None} + self.ema = {"train_y": None, "val_y": None, "train_neg_dy": None, "val_neg_dy": None} diff --git a/torchmdnet/priors.py b/torchmdnet/priors.py deleted file mode 100644 index 16484513a..000000000 --- a/torchmdnet/priors.py +++ /dev/null @@ -1,133 +0,0 @@ -from abc import abstractmethod, ABCMeta -import torch -from torch import nn -from pytorch_lightning.utilities import rank_zero_warn -from torchmdnet.models.utils import Distance, CosineCutoff - -__all__ = ["Atomref", "ZBL"] - - -class BasePrior(nn.Module, metaclass=ABCMeta): - r"""Base class for prior models. - Derive this class to make custom prior models, which take some arguments and a dataset as input. - As an example, have a look at the `torchmdnet.priors.Atomref` prior. - """ - - def __init__(self, dataset=None, atomwise=True): - super(BasePrior, self).__init__() - self.atomwise = atomwise - - @abstractmethod - def get_init_args(self): - r"""A function that returns all required arguments to construct a prior object. - The values should be returned inside a dict with the keys being the arguments' names. - All values should also be saveable in a .yaml file as this is used to reconstruct the - prior model from a checkpoint file. - """ - return - - @abstractmethod - def forward(self, x, z, pos, batch): - r"""Forward method of the prior model. - - Args: - x (torch.Tensor): scalar atomwise predictions from the model. - z (torch.Tensor): atom types of all atoms. - pos (torch.Tensor): 3D atomic coordinates. - batch (torch.Tensor): tensor containing the sample index for each atom. - - Returns: - torch.Tensor: If this is an atom-wise prior (self.atomwise is True), the return value - contains updated scalar atomwise predictions. Otherwise, it is a single scalar that - is added to the result after summing over atoms. - """ - return - - -class Atomref(BasePrior): - r"""Atomref prior model. - When using this in combination with some dataset, the dataset class must implement - the function `get_atomref`, which returns the atomic reference values as a tensor. - """ - - def __init__(self, max_z=None, dataset=None): - super(Atomref, self).__init__() - if max_z is None and dataset is None: - raise ValueError("Can't instantiate Atomref prior, all arguments are None.") - if dataset is None: - atomref = torch.zeros(max_z, 1) - else: - atomref = dataset.get_atomref() - if atomref is None: - rank_zero_warn( - "The atomref returned by the dataset is None, defaulting to zeros with max. " - "atomic number 99. Maybe atomref is not defined for the current target." - ) - atomref = torch.zeros(100, 1) - - if atomref.ndim == 1: - atomref = atomref.view(-1, 1) - self.register_buffer("initial_atomref", atomref) - self.atomref = nn.Embedding(len(atomref), 1) - self.atomref.weight.data.copy_(atomref) - - def reset_parameters(self): - self.atomref.weight.data.copy_(self.initial_atomref) - - def get_init_args(self): - return dict(max_z=self.initial_atomref.size(0)) - - def forward(self, x, z, pos, batch): - return x + self.atomref(z) - - -class ZBL(BasePrior): - """This class implements the Ziegler-Biersack-Littmark (ZBL) potential for screened nuclear repulsion. - Is is described in https://doi.org/10.1007/978-3-642-68779-2_5 (equations 9 and 10 on page 147). It - is an empirical potential that does a good job of describing the repulsion between atoms at very short - distances. - - To use this prior, the Dataset must provide the following attributes. - - atomic_number: 1D tensor of length max_z. atomic_number[z] is the atomic number of atoms with atom type z. - distance_scale: multiply by this factor to convert coordinates stored in the dataset to meters - energy_scale: multiply by this factor to convert energies stored in the dataset to Joules - """ - def __init__(self, cutoff_distance, max_num_neighbors, atomic_number=None, distance_scale=None, energy_scale=None, dataset=None): - super(ZBL, self).__init__(atomwise=False) - if atomic_number is None: - atomic_number = dataset.atomic_number - if distance_scale is None: - distance_scale = dataset.distance_scale - if energy_scale is None: - energy_scale = dataset.energy_scale - atomic_number = torch.as_tensor(atomic_number, dtype=torch.int8) - self.register_buffer("atomic_number", atomic_number) - self.distance = Distance(0, cutoff_distance, max_num_neighbors=max_num_neighbors) - self.cutoff = CosineCutoff(cutoff_upper=cutoff_distance) - self.cutoff_distance = cutoff_distance - self.max_num_neighbors = max_num_neighbors - self.distance_scale = distance_scale - self.energy_scale = energy_scale - - def get_init_args(self): - return {'cutoff_distance': self.cutoff_distance, - 'max_num_neighbors': self.max_num_neighbors, - 'atomic_number': self.atomic_number, - 'distance_scale': self.distance_scale, - 'energy_scale': self.energy_scale} - - def reset_parameters(self): - pass - - def forward(self, x, z, pos, batch): - edge_index, distance, _ = self.distance(pos, batch) - atomic_number = self.atomic_number[z[edge_index]] - # 5.29e-11 is the Bohr radius in meters. All other numbers are magic constants from the ZBL potential. - a = 0.8854*5.29177210903e-11/(atomic_number[0]**0.23 + atomic_number[1]**0.23) - d = distance*self.distance_scale/a - f = 0.1818*torch.exp(-3.2*d) + 0.5099*torch.exp(-0.9423*d) + 0.2802*torch.exp(-0.4029*d) + 0.02817*torch.exp(-0.2016*d) - f *= self.cutoff(distance) - # Compute the energy, converting to the dataset's units. Multiply by 0.5 because every atom pair - # appears twice. - return 0.5*(2.30707755e-28/self.energy_scale/self.distance_scale)*torch.sum(f*atomic_number[0]*atomic_number[1]/distance, dim=-1) diff --git a/torchmdnet/priors/__init__.py b/torchmdnet/priors/__init__.py new file mode 100644 index 000000000..3f961329d --- /dev/null +++ b/torchmdnet/priors/__init__.py @@ -0,0 +1,2 @@ +from torchmdnet.priors.atomref import Atomref +from torchmdnet.priors.zbl import ZBL diff --git a/torchmdnet/priors/atomref.py b/torchmdnet/priors/atomref.py new file mode 100644 index 000000000..bf5e9478f --- /dev/null +++ b/torchmdnet/priors/atomref.py @@ -0,0 +1,41 @@ +from torchmdnet.priors.base import BasePrior +import torch +from torch import nn +from pytorch_lightning.utilities import rank_zero_warn + + +class Atomref(BasePrior): + r"""Atomref prior model. + When using this in combination with some dataset, the dataset class must implement + the function `get_atomref`, which returns the atomic reference values as a tensor. + """ + + def __init__(self, max_z=None, dataset=None): + super().__init__() + if max_z is None and dataset is None: + raise ValueError("Can't instantiate Atomref prior, all arguments are None.") + if dataset is None: + atomref = torch.zeros(max_z, 1) + else: + atomref = dataset.get_atomref() + if atomref is None: + rank_zero_warn( + "The atomref returned by the dataset is None, defaulting to zeros with max. " + "atomic number 99. Maybe atomref is not defined for the current target." + ) + atomref = torch.zeros(100, 1) + + if atomref.ndim == 1: + atomref = atomref.view(-1, 1) + self.register_buffer("initial_atomref", atomref) + self.atomref = nn.Embedding(len(atomref), 1) + self.atomref.weight.data.copy_(atomref) + + def reset_parameters(self): + self.atomref.weight.data.copy_(self.initial_atomref) + + def get_init_args(self): + return dict(max_z=self.initial_atomref.size(0)) + + def pre_reduce(self, x, z, pos, batch): + return x + self.atomref(z) diff --git a/torchmdnet/priors/base.py b/torchmdnet/priors/base.py new file mode 100644 index 000000000..c43cdcf2b --- /dev/null +++ b/torchmdnet/priors/base.py @@ -0,0 +1,47 @@ +from torch import nn + + +class BasePrior(nn.Module): + r"""Base class for prior models. + Derive this class to make custom prior models, which take some arguments and a dataset as input. + As an example, have a look at the `torchmdnet.priors.atomref.Atomref` prior. + """ + + def __init__(self, dataset=None): + super().__init__() + + def get_init_args(self): + r"""A function that returns all required arguments to construct a prior object. + The values should be returned inside a dict with the keys being the arguments' names. + All values should also be saveable in a .yaml file as this is used to reconstruct the + prior model from a checkpoint file. + """ + return {} + + def pre_reduce(self, x, z, pos, batch): + r"""Pre-reduce method of the prior model. + + Args: + x (torch.Tensor): scalar atom-wise predictions from the model. + z (torch.Tensor): atom types of all atoms. + pos (torch.Tensor): 3D atomic coordinates. + batch (torch.Tensor): tensor containing the sample index for each atom. + + Returns: + torch.Tensor: updated scalar atom-wise predictions + """ + return x + + def post_reduce(self, y, z, pos, batch): + r"""Post-reduce method of the prior model. + + Args: + y (torch.Tensor): scalar molecule-wise predictions from the model. + z (torch.Tensor): atom types of all atoms. + pos (torch.Tensor): 3D atomic coordinates. + batch (torch.Tensor): tensor containing the sample index for each atom. + + Returns: + torch.Tensor: updated scalar molecular-wise predictions + """ + return y diff --git a/torchmdnet/priors/zbl.py b/torchmdnet/priors/zbl.py new file mode 100644 index 000000000..5e735d7e4 --- /dev/null +++ b/torchmdnet/priors/zbl.py @@ -0,0 +1,54 @@ +import torch +from torchmdnet.priors.base import BasePrior +from torchmdnet.models.utils import Distance, CosineCutoff + +class ZBL(BasePrior): + """This class implements the Ziegler-Biersack-Littmark (ZBL) potential for screened nuclear repulsion. + Is is described in https://doi.org/10.1007/978-3-642-68779-2_5 (equations 9 and 10 on page 147). It + is an empirical potential that does a good job of describing the repulsion between atoms at very short + distances. + + To use this prior, the Dataset must provide the following attributes. + + atomic_number: 1D tensor of length max_z. atomic_number[z] is the atomic number of atoms with atom type z. + distance_scale: multiply by this factor to convert coordinates stored in the dataset to meters + energy_scale: multiply by this factor to convert energies stored in the dataset to Joules + """ + def __init__(self, cutoff_distance, max_num_neighbors, atomic_number=None, distance_scale=None, energy_scale=None, dataset=None): + super(ZBL, self).__init__() + if atomic_number is None: + atomic_number = dataset.atomic_number + if distance_scale is None: + distance_scale = dataset.distance_scale + if energy_scale is None: + energy_scale = dataset.energy_scale + atomic_number = torch.as_tensor(atomic_number, dtype=torch.int8) + self.register_buffer("atomic_number", atomic_number) + self.distance = Distance(0, cutoff_distance, max_num_neighbors=max_num_neighbors) + self.cutoff = CosineCutoff(cutoff_upper=cutoff_distance) + self.cutoff_distance = cutoff_distance + self.max_num_neighbors = max_num_neighbors + self.distance_scale = distance_scale + self.energy_scale = energy_scale + + def get_init_args(self): + return {'cutoff_distance': self.cutoff_distance, + 'max_num_neighbors': self.max_num_neighbors, + 'atomic_number': self.atomic_number, + 'distance_scale': self.distance_scale, + 'energy_scale': self.energy_scale} + + def reset_parameters(self): + pass + + def post_reduce(self, y, z, pos, batch): + edge_index, distance, _ = self.distance(pos, batch) + atomic_number = self.atomic_number[z[edge_index]] + # 5.29e-11 is the Bohr radius in meters. All other numbers are magic constants from the ZBL potential. + a = 0.8854*5.29177210903e-11/(atomic_number[0]**0.23 + atomic_number[1]**0.23) + d = distance*self.distance_scale/a + f = 0.1818*torch.exp(-3.2*d) + 0.5099*torch.exp(-0.9423*d) + 0.2802*torch.exp(-0.4029*d) + 0.02817*torch.exp(-0.2016*d) + f *= self.cutoff(distance) + # Compute the energy, converting to the dataset's units. Multiply by 0.5 because every atom pair + # appears twice. + return y + 0.5*(2.30707755e-28/self.energy_scale/self.distance_scale)*torch.sum(f*atomic_number[0]*atomic_number[1]/distance, dim=-1) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 1345269cb..879f4eda3 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -33,7 +33,7 @@ def get_args(): parser.add_argument('--reset-trainer', type=bool, default=False, help='Reset training metrics (e.g. early stopping, lr) when loading a model checkpoint') parser.add_argument('--weight-decay', type=float, default=0.0, help='Weight decay strength') parser.add_argument('--ema-alpha-y', type=float, default=1.0, help='The amount of influence of new losses on the exponential moving average of y') - parser.add_argument('--ema-alpha-dy', type=float, default=1.0, help='The amount of influence of new losses on the exponential moving average of dy') + parser.add_argument('--ema-alpha-neg-dy', type=float, default=1.0, help='The amount of influence of new losses on the exponential moving average of dy') parser.add_argument('--ngpus', type=int, default=-1, help='Number of GPUs, -1 use all available. Use CUDA_VISIBLE_DEVICES=1, to decide gpus') parser.add_argument('--num-nodes', type=int, default=1, help='Number of nodes') parser.add_argument('--precision', type=int, default=32, choices=[16, 32], help='Floating point precision') @@ -56,8 +56,8 @@ def get_args(): parser.add_argument('--embed-files', default=None, type=str, help='Custom embedding files glob') parser.add_argument('--energy-files', default=None, type=str, help='Custom energy files glob') parser.add_argument('--force-files', default=None, type=str, help='Custom force files glob') - parser.add_argument('--energy-weight', default=1.0, type=float, help='Weighting factor for energies in the loss function') - parser.add_argument('--force-weight', default=1.0, type=float, help='Weighting factor for forces in the loss function') + parser.add_argument('--y-weight', default=1.0, type=float, help='Weighting factor for y label in the loss function') + parser.add_argument('--neg-dy-weight', default=1.0, type=float, help='Weighting factor for neg_dy label in the loss function') # model architecture parser.add_argument('--model', type=str, default='graph-network', choices=models.__all__, help='Which model to train') From 9e035cb3c1acd4c1b23819798ac92e4af712b362 Mon Sep 17 00:00:00 2001 From: Peter Eastman Date: Sat, 29 Oct 2022 13:32:49 -0700 Subject: [PATCH 07/12] Added test case for ZBL --- tests/test_priors.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/tests/test_priors.py b/tests/test_priors.py index 31481233d..c0a192351 100644 --- a/tests/test_priors.py +++ b/tests/test_priors.py @@ -4,7 +4,7 @@ import pytorch_lightning as pl from torchmdnet import models from torchmdnet.models.model import create_model -from torchmdnet.priors import Atomref +from torchmdnet.priors import Atomref, ZBL from torch_scatter import scatter from utils import load_example_args, create_example_batch, DummyDataset @@ -31,3 +31,31 @@ def test_atomref(model_name): # check if the output of both models differs by the expected atomref contribution expected_offset = scatter(dataset.get_atomref().squeeze()[z], batch).unsqueeze(1) torch.testing.assert_allclose(x_atomref, x_no_atomref + expected_offset) + +def test_zbl(): + pos = torch.tensor([[1.0, 0.0, 0.0], [2.5, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, -1.0]], dtype=torch.float32) # Atom positions in Bohr + types = torch.tensor([0, 1, 2, 1], dtype=torch.long) # Atom types + atomic_number = torch.tensor([1, 6, 8], dtype=torch.int8) # Mapping of atom types to atomic numbers + distance_scale = 5.29177210903e-11 # Convert Bohr to meters + energy_scale = 1000.0/6.02214076e23 # Convert kJ/mol to Joules + + # Use the ZBL class to compute the energy. + + zbl = ZBL(10.0, 5, atomic_number, distance_scale=distance_scale, energy_scale=energy_scale) + energy = zbl.post_reduce(torch.zeros((1,)), types, pos, torch.zeros(types.shape))[0] + + # Compare to the expected value. + + def compute_interaction(pos1, pos2, z1, z2): + delta = pos1-pos2 + r = torch.sqrt(torch.dot(delta, delta)) + x = r / (0.8854/(z1**0.23 + z2**0.23)) + phi = 0.1818*torch.exp(-3.2*x) + 0.5099*torch.exp(-0.9423*x) + 0.2802*torch.exp(-0.4029*x) + 0.02817*torch.exp(-0.2016*x) + cutoff = 0.5*(torch.cos(r*torch.pi/10.0) + 1.0) + return cutoff*phi*(138.935/5.29177210903e-2)*z1*z2/r + + expected = 0 + for i in range(len(pos)): + for j in range(i): + expected += compute_interaction(pos[i], pos[j], atomic_number[types[i]], atomic_number[types[j]]) + torch.testing.assert_allclose(expected, energy) \ No newline at end of file From 53d1e4d1d00a846893f0d2404218a7692e26bcc9 Mon Sep 17 00:00:00 2001 From: Peter Eastman Date: Mon, 31 Oct 2022 11:20:26 -0700 Subject: [PATCH 08/12] Attempt at fixing test failure on CI --- tests/test_priors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_priors.py b/tests/test_priors.py index c0a192351..815e3bb6c 100644 --- a/tests/test_priors.py +++ b/tests/test_priors.py @@ -42,7 +42,7 @@ def test_zbl(): # Use the ZBL class to compute the energy. zbl = ZBL(10.0, 5, atomic_number, distance_scale=distance_scale, energy_scale=energy_scale) - energy = zbl.post_reduce(torch.zeros((1,)), types, pos, torch.zeros(types.shape))[0] + energy = zbl.post_reduce(torch.zeros((1,)), types, pos, torch.zeros_like(types))[0] # Compare to the expected value. From 7016f5fbccca25b7e63d89d31c2d582805896905 Mon Sep 17 00:00:00 2001 From: Peter Eastman Date: Wed, 2 Nov 2022 15:23:37 -0700 Subject: [PATCH 09/12] Support multiple prior models --- tests/test_module.py | 2 +- torchmdnet/models/model.py | 55 ++++++++++++++++++++++++++--------- torchmdnet/priors/__init__.py | 2 ++ torchmdnet/scripts/train.py | 18 +++--------- torchmdnet/utils.py | 2 +- 5 files changed, 49 insertions(+), 30 deletions(-) diff --git a/tests/test_module.py b/tests/test_module.py index c920c010e..17002d7c8 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -43,7 +43,7 @@ def test_train(model_name, use_atomref, tmpdir): prior = None if use_atomref: prior = getattr(priors, args["prior_model"])(dataset=datamodule.dataset) - args["prior_init_args"] = prior.get_init_args() + args["prior_args"] = prior.get_init_args() module = LNNP(args, prior_model=prior) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index be6e32dfb..bc93c4516 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -65,16 +65,8 @@ def create_model(args, prior_model=None, mean=None, std=None): # prior model if args["prior_model"] and prior_model is None: - assert "prior_init_args" in args, ( - f"Requested prior model {args['prior_model']} but the " - f'arguments are lacking the key "prior_init_args".' - ) - assert hasattr(priors, args["prior_model"]), ( - f'Unknown prior model {args["prior_model"]}. ' - f'Available models are {", ".join(priors.__all__)}' - ) # instantiate prior model if it was not passed to create_model (i.e. when loading a model) - prior_model = getattr(priors, args["prior_model"])(**args["prior_init_args"]) + prior_model = create_prior_models(args) # create output network output_prefix = "Equivariant" if is_equivariant else "" @@ -113,6 +105,36 @@ def load_model(filepath, args=None, device="cpu", **kwargs): return model.to(device) +def create_prior_models(args, dataset=None): + """Parse the prior_model configuration option and create the prior models.""" + prior_models = [] + if args.prior_model: + prior_model = args.prior_model + prior_names = [] + prior_args = [] + if not isinstance(prior_model, list): + prior_model = [prior_model] + for prior in prior_model: + if isinstance(prior, dict): + for key, value in prior.items(): + prior_names.append(key) + if value is None: + prior_args.append({}) + else: + prior_args.append(value) + else: + prior_names.append(prior) + prior_args.append({}) + for name, arg in zip(prior_names, prior_args): + assert hasattr(priors, name), ( + f"Unknown prior model {name}. " + f"Available models are {', '.join(priors.__all__)}" + ) + # initialize the prior model + prior_models.append(getattr(priors, name)(dataset=dataset, **arg)) + return prior_models + + class TorchMD_Net(nn.Module): def __init__( self, @@ -127,15 +149,17 @@ def __init__( self.representation_model = representation_model self.output_model = output_model - self.prior_model = prior_model if not output_model.allow_prior_model and prior_model is not None: - self.prior_model = None + prior_model = None rank_zero_warn( ( "Prior model was given but the output model does " "not allow prior models. Dropping the prior model." ) ) + if isinstance(prior_model, priors.base.BasePrior): + prior_model = [prior_model] + self.prior_model = None if prior_model is None else torch.nn.ModuleList(prior_model) self.derivative = derivative @@ -150,7 +174,8 @@ def reset_parameters(self): self.representation_model.reset_parameters() self.output_model.reset_parameters() if self.prior_model is not None: - self.prior_model.reset_parameters() + for prior in self.prior_model: + prior.reset_parameters() def forward( self, @@ -179,7 +204,8 @@ def forward( # apply atom-wise prior model if self.prior_model is not None: - x = self.prior_model.pre_reduce(x, z, pos, batch) + for prior in self.prior_model: + x = prior.pre_reduce(x, z, pos, batch) # aggregate atoms x = self.output_model.reduce(x, batch) @@ -193,7 +219,8 @@ def forward( # apply molecular-wise prior model if self.prior_model is not None: - y = self.prior_model.post_reduce(y, z, pos, batch) + for prior in self.prior_model: + y = prior.post_reduce(y, z, pos, batch) # compute gradients with respect to coordinates if self.derivative: diff --git a/torchmdnet/priors/__init__.py b/torchmdnet/priors/__init__.py index 3f961329d..c5cfecd94 100644 --- a/torchmdnet/priors/__init__.py +++ b/torchmdnet/priors/__init__.py @@ -1,2 +1,4 @@ from torchmdnet.priors.atomref import Atomref from torchmdnet.priors.zbl import ZBL + +__all__ = ['Atomref', 'ZBL'] \ No newline at end of file diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 879f4eda3..6f8a53bbe 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -11,6 +11,7 @@ from torchmdnet import datasets, priors, models from torchmdnet.data import DataModule from torchmdnet.models import output_modules +from torchmdnet.models.model import create_prior_models from torchmdnet.models.utils import rbf_class_mapping, act_class_mapping from torchmdnet.utils import LoadFromFile, LoadFromCheckpoint, save_argparse, number @@ -63,7 +64,6 @@ def get_args(): parser.add_argument('--model', type=str, default='graph-network', choices=models.__all__, help='Which model to train') parser.add_argument('--output-model', type=str, default='Scalar', choices=output_modules.__all__, help='The type of output model') parser.add_argument('--prior-model', type=str, default=None, choices=priors.__all__, help='Which prior model to use') - parser.add_argument('--prior-args', default=None, type=str, help='Additional arguments for the prior model. Need to be specified in JSON format i.e. \'{"cutoff_distance": 10.0, "max_num_neighbors": 100}\'') # architectural args parser.add_argument('--charge', type=bool, default=False, help='Model needs a total charge') @@ -119,21 +119,11 @@ def main(): data.prepare_data() data.setup("fit") - prior = None - if args.prior_model: - assert hasattr(priors, args.prior_model), ( - f"Unknown prior model {args['prior_model']}. " - f"Available models are {', '.join(priors.__all__)}" - ) - # initialize the prior model - prior_args = args.prior_args - if prior_args is None: - prior_args = {} - prior = getattr(priors, args.prior_model)(dataset=data.dataset, **prior_args) - args.prior_init_args = prior.get_init_args() + prior_models = create_prior_models(args, data.dataset) + args.prior_args = [p.get_init_args() for p in prior_models] # initialize lightning module - model = LNNP(args, prior_model=prior, mean=data.mean, std=data.std) + model = LNNP(args, prior_model=prior_models, mean=data.mean, std=data.std) checkpoint_callback = ModelCheckpoint( dirpath=args.log_dir, diff --git a/torchmdnet/utils.py b/torchmdnet/utils.py index 82321273b..c3fc67afe 100644 --- a/torchmdnet/utils.py +++ b/torchmdnet/utils.py @@ -176,7 +176,7 @@ def __call__(self, parser, namespace, values, option_string=None): with open(hparams_path, "r") as f: config = yaml.load(f, Loader=yaml.FullLoader) for key in config.keys(): - if key not in namespace and key != "prior_init_args": + if key not in namespace and key != "prior_args": raise ValueError(f"Unknown argument in the model checkpoint: {key}") namespace.__dict__.update(config) namespace.__dict__.update(load_model=values) From 7069d52cb248c78cc1cca31a24c02a0a79aafe5d Mon Sep 17 00:00:00 2001 From: Peter Eastman Date: Wed, 2 Nov 2022 17:30:42 -0700 Subject: [PATCH 10/12] Tests and fixes for multiple priors --- tests/priors.yaml | 56 ++++++++++++++++++++++++++++++++++++++ tests/test_priors.py | 39 ++++++++++++++++++++++++-- tests/utils.py | 9 ++++-- torchmdnet/models/model.py | 8 ++++-- 4 files changed, 106 insertions(+), 6 deletions(-) create mode 100644 tests/priors.yaml diff --git a/tests/priors.yaml b/tests/priors.yaml new file mode 100644 index 000000000..4ec754cac --- /dev/null +++ b/tests/priors.yaml @@ -0,0 +1,56 @@ +activation: silu +aggr: add +atom_filter: -1 +attn_activation: silu +batch_size: 128 +coord_files: null +cutoff_lower: 0.0 +cutoff_upper: 5.0 +derivative: false +distance_influence: both +early_stopping_patience: 150 +ema_alpha_neg_dy: 1.0 +ema_alpha_y: 1.0 +embed_files: null +embedding_dimension: 256 +energy_files: null +y_weight: 1.0 +force_files: null +neg_dy_weight: 1.0 +inference_batch_size: 128 +load_model: null +lr: 0.0004 +lr_factor: 0.8 +lr_min: 1.0e-07 +lr_patience: 15 +lr_warmup_steps: 10000 +max_num_neighbors: 64 +max_z: 100 +model: equivariant-transformer +neighbor_embedding: true +ngpus: -1 +num_epochs: 3000 +num_heads: 8 +num_layers: 8 +num_nodes: 1 +num_rbf: 64 +num_workers: 6 +output_model: Scalar +precision: 32 +prior_model: + - ZBL: + cutoff_distance: 4.0 + max_num_neighbors: 50 + - Atomref +rbf_type: expnorm +redirect: false +reduce_op: add +save_interval: 10 +splits: null +standardize: false +test_interval: 10 +test_size: null +train_size: 110000 +trainable_rbf: false +val_size: 10000 +weight_decay: 0.0 diff --git a/tests/test_priors.py b/tests/test_priors.py index 815e3bb6c..374718929 100644 --- a/tests/test_priors.py +++ b/tests/test_priors.py @@ -3,10 +3,13 @@ import torch import pytorch_lightning as pl from torchmdnet import models -from torchmdnet.models.model import create_model +from torchmdnet.models.model import create_model, create_prior_models +from torchmdnet.module import LNNP from torchmdnet.priors import Atomref, ZBL from torch_scatter import scatter from utils import load_example_args, create_example_batch, DummyDataset +from os.path import dirname, join +import tempfile @mark.parametrize("model_name", models.__all__) @@ -58,4 +61,36 @@ def compute_interaction(pos1, pos2, z1, z2): for i in range(len(pos)): for j in range(i): expected += compute_interaction(pos[i], pos[j], atomic_number[types[i]], atomic_number[types[j]]) - torch.testing.assert_allclose(expected, energy) \ No newline at end of file + torch.testing.assert_allclose(expected, energy) + +def test_multiple_priors(): + # Create a model from a config file. + + dataset = DummyDataset(has_atomref=True) + config_file = join(dirname(__file__), 'priors.yaml') + args = load_example_args('equivariant-transformer', config_file=config_file) + prior_models = create_prior_models(args, dataset) + args['prior_args'] = [p.get_init_args() for p in prior_models] + model = LNNP(args, prior_model=prior_models) + priors = model.model.prior_model + + # Make sure the priors were created correctly. + + assert len(priors) == 2 + assert isinstance(priors[0], ZBL) + assert isinstance(priors[1], Atomref) + assert priors[0].cutoff_distance == 4.0 + assert priors[0].max_num_neighbors == 50 + + # Save and load a checkpoint, and make sure the priors are correct. + + with tempfile.NamedTemporaryFile() as f: + torch.save(model, f) + f.seek(0) + model2 = torch.load(f) + priors2 = model2.model.prior_model + assert len(priors2) == 2 + assert isinstance(priors2[0], ZBL) + assert isinstance(priors2[1], Atomref) + assert priors2[0].cutoff_distance == priors[0].cutoff_distance + assert priors2[0].max_num_neighbors == priors[0].max_num_neighbors diff --git a/tests/utils.py b/tests/utils.py index d8cd8322b..b297b5217 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -4,8 +4,10 @@ from torch_geometric.data import Dataset, Data -def load_example_args(model_name, remove_prior=False, **kwargs): - with open(join(dirname(dirname(__file__)), "examples", "ET-QM9.yaml"), "r") as f: +def load_example_args(model_name, remove_prior=False, config_file=None, **kwargs): + if config_file is None: + config_file = join(dirname(dirname(__file__)), "examples", "ET-QM9.yaml") + with open(config_file, "r") as f: args = yaml.load(f, Loader=yaml.FullLoader) args["model"] = model_name args["seed"] = 1234 @@ -69,6 +71,9 @@ def _get_atomref(self): return self.atomref DummyDataset.get_atomref = _get_atomref + self.atomic_number = torch.arange(max(atom_types)+1) + self.distance_scale = 1.0 + self.energy_scale = 1.0 def get(self, idx): features = dict(z=self.z[idx].clone(), pos=self.pos[idx].clone()) diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index bc93c4516..2fede80b0 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -108,8 +108,8 @@ def load_model(filepath, args=None, device="cpu", **kwargs): def create_prior_models(args, dataset=None): """Parse the prior_model configuration option and create the prior models.""" prior_models = [] - if args.prior_model: - prior_model = args.prior_model + if args['prior_model']: + prior_model = args['prior_model'] prior_names = [] prior_args = [] if not isinstance(prior_model, list): @@ -125,6 +125,10 @@ def create_prior_models(args, dataset=None): else: prior_names.append(prior) prior_args.append({}) + if 'prior_args' in args: + prior_args = args['prior_args'] + if not isinstance(prior_args): + prior_args = [prior_args] for name, arg in zip(prior_names, prior_args): assert hasattr(priors, name), ( f"Unknown prior model {name}. " From 4dc5268f00e32491858284040861b125c718e1e4 Mon Sep 17 00:00:00 2001 From: Peter Eastman Date: Wed, 2 Nov 2022 19:12:30 -0700 Subject: [PATCH 11/12] Clarification to docstring --- torchmdnet/priors/zbl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/priors/zbl.py b/torchmdnet/priors/zbl.py index 5e735d7e4..544d25192 100644 --- a/torchmdnet/priors/zbl.py +++ b/torchmdnet/priors/zbl.py @@ -12,7 +12,7 @@ class ZBL(BasePrior): atomic_number: 1D tensor of length max_z. atomic_number[z] is the atomic number of atoms with atom type z. distance_scale: multiply by this factor to convert coordinates stored in the dataset to meters - energy_scale: multiply by this factor to convert energies stored in the dataset to Joules + energy_scale: multiply by this factor to convert energies stored in the dataset to Joules (*not* J/mol) """ def __init__(self, cutoff_distance, max_num_neighbors, atomic_number=None, distance_scale=None, energy_scale=None, dataset=None): super(ZBL, self).__init__() From 906693508143ad884c10036df0f84df0d7de2517 Mon Sep 17 00:00:00 2001 From: Peter Eastman Date: Thu, 3 Nov 2022 10:15:40 -0700 Subject: [PATCH 12/12] Fixed error --- torchmdnet/scripts/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 6f8a53bbe..62ead204b 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -119,7 +119,7 @@ def main(): data.prepare_data() data.setup("fit") - prior_models = create_prior_models(args, data.dataset) + prior_models = create_prior_models(vars(args), data.dataset) args.prior_args = [p.get_init_args() for p in prior_models] # initialize lightning module