-
Couldn't load subscription status.
- Fork 9
Open
Description
In the paper you said that you reset the codebook every 20 iterations to prevent codebook collapse. However, in the training loop
Lines 156 to 193 in 68d8500
| for nb_iter in tqdm(range(1, args.total_iter + 1)): | |
| gt_motion = next(train_loader_iter) | |
| gt_motion = gt_motion.cuda().float() # bs, nb_joints, joints_dim, seq_len | |
| if args.sep_uplow: | |
| pred_motion, loss_commit, perplexity = net(gt_motion, idx_noise=0) | |
| else: | |
| pred_motion, loss_commit, perplexity = net(gt_motion) | |
| loss_motion = Loss(pred_motion, gt_motion) | |
| loss_vel = Loss.forward_joint(pred_motion, gt_motion) | |
| loss = loss_motion + args.commit * loss_commit + args.loss_vel * loss_vel | |
| optimizer.zero_grad() | |
| loss.backward() | |
| optimizer.step() | |
| scheduler.step() | |
| avg_recons += loss_motion.item() | |
| avg_perplexity += perplexity.item() | |
| avg_commit += loss_commit.item() | |
| if nb_iter % args.print_iter == 0 : | |
| avg_recons /= args.print_iter | |
| avg_perplexity /= args.print_iter | |
| avg_commit /= args.print_iter | |
| writer.add_scalar('./Train/L1', avg_recons, nb_iter) | |
| writer.add_scalar('./Train/PPL', avg_perplexity, nb_iter) | |
| writer.add_scalar('./Train/Commit', avg_commit, nb_iter) | |
| logger.info(f"Train. Iter {nb_iter} : \t Commit. {avg_commit:.5f} \t PPL. {avg_perplexity:.2f} \t Recons. {avg_recons:.5f}") | |
| avg_recons, avg_perplexity, avg_commit = 0., 0., 0., | |
| if nb_iter % args.eval_iter==0 : | |
| best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, writer, logger = eval_trans.evaluation_vqvae(args.out_dir, val_loader, net, logger, writer, nb_iter, best_fid, best_iter, best_div, best_top1, best_top2, best_top3, best_matching, eval_wrapper=eval_wrapper) |
Metadata
Metadata
Assignees
Labels
No labels