Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d922f01

Browse files
committed
fix fp16 siou bug
1 parent 309b744 commit d922f01

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

nets/yolo_training.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def box_iou(self, b1, b2):
125125
#----------------------------------------------------#
126126
# 计算中心的距离
127127
#----------------------------------------------------#
128-
sigma = torch.pow(center_wh[..., 0] ** 2 + center_wh[..., 1] ** 2, 0.5)
128+
sigma = torch.pow(torch.sum(torch.pow(center_wh, 2), axis=-1), 0.5)
129129

130130
#----------------------------------------------------#
131131
# 求h和w方向上的sin比值
@@ -264,8 +264,9 @@ def forward(self, l, input, targets=None):
264264
# loss_cls 分类损失
265265
#---------------------------------------------------------------#
266266
iou = self.box_iou(pred_boxes, y_true[..., :4]).type_as(x)
267-
# loss_loc = torch.mean((1 - iou)[obj_mask] * box_loss_scale[obj_mask])
267+
obj_mask = obj_mask & torch.logical_not(torch.isnan(iou))
268268
loss_loc = torch.mean((1 - iou)[obj_mask])
269+
# loss_loc = torch.mean((1 - iou)[obj_mask] * box_loss_scale[obj_mask])
269270

270271
loss_cls = torch.mean(self.BCELoss(pred_cls[obj_mask], y_true[..., 5:][obj_mask]))
271272
loss += loss_loc * self.box_ratio + loss_cls * self.cls_ratio

utils/utils_fit.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,9 @@ def fit_one_epoch(model_train, model, yolo_loss, loss_history, eval_callback, op
6060
# 计算损失
6161
#----------------------#
6262
for l in range(len(outputs)):
63-
loss_item = yolo_loss(l, outputs[l], targets)
63+
with torch.cuda.amp.autocast(enabled=False):
64+
predication = outputs[l].float()
65+
loss_item = yolo_loss(l, predication, targets)
6466
loss_value_all += loss_item
6567
loss_value = loss_value_all
6668

0 commit comments

Comments
 (0)