Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0679271

Browse files
albanDbrianjo
andauthored
Cleanup of the neural style transfer to reduce memory usage (#1665)
* No need to clone the model, it is used only once * cleanup use of .data * Set requires_grad properly to avoid computing un-needed gradients Co-authored-by: Brian Johnson <[email protected]>
1 parent 011ae8a commit 0679271

1 file changed

Lines changed: 11 additions & 5 deletions

File tree

advanced_source/neural_style_tutorial.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,6 @@ def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
309309
style_img, content_img,
310310
content_layers=content_layers_default,
311311
style_layers=style_layers_default):
312-
cnn = copy.deepcopy(cnn)
313-
314312
# normalization module
315313
normalization = Normalization(normalization_mean, normalization_std).to(device)
316314

@@ -394,7 +392,7 @@ def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
394392

395393
def get_input_optimizer(input_img):
396394
# this line to show that input is a parameter that requires a gradient
397-
optimizer = optim.LBFGS([input_img.requires_grad_()])
395+
optimizer = optim.LBFGS([input_img])
398396
return optimizer
399397

400398

@@ -418,6 +416,12 @@ def run_style_transfer(cnn, normalization_mean, normalization_std,
418416
print('Building the style transfer model..')
419417
model, style_losses, content_losses = get_style_model_and_losses(cnn,
420418
normalization_mean, normalization_std, style_img, content_img)
419+
420+
# We want to optimize the input and not the model parameters so we
421+
# update all the requires_grad fields accordingly
422+
input_img.requires_grad_(True)
423+
model.requires_grad_(False)
424+
421425
optimizer = get_input_optimizer(input_img)
422426

423427
print('Optimizing..')
@@ -426,7 +430,8 @@ def run_style_transfer(cnn, normalization_mean, normalization_std,
426430

427431
def closure():
428432
# correct the values of updated input image
429-
input_img.data.clamp_(0, 1)
433+
with torch.no_grad():
434+
input_img.clamp_(0, 1)
430435

431436
optimizer.zero_grad()
432437
model(input_img)
@@ -456,7 +461,8 @@ def closure():
456461
optimizer.step(closure)
457462

458463
# a last correction...
459-
input_img.data.clamp_(0, 1)
464+
with torch.no_grad():
465+
input_img.clamp_(0, 1)
460466

461467
return input_img
462468

0 commit comments

Comments
 (0)