132132from IPython .display import HTML
133133
134134# Set random seem for reproducibility
135- manualSeed = random .randint (1 , 10000 )
135+ manualSeed = 999
136+ #manualSeed = random.randint(1, 10000) # use if you want new results
136137print ("Random Seed: " , manualSeed )
137138random .seed (manualSeed )
138139torch .manual_seed (manualSeed )
@@ -369,7 +370,7 @@ def forward(self, input):
369370netG = Generator (ngpu ).to (device )
370371
371372# Handle multi-gpu if desired
372- if ('cuda' in str ( device ) ) and (ngpu > 1 ):
373+ if (device . type == 'cuda' ) and (ngpu > 1 ):
373374 netG = nn .DataParallel (netG , list (range (ngpu )))
374375
375376# Apply the weights_init function to randomly initialize all weights
@@ -440,7 +441,7 @@ def forward(self, input):
440441netD = Discriminator (ngpu ).to (device )
441442
442443# Handle multi-gpu if desired
443- if ('cuda' in str ( device ) ) and (ngpu > 1 ):
444+ if (device . type == 'cuda' ) and (ngpu > 1 ):
444445 netD = nn .DataParallel (netD , list (range (ngpu )))
445446
446447# Apply the weights_init function to randomly initialize all weights
@@ -592,7 +593,7 @@ def forward(self, input):
592593 b_size = real_cpu .size (0 )
593594 label = torch .full ((b_size ,), real_label , device = device )
594595 # Forward pass real batch through D
595- output = netD (real_cpu ).view (- 1 , 1 ). squeeze ( 1 )
596+ output = netD (real_cpu ).view (- 1 )
596597 # Calculate loss on all-real batch
597598 errD_real = criterion (output , label )
598599 # Calculate gradients for D in backward pass
@@ -606,7 +607,7 @@ def forward(self, input):
606607 fake = netG (noise )
607608 label .fill_ (fake_label )
608609 # Classify all fake batch with D
609- output = netD (fake .detach ()).view (- 1 , 1 ). squeeze ( 1 )
610+ output = netD (fake .detach ()).view (- 1 )
610611 # Calculate D's loss on the all-fake batch
611612 errD_fake = criterion (output , label )
612613 # Calculate the gradients for this batch
@@ -623,7 +624,7 @@ def forward(self, input):
623624 netG .zero_grad ()
624625 label .fill_ (real_label ) # fake labels are real for generator cost
625626 # Since we just updated D, perform another forward pass of all-fake batch through D
626- output = netD (fake ).view (- 1 , 1 ). squeeze ( 1 )
627+ output = netD (fake ).view (- 1 )
627628 # Calculate G's loss based on this output
628629 errG = criterion (output , label )
629630 # Calculate gradients for G
@@ -709,7 +710,7 @@ def forward(self, input):
709710plt .subplot (1 ,2 ,1 )
710711plt .axis ("off" )
711712plt .title ("Real Images" )
712- plt .imshow (np .transpose (vutils .make_grid (real_batch [0 ].to (device )[:64 ], padding = 5 , normalize = True ),(1 ,2 ,0 )))
713+ plt .imshow (np .transpose (vutils .make_grid (real_batch [0 ].to (device )[:64 ], padding = 5 , normalize = True ). cpu () ,(1 ,2 ,0 )))
713714
714715# Plot the fake images from the last epoch
715716plt .subplot (1 ,2 ,2 )
0 commit comments