@@ -34,17 +34,18 @@ def get_anchors(anchors_path):
34
34
def fit_ont_epoch (net ,yolo_losses ,epoch ,epoch_size ,epoch_size_val ,gen ,genval ,Epoch ,cuda ):
35
35
total_loss = 0
36
36
val_loss = 0
37
- for iteration in range (epoch_size ):
37
+ for iteration , batch in enumerate (gen ):
38
+ if iteration == epoch_size :
39
+ break
38
40
start_time = time .time ()
39
- images , targets = next ( gen )
41
+ images , targets = batch [ 0 ], batch [ 1 ]
40
42
with torch .no_grad ():
41
43
if cuda :
42
44
images = Variable (torch .from_numpy (images ).type (torch .FloatTensor )).cuda ()
43
45
targets = [Variable (torch .from_numpy (ann ).type (torch .FloatTensor )) for ann in targets ]
44
46
else :
45
47
images = Variable (torch .from_numpy (images ).type (torch .FloatTensor ))
46
48
targets = [Variable (torch .from_numpy (ann ).type (torch .FloatTensor )) for ann in targets ]
47
- # print(images)
48
49
optimizer .zero_grad ()
49
50
outputs = net (images )
50
51
losses = []
@@ -61,8 +62,10 @@ def fit_ont_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,genval,Epo
61
62
print ('iter:' + str (iteration ) + '/' + str (epoch_size ) + ' || Total Loss: %.4f || %.4fs/step' % (total_loss / (iteration + 1 ),waste_time ))
62
63
63
64
print ('Start Validation' )
64
- for iteration in range (epoch_size_val ):
65
- images_val , targets_val = next (genval )
65
+ for iteration , batch in enumerate (genval ):
66
+ if iteration == epoch_size_val :
67
+ break
68
+ images_val , targets_val = batch [0 ], batch [1 ]
66
69
67
70
with torch .no_grad ():
68
71
if cuda :
0 commit comments