Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Metric with multiple input runs in an unexpected way. #2940

@lyhyl

Description

@lyhyl

❓ Questions/Help/Support

My customized loss requires two pairs of input:

class MyLoss(nn.Module):
    def __init__(self, ca: float = 1.0, cb: float = 1.0) -> None:
        super().__init__()
        self.ca = ca
        self.cb = cb

    def forward(self, y_pred: Tuple[torch.Tensor, torch.Tensor], y_true: Tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        a_true, b_true = y_true
        a_pred, b_pred = y_pred
        return self.ca * F.mse_loss(a_pred, a_true) + self.cb * F.cross_entropy(b_pred, b_true)

When I try to log the loss with Loss metric:

loss = MyLoss(0.5, 1.0)
metrics = {
    "Loss": Loss(loss)
}
train_evaluator = create_supervised_evaluator(model, metrics, device, prepare_batch=prepare_batch)

It will crash on line:

self.update((tensor_o1, tensor_o2))

because it treats all inputs as independent pair of y_pred and y, which is not what MyLoss need.

I dug into the source code I found #2055 introduces a new feature, which causes this issue.
So, what are the best practices for dealing with multiple input losses?

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions