-
-
Notifications
You must be signed in to change notification settings - Fork 657
Closed
Labels
Description
❓ Questions/Help/Support
My customized loss requires two pairs of input:
class MyLoss(nn.Module):
def __init__(self, ca: float = 1.0, cb: float = 1.0) -> None:
super().__init__()
self.ca = ca
self.cb = cb
def forward(self, y_pred: Tuple[torch.Tensor, torch.Tensor], y_true: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
a_true, b_true = y_true
a_pred, b_pred = y_pred
return self.ca * F.mse_loss(a_pred, a_true) + self.cb * F.cross_entropy(b_pred, b_true)When I try to log the loss with Loss metric:
loss = MyLoss(0.5, 1.0)
metrics = {
"Loss": Loss(loss)
}
train_evaluator = create_supervised_evaluator(model, metrics, device, prepare_batch=prepare_batch)It will crash on line:
ignite/ignite/metrics/metric.py
Line 308 in 4825bb6
| self.update((tensor_o1, tensor_o2)) |
because it treats all inputs as independent pair of y_pred and y, which is not what
MyLoss need.
I dug into the source code I found #2055 introduces a new feature, which causes this issue.
So, what are the best practices for dealing with multiple input losses?
vfdev-5