Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 2dfdba2

Browse files
committed
minor bug fix
1 parent 20fb569 commit 2dfdba2

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pytorch/src/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def image_classification_test(loader, model, test_10crop=True, gpu=True):
104104
all_label = torch.cat((all_label, labels.data.float()), 0)
105105

106106
_, predict = torch.max(all_output, 1)
107-
accuracy = torch.sum(torch.squeeze(predict).float() == all_label) / float(all_label.size()[0])
107+
accuracy = torch.sum(torch.squeeze(predict).float() == all_label).item() / float(all_label.size()[0])
108108
return accuracy
109109

110110

0 commit comments

Comments
 (0)