Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit ad6ef01

Browse files
inkawhichJoelMarcey
authored andcommitted
tweaking dcgan tutorial
1 parent 882a0bd commit ad6ef01

1 file changed

Lines changed: 8 additions & 7 deletions

File tree

beginner_source/dcgan_faces_tutorial.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@
132132
from IPython.display import HTML
133133

134134
# Set random seem for reproducibility
135-
manualSeed = random.randint(1, 10000)
135+
manualSeed = 999
136+
#manualSeed = random.randint(1, 10000) # use if you want new results
136137
print("Random Seed: ", manualSeed)
137138
random.seed(manualSeed)
138139
torch.manual_seed(manualSeed)
@@ -369,7 +370,7 @@ def forward(self, input):
369370
netG = Generator(ngpu).to(device)
370371

371372
# Handle multi-gpu if desired
372-
if ('cuda' in str(device)) and (ngpu > 1):
373+
if (device.type == 'cuda') and (ngpu > 1):
373374
netG = nn.DataParallel(netG, list(range(ngpu)))
374375

375376
# Apply the weights_init function to randomly initialize all weights
@@ -440,7 +441,7 @@ def forward(self, input):
440441
netD = Discriminator(ngpu).to(device)
441442

442443
# Handle multi-gpu if desired
443-
if ('cuda' in str(device)) and (ngpu > 1):
444+
if (device.type == 'cuda') and (ngpu > 1):
444445
netD = nn.DataParallel(netD, list(range(ngpu)))
445446

446447
# Apply the weights_init function to randomly initialize all weights
@@ -592,7 +593,7 @@ def forward(self, input):
592593
b_size = real_cpu.size(0)
593594
label = torch.full((b_size,), real_label, device=device)
594595
# Forward pass real batch through D
595-
output = netD(real_cpu).view(-1, 1).squeeze(1)
596+
output = netD(real_cpu).view(-1)
596597
# Calculate loss on all-real batch
597598
errD_real = criterion(output, label)
598599
# Calculate gradients for D in backward pass
@@ -606,7 +607,7 @@ def forward(self, input):
606607
fake = netG(noise)
607608
label.fill_(fake_label)
608609
# Classify all fake batch with D
609-
output = netD(fake.detach()).view(-1, 1).squeeze(1)
610+
output = netD(fake.detach()).view(-1)
610611
# Calculate D's loss on the all-fake batch
611612
errD_fake = criterion(output, label)
612613
# Calculate the gradients for this batch
@@ -623,7 +624,7 @@ def forward(self, input):
623624
netG.zero_grad()
624625
label.fill_(real_label) # fake labels are real for generator cost
625626
# Since we just updated D, perform another forward pass of all-fake batch through D
626-
output = netD(fake).view(-1, 1).squeeze(1)
627+
output = netD(fake).view(-1)
627628
# Calculate G's loss based on this output
628629
errG = criterion(output, label)
629630
# Calculate gradients for G
@@ -709,7 +710,7 @@ def forward(self, input):
709710
plt.subplot(1,2,1)
710711
plt.axis("off")
711712
plt.title("Real Images")
712-
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True),(1,2,0)))
713+
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))
713714

714715
# Plot the fake images from the last epoch
715716
plt.subplot(1,2,2)

0 commit comments

Comments
 (0)