diff --git a/torchmdnet/models/model.py b/torchmdnet/models/model.py index 7ee44e361..e80c6b93d 100644 --- a/torchmdnet/models/model.py +++ b/torchmdnet/models/model.py @@ -11,18 +11,15 @@ import warnings -def create_model(args, prior_model=None, mean=None, std=None): - """Create a model from the given arguments. +def create_representation_model(args): + """Create a representation model from the given arguments. See :func:`get_args` in scripts/train.py for a description of the arguments. Parameters ---------- args (dict): Arguments for the model. - prior_model (nn.Module, optional): Prior model to use. Defaults to None. - mean (torch.Tensor, optional): Mean of the training data. Defaults to None. - std (torch.Tensor, optional): Standard deviation of the training data. Defaults to None. Returns ------- - nn.Module: An instance of the TorchMD_Net model. + nn.Module: An instance of the TorchMD_Net representation model. """ dtype = dtype_mapping[args["precision"]] shared_args = dict( @@ -38,12 +35,8 @@ def create_model(args, prior_model=None, mean=None, std=None): max_num_neighbors=args["max_num_neighbors"], dtype=dtype ) - - # representation network if args["model"] == "graph-network": from torchmdnet.models.torchmd_gn import TorchMD_GN - - is_equivariant = False representation_model = TorchMD_GN( num_filters=args["embedding_dimension"], aggr=args["aggr"], @@ -52,8 +45,6 @@ def create_model(args, prior_model=None, mean=None, std=None): ) elif args["model"] == "transformer": from torchmdnet.models.torchmd_t import TorchMD_T - - is_equivariant = False representation_model = TorchMD_T( attn_activation=args["attn_activation"], num_heads=args["num_heads"], @@ -63,8 +54,6 @@ def create_model(args, prior_model=None, mean=None, std=None): ) elif args["model"] == "equivariant-transformer": from torchmdnet.models.torchmd_et import TorchMD_ET - - is_equivariant = True representation_model = TorchMD_ET( attn_activation=args["attn_activation"], num_heads=args["num_heads"], @@ -74,28 +63,41 @@ def create_model(args, prior_model=None, mean=None, std=None): ) elif args["model"] == "tensornet": from torchmdnet.models.tensornet import TensorNet - - # Setting is_equivariant to False to enforce the use of Scalar output module instead of EquivariantScalar - is_equivariant = False representation_model = TensorNet( equivariance_invariance_group=args["equivariance_invariance_group"], **shared_args, ) else: raise ValueError(f'Unknown architecture: {args["model"]}') + return representation_model +def create_model(args, prior_model=None, mean=None, std=None): + """Create a model from the given arguments. + See :func:`get_args` in scripts/train.py for a description of the arguments. + Parameters + ---------- + args (dict): Arguments for the model. + prior_model (nn.Module, optional): Prior model to use. Defaults to None. + mean (torch.Tensor, optional): Mean of the training data. Defaults to None. + std (torch.Tensor, optional): Standard deviation of the training data. Defaults to None. + Returns + ------- + nn.Module: An instance of the TorchMD_Net model. + """ + dtype = dtype_mapping[args["precision"]] + # representation network + representation_model = create_representation_model(args) # atom filter if not args["derivative"] and args["atom_filter"] > -1: representation_model = AtomFilter(representation_model, args["atom_filter"]) elif args["atom_filter"] > -1: raise ValueError("Derivative and atom filter can't be used together") - # prior model if args["prior_model"] and prior_model is None: # instantiate prior model if it was not passed to create_model (i.e. when loading a model) prior_model = create_prior_models(args) - # create output network + is_equivariant = args["model"] == "equivariant-transformer" output_prefix = "Equivariant" if is_equivariant else "" output_model = getattr(output_modules, output_prefix + args["output_model"])( args["embedding_dimension"], @@ -103,12 +105,11 @@ def create_model(args, prior_model=None, mean=None, std=None): reduce_op=args["reduce_op"], dtype=dtype, ) - # combine representation and output network - model = TorchMD_Net( + model = MultiHeadTorchMD_Net( representation_model, - output_model, - prior_model=prior_model, +# output_model, +# prior_model=prior_model, mean=mean, std=std, derivative=args["derivative"], @@ -118,6 +119,7 @@ def create_model(args, prior_model=None, mean=None, std=None): def load_model(filepath, args=None, device="cpu", **kwargs): + raise NotImplementedError("load_model is not implemented yet") ckpt = torch.load(filepath, map_location="cpu") if args is None: args = ckpt["hyper_parameters"] @@ -297,3 +299,138 @@ def forward( return y, -dy # TODO: return only `out` once Union typing works with TorchScript (https://github.com/pytorch/pytorch/pull/53180) return y, None + +from torchmdnet.models.utils import scatter, act_class_mapping + +class BaseHead(nn.Module): + def __init__(self, dtype=torch.float32): + super(BaseHead, self).__init__() + self.dtype = dtype + + def reset_parameters(self): + pass + + def per_point(self, point_features, results, z, pos, batch, extra_args): + return point_features, results + + def per_sample(self, point_features, results, z, pos, batch, extra_args): + return point_features, results + +class EnergyHead(BaseHead): + def __init__(self, + hidden_channels, + activation="silu", + dtype=torch.float32): + super(EnergyHead, self).__init__(dtype=dtype) + act_class = act_class_mapping[activation] + self.output_network = nn.Sequential( + nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype), + act_class(), + nn.Linear(hidden_channels // 2, 1, dtype=dtype), + ) + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.output_network[0].weight) + self.output_network[0].bias.data.fill_(0) + nn.init.xavier_uniform_(self.output_network[2].weight) + self.output_network[2].bias.data.fill_(0) + + def per_point(self, point_features, results, z, pos, batch, extra_args): + results["energy"] = self.output_network(point_features) + return point_features, results + + def per_sample(self, point_features, results, z, pos, batch, extra_args): + results["energy"] = scatter(results["energy"], batch, dim=0) + return point_features, results + +class PointChargeHead(BaseHead): + def __init__(self, + hidden_channels, + activation="silu", + dtype=torch.float32): + super(PointChargeHead, self).__init__(dtype=dtype) + act_class = act_class_mapping[activation] + self.output_network = nn.Sequential( + nn.Linear(hidden_channels, hidden_channels // 2, dtype=dtype), + act_class(), + nn.Linear(hidden_channels // 2, 1, dtype=dtype), + ) + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.output_network[0].weight) + self.output_network[0].bias.data.fill_(0) + nn.init.xavier_uniform_(self.output_network[2].weight) + self.output_network[2].bias.data.fill_(0) + + def per_point(self, point_features, results, z, pos, batch, extra_args): + results["charge"] = self.output_network(point_features) + return point_features, results + + def per_sample(self, point_features, results, z, pos, batch, extra_args): + return point_features, results + +class ForceHead(BaseHead): + def __init__(self, + dtype=torch.float32): + super(ForceHead, self).__init__(dtype=dtype) + pass + + def per_sample(self, point_features, results, z, pos, batch, extra_args): + grad_outputs: List[Optional[torch.Tensor]] = [torch.ones_like(results["energy"])] + results["force"] = -grad([results["energy"]], + [pos], + grad_outputs=grad_outputs, + create_graph=self.training, + retain_graph=self.training)[0] + return point_features, results + +class MultiHeadTorchMD_Net(nn.Module): + def __init__( + self, + representation_model, + head_list = None, + mean=None, + std=None, + derivative=False, + dtype=torch.float32, + ): + super(MultiHeadTorchMD_Net, self).__init__() + self.representation_model = representation_model.to(dtype=dtype) + self.derivative = derivative + self.head_list = nn.ModuleList([EnergyHead(representation_model.hidden_channels, dtype=dtype)]) + if derivative: + self.head_list.append(ForceHead(dtype=dtype)) + mean = torch.scalar_tensor(0) if mean is None else mean + self.register_buffer("mean", mean.to(dtype=dtype)) + std = torch.scalar_tensor(1) if std is None else std + self.register_buffer("std", std.to(dtype=dtype)) + self.reset_parameters() + + def reset_parameters(self): + self.representation_model.reset_parameters() + for head in self.head_list: + head.reset_parameters() + + def forward( + self, + z: Tensor, + pos: Tensor, + batch: Optional[Tensor] = None, + q: Optional[Tensor] = None, + s: Optional[Tensor] = None, + extra_args: Optional[Dict[str, Tensor]] = None + ) -> Dict[str, Tensor]: + assert z.dim() == 1 and z.dtype == torch.long + batch = torch.zeros_like(z) if batch is None else batch + if self.derivative: + pos.requires_grad_(True) + results = {} + # run the potentially wrapped representation model + point_features = self.representation_model(z, pos, batch, q=q, s=s) + for head in self.head_list: + point_features, results = head.per_point(point_features, results, z=z, pos=pos, batch=batch, extra_args=extra_args) + for head in self.head_list: + point_features, results = head.per_sample(point_features, results, z=z, pos=pos, batch=batch, extra_args=extra_args) + return results diff --git a/torchmdnet/module.py b/torchmdnet/module.py index e30a1e0ab..c675abcbf 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -29,6 +29,11 @@ def __init__(self, hparams, prior_model=None, mean=None, std=None): else: self.model = create_model(self.hparams, prior_model, mean, std) + self.weights = {} + for key in self.model.state_dict(): + if "_weight" in key: + self.weights[key] = self.model.state_dict()[key] + # initialize exponential smoothing self.ema = None self._reset_ema_dict() @@ -66,7 +71,7 @@ def forward( q: Optional[Tensor] = None, s: Optional[Tensor] = None, extra_args: Optional[Dict[str, Tensor]] = None, - ) -> Tuple[Tensor, Optional[Tensor]]: + ) -> Dict[str, Tensor]: return self.model(z, pos, batch=batch, q=q, s=s, extra_args=extra_args) def training_step(self, batch, batch_idx): @@ -87,29 +92,29 @@ def validation_step(self, batch, batch_idx, *args): def test_step(self, batch, batch_idx): return self.step(batch, [l1_loss], "test") - def _compute_losses(self, y, neg_y, batch, loss_fn, stage): - # Compute the loss for the predicted value and the negative derivative (if available) - # Args: - # y: predicted value - # neg_y: predicted negative derivative - # batch: batch of data - # loss_fn: loss function to compute - # Returns: - # loss_y: loss for the predicted value - # loss_neg_y: loss for the predicted negative derivative - loss_y, loss_neg_y = torch.tensor(0.0, device=self.device), torch.tensor( - 0.0, device=self.device - ) + def _compute_losses(self, outputs, batch, loss_fn, stage): + """ + Compute the losses for each model output. + + Args: + outputs: Dictionary of model outputs. + batch: Batch of data. + loss_fn: Loss function to compute. + stage: Current training stage. + + Returns: + losses: Dictionary of computed losses for each model output. + """ + losses = {} loss_name = loss_fn.__name__ - if self.hparams.derivative and "neg_dy" in batch: - loss_neg_y = loss_fn(neg_y, batch.neg_dy) - loss_neg_y = self._update_loss_with_ema( - stage, "neg_dy", loss_name, loss_neg_y - ) - if "y" in batch: - loss_y = loss_fn(y, batch.y) - loss_y = self._update_loss_with_ema(stage, "y", loss_name, loss_y) - return {"y": loss_y, "neg_dy": loss_neg_y} + for key in outputs: + if key in batch: + loss = loss_fn(outputs[key], getattr(batch, key)) + loss = self._update_loss_with_ema(stage, key, loss_name, loss) + losses[key] = loss + else: + raise ValueError(f"Reference values for '{key}' are missing in the batch") + return losses def _update_loss_with_ema(self, stage, type, loss_name, loss): # Update the loss using an exponential moving average when applicable @@ -147,7 +152,7 @@ def step(self, batch, loss_fn_list, stage): del extra_args[a] # TODO: the model doesn't necessarily need to return a derivative once # Union typing works under TorchScript (https://github.com/pytorch/pytorch/pull/53180) - y, neg_dy = self( + outputs = self( batch.z, batch.pos, batch=batch.batch, @@ -155,28 +160,22 @@ def step(self, batch, loss_fn_list, stage): s=batch.s if self.hparams.spin else None, extra_args=extra_args, ) - if self.hparams.derivative and "y" not in batch: - # "use" both outputs of the model's forward function but discard the first - # to only use the negative derivative and avoid 'Expected to have finished reduction - # in the prior iteration before starting a new one.', which otherwise get's - # thrown because of setting 'find_unused_parameters=False' in the DDPPlugin - neg_dy = neg_dy + y.sum() * 0 - if "y" in batch and batch.y.ndim == 1: - batch.y = batch.y.unsqueeze(1) + # Raul 2023, I don't think this is needed anymore + # if self.hparams.derivative and "energy" not in batch: + # # "use" both outputs of the model's forward function but discard the first + # # to only use the negative derivative and avoid 'Expected to have finished reduction + # # in the prior iteration before starting a new one.', which otherwise get's + # # thrown because of setting 'find_unused_parameters=False' in the DDPPlugin + # outputs["forces"] += outputs["energy"].sum() * 0 for loss_fn in loss_fn_list: - step_losses = self._compute_losses(y, neg_dy, batch, loss_fn, stage) + step_losses = self._compute_losses(outputs, batch, loss_fn, stage) loss_name = loss_fn.__name__ - if self.hparams.neg_dy_weight > 0: - self.losses[stage]["neg_dy"][loss_name].append( - step_losses["neg_dy"].detach() - ) - if self.hparams.y_weight > 0: - self.losses[stage]["y"][loss_name].append(step_losses["y"].detach()) - total_loss = ( - step_losses["y"] * self.hparams.y_weight - + step_losses["neg_dy"] * self.hparams.neg_dy_weight - ) + total_loss = torch.tensor(0.0, device=self.device, dtype=batch.pos.dtype) + for key, loss in step_losses.items(): + if key in self.weights and self.weights[key] > 0: + self.losses[stage][key][loss_name].append(loss.detach()) + total_loss += loss * self.weights[key] self.losses[stage]["total"][loss_name].append(total_loss.detach()) return total_loss @@ -218,8 +217,8 @@ def on_validation_epoch_end(self): "lr": self.trainer.optimizers[0].param_groups[0]["lr"], } result_dict.update(self._get_mean_loss_dict_for_type("total")) - result_dict.update(self._get_mean_loss_dict_for_type("y")) - result_dict.update(self._get_mean_loss_dict_for_type("neg_dy")) + for i in self.weights.keys(): + result_dict.update(self._get_mean_loss_dict_for_type(i)) self.log_dict(result_dict, sync_dist=True) self._reset_losses_dict() @@ -229,26 +228,28 @@ def on_test_epoch_end(self): if not self.trainer.sanity_checking: result_dict = {} result_dict.update(self._get_mean_loss_dict_for_type("total")) - result_dict.update(self._get_mean_loss_dict_for_type("y")) - result_dict.update(self._get_mean_loss_dict_for_type("neg_dy")) + for i in self.weights.keys(): + result_dict.update(self._get_mean_loss_dict_for_type(i)) # Get only test entries result_dict = {k: v for k, v in result_dict.items() if k.startswith("test")} self.log_dict(result_dict, sync_dist=True) def _reset_losses_dict(self): # Losses has an entry for each stage in ["train", "val", "test"] - # Each entry has an entry with "total", "y" and "neg_dy" + # Each entry has an entry with "total" and each of the weights keys (e.g. y, neg_dy) # Each of these entries has an entry for each loss_fn (e.g. mse_loss) # The loss_fn values are not known in advance self.losses = {} for stage in ["train", "val", "test"]: self.losses[stage] = {} - for loss_type in ["total", "y", "neg_dy"]: + types = ["total"] + list(self.weights.keys()) + for loss_type in types: self.losses[stage][loss_type] = defaultdict(list) def _reset_ema_dict(self): self.ema = {} for stage in ["train", "val"]: self.ema[stage] = {} - for loss_type in ["y", "neg_dy"]: + types = ["total"] + list(self.weights.keys()) + for loss_type in types: self.ema[stage][loss_type] = {}