Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 9b9ed8a

Browse files
authored
Update train.py
1 parent 8a3769d commit 9b9ed8a

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

train.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,18 @@ def get_anchors(anchors_path):
3434
def fit_ont_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,genval,Epoch,cuda):
3535
total_loss = 0
3636
val_loss = 0
37-
for iteration in range(epoch_size):
37+
for iteration, batch in enumerate(gen):
38+
if iteration == epoch_size:
39+
break
3840
start_time = time.time()
39-
images, targets = next(gen)
41+
images, targets = batch[0], batch[1]
4042
with torch.no_grad():
4143
if cuda:
4244
images = Variable(torch.from_numpy(images).type(torch.FloatTensor)).cuda()
4345
targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets]
4446
else:
4547
images = Variable(torch.from_numpy(images).type(torch.FloatTensor))
4648
targets = [Variable(torch.from_numpy(ann).type(torch.FloatTensor)) for ann in targets]
47-
# print(images)
4849
optimizer.zero_grad()
4950
outputs = net(images)
5051
losses = []
@@ -61,8 +62,10 @@ def fit_ont_epoch(net,yolo_losses,epoch,epoch_size,epoch_size_val,gen,genval,Epo
6162
print('iter:' + str(iteration) + '/' + str(epoch_size) + ' || Total Loss: %.4f || %.4fs/step' % (total_loss/(iteration+1),waste_time))
6263

6364
print('Start Validation')
64-
for iteration in range(epoch_size_val):
65-
images_val, targets_val = next(genval)
65+
for iteration, batch in enumerate(genval):
66+
if iteration == epoch_size_val:
67+
break
68+
images_val, targets_val = batch[0], batch[1]
6669

6770
with torch.no_grad():
6871
if cuda:

0 commit comments

Comments
 (0)