-
Notifications
You must be signed in to change notification settings - Fork 83
Description
Hello! I am trying to wrap a lightning model with your package Laplace(). However my input X is a list of three tensors of different shapes (so they cannot be stacked into one torch tensor). This causes fit to break, as one of the first steps is to try to do X.to(device). I tried to pass it as a dictionary instead of a list (trying to replicate HuggingFace) but it still did not work. Do you have any advice on how I could make this work? Thank you so much! I would greatly appreciate the help. Here is the error:
TypeError Traceback (most recent call last)
File ~/SCI/lib/python3.9/site-packages/laplace/baselaplace.py:902, in ParametricLaplace.fit(self, train_loader, override, progress_bar)
901 try:
--> 902 out = self.model(X[:1].to(self._device))
903 except (TypeError, AttributeError):
TypeError: unhashable type: 'slice'
During handling of the above exception, another exception occurred:
AttributeError Traceback (most recent call last)
Cell In[55], line 1
----> 1 lp_model.fit(dl.train_dataloader())
File ~/SCI/lib/python3.9/site-packages/laplace/lllaplace.py:204, in LLLaplace.fit(self, train_loader, override, progress_bar)
201 self.prior_mean: float | torch.Tensor = self._prior_mean
202 self._init_H()
--> 204 super().fit(train_loader, override=override)
205 self.mean: torch.Tensor = parameters_to_vector(
206 self.model.last_layer.parameters()
207 )
209 if not self.enable_backprop:
File ~/SCI/lib/python3.9/site-packages/laplace/baselaplace.py:1741, in KronLaplace.fit(self, train_loader, override, progress_bar)
1736 # discount previous Kronecker factors to sum up properly together with new ones
1737 self.H_facs = self._rescale_factors(
1738 self.H_facs, n_data_old / (n_data_old + n_data_new)
1739 )
-> 1741 super().fit(train_loader, override=override, progress_bar=progress_bar)
1743 if self.H_facs is None:
1744 self.H_facs = self.H
File ~/SCI/lib/python3.9/site-packages/laplace/baselaplace.py:904, in ParametricLaplace.fit(self, train_loader, override, progress_bar)
902 out = self.model(X[:1].to(self._device))
903 except (TypeError, AttributeError):
--> 904 out = self.model(X.to(self._device))
905 self.n_outputs = out.shape[-1]
906 setattr(self.model, "output_size", self.n_outputs)
AttributeError: 'dict' object has no attribute 'to'