From 223037f984e68bd42e318758296ee3bb33f35026 Mon Sep 17 00:00:00 2001 From: Raimondas Galvelis Date: Wed, 19 Jul 2023 18:49:01 +0200 Subject: [PATCH 1/3] Implement element filtering in the Ace datasets --- torchmdnet/datasets/ace.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torchmdnet/datasets/ace.py b/torchmdnet/datasets/ace.py index e2470dcb2..64197fbaa 100644 --- a/torchmdnet/datasets/ace.py +++ b/torchmdnet/datasets/ace.py @@ -15,15 +15,17 @@ def __init__( pre_transform=None, pre_filter=None, paths=None, + atomic_numbers=None, max_gradient=None, subsample_molecules=1, ): assert isinstance(paths, (str, list)) - arg_hash = f"{paths}{max_gradient}{subsample_molecules}" + arg_hash = f"{paths}{atomic_numbers}{max_gradient}{subsample_molecules}" arg_hash = hashlib.md5(arg_hash.encode()).hexdigest() self.name = f"{self.__class__.__name__}-{arg_hash}" self.paths = paths + self.atomic_numbers = atomic_numbers self.max_gradient = max_gradient self.subsample_molecules = int(subsample_molecules) super().__init__(root, transform, pre_transform, pre_filter) @@ -180,6 +182,11 @@ def sample_iter(self, mol_ids=False): fq = pt.tensor(mol["formal_charges"], dtype=pt.long) q = fq.sum() + # Keep molecules with specific elements + if self.atomic_numbers: + if not set(z.numpy()).issubset(self.atomic_numbers): + continue + for i_conf, (pos, y, neg_dy, pq, dp) in enumerate(load_confs(mol, n_atoms=len(z))): # Skip samples with large forces @@ -220,6 +227,7 @@ def processed_file_names(self): def process(self): print("Arguments") + print(f" atomic_numbers: {self.atomic_numbers}") print(f" max_gradient: {self.max_gradient} eV/A") print(f" subsample_molecules: {self.subsample_molecules}\n") From 32ce2dc1b9bcd31ab6960b815a73c8c3444cc869 Mon Sep 17 00:00:00 2001 From: Raimondas Galvelis Date: Mon, 24 Jul 2023 16:03:37 +0200 Subject: [PATCH 2/3] Make the Ace loader to load energy in float64 --- torchmdnet/datasets/ace.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/torchmdnet/datasets/ace.py b/torchmdnet/datasets/ace.py index 64197fbaa..ac3db7b3c 100644 --- a/torchmdnet/datasets/ace.py +++ b/torchmdnet/datasets/ace.py @@ -317,9 +317,7 @@ def get(self, idx): atoms = slice(self.idx_mm[idx], self.idx_mm[idx + 1]) z = pt.tensor(self.z_mm[atoms], dtype=pt.long) pos = pt.tensor(self.pos_mm[atoms], dtype=pt.float32) - y = pt.tensor(self.y_mm[idx], dtype=pt.float32).view( - 1, 1 - ) # It would be better to use float64, but the trainer complaints + y = pt.tensor(self.y_mm[idx], dtype=pt.float64).view(1, 1) neg_dy = pt.tensor(self.neg_dy_mm[atoms], dtype=pt.float32) q = pt.tensor(self.q_mm[idx], dtype=pt.long) pq = pt.tensor(self.pq_mm[atoms], dtype=pt.float32) From 8ce19c6d50ac4fd436a3fd8c78ebc9626e1b28c2 Mon Sep 17 00:00:00 2001 From: Raimondas Galvelis Date: Mon, 24 Jul 2023 16:04:30 +0200 Subject: [PATCH 3/3] Convert energy to right precision for training --- torchmdnet/module.py | 5 +++-- torchmdnet/scripts/train.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index a04142a61..8cdafc450 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -127,8 +127,9 @@ def step(self, batch, loss_fn, stage): if batch.y.ndim == 1: batch.y = batch.y.unsqueeze(1) - # y loss - loss_y = loss_fn(y, batch.y) + # y + y_dtype = {16: torch.float16, 32: torch.float32, 64: torch.float64}[self.hparams.precision] + loss_y = loss_fn(y, batch.y.to(y_dtype)) if stage in ["train", "val"] and self.hparams.ema_alpha_y < 1: if self.ema[stage + "_y"] is None: diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 6b120df86..4ad51275b 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -37,7 +37,7 @@ def get_args(): parser.add_argument('--ema-alpha-neg-dy', type=float, default=1.0, help='The amount of influence of new losses on the exponential moving average of dy') parser.add_argument('--ngpus', type=int, default=-1, help='Number of GPUs, -1 use all available. Use CUDA_VISIBLE_DEVICES=1, to decide gpus') parser.add_argument('--num-nodes', type=int, default=1, help='Number of nodes') - parser.add_argument('--precision', type=int, default=32, choices=[16, 32], help='Floating point precision') + parser.add_argument('--precision', type=int, default=32, choices=[16, 32, 64], help='Floating point precision') parser.add_argument('--log-dir', '-l', default='/tmp/logs', help='log file') parser.add_argument('--splits', default=None, help='Npz with splits idx_train, idx_val, idx_test') parser.add_argument('--train-size', type=number, default=None, help='Percentage/number of samples in training set (None to use all remaining samples)')