4040Let's start with the imports:
4141"""
4242from functools import partial
43- import numpy as np
4443import os
4544import torch
4645import torch .nn as nn
5049import torchvision
5150import torchvision .transforms as transforms
5251from ray import tune
53- from ray .tune import CLIReporter
52+ from ray .air import Checkpoint , session
5453from ray .tune .schedulers import ASHAScheduler
5554
5655######################################################################
6463
6564
6665def load_data (data_dir = "./data" ):
67- transform = transforms .Compose ([
68- transforms .ToTensor (),
69- transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))
70- ])
66+ transform = transforms .Compose (
67+ [transforms .ToTensor (), transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))]
68+ )
7169
7270 trainset = torchvision .datasets .CIFAR10 (
73- root = data_dir , train = True , download = True , transform = transform )
71+ root = data_dir , train = True , download = True , transform = transform
72+ )
7473
7574 testset = torchvision .datasets .CIFAR10 (
76- root = data_dir , train = False , download = True , transform = transform )
75+ root = data_dir , train = False , download = True , transform = transform
76+ )
7777
7878 return trainset , testset
7979
80+
8081######################################################################
8182# Configurable neural network
8283# ---------------------------
83- # We can only tune those parameters that are configurable. In this example, we can specify
84+ # We can only tune those parameters that are configurable.
85+ # In this example, we can specify
8486# the layer sizes of the fully connected layers:
8587
8688
@@ -97,32 +99,40 @@ def __init__(self, l1=120, l2=84):
9799 def forward (self , x ):
98100 x = self .pool (F .relu (self .conv1 (x )))
99101 x = self .pool (F .relu (self .conv2 (x )))
100- x = x . view ( - 1 , 16 * 5 * 5 )
102+ x = torch . flatten ( x , 1 ) # flatten all dimensions except batch
101103 x = F .relu (self .fc1 (x ))
102104 x = F .relu (self .fc2 (x ))
103105 x = self .fc3 (x )
104106 return x
105107
108+
106109######################################################################
107110# The train function
108111# ------------------
109112# Now it gets interesting, because we introduce some changes to the example `from the PyTorch
110113# documentation <https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html>`_.
111114#
112- # We wrap the training script in a function ``train_cifar(config, checkpoint_dir=None, data_dir=None)``.
113- # As you can guess, the ``config`` parameter will receive the hyperparameters we would like to
114- # train with. The ``checkpoint_dir`` parameter is used to restore checkpoints. The ``data_dir`` specifies
115- # the directory where we load and store the data, so multiple runs can share the same data source.
115+ # We wrap the training script in a function ``train_cifar(config, data_dir=None)``.
116+ # The ``config`` parameter will receive the hyperparameters we would like to
117+ # train with. The ``data_dir`` specifies the directory where we load and store the data,
118+ # so that multiple runs can share the same data source.
119+ # We also load the model and optimizer state at the start of the run, if a checkpoint
120+ # is provided. Further down in this tutorial you will find information on how
121+ # to save the checkpoint and what it is used for.
116122#
117123# .. code-block:: python
118124#
119125# net = Net(config["l1"], config["l2"])
120126#
121- # if checkpoint_dir:
122- # model_state, optimizer_state = torch.load(
123- # os.path.join(checkpoint_dir, "checkpoint"))
124- # net.load_state_dict(model_state)
125- # optimizer.load_state_dict(optimizer_state)
127+ # checkpoint = session.get_checkpoint()
128+ #
129+ # if checkpoint:
130+ # checkpoint_state = checkpoint.to_dict()
131+ # start_epoch = checkpoint_state["epoch"]
132+ # net.load_state_dict(checkpoint_state["net_state_dict"])
133+ # optimizer.load_state_dict(checkpoint_state["optimizer_state_dict"])
134+ # else:
135+ # start_epoch = 0
126136#
127137# The learning rate of the optimizer is made configurable, too:
128138#
@@ -171,11 +181,17 @@ def forward(self, x):
171181#
172182# .. code-block:: python
173183#
174- # with tune.checkpoint_dir(epoch) as checkpoint_dir:
175- # path = os.path.join(checkpoint_dir, "checkpoint")
176- # torch.save((net.state_dict(), optimizer.state_dict()), path)
184+ # checkpoint_data = {
185+ # "epoch": epoch,
186+ # "net_state_dict": net.state_dict(),
187+ # "optimizer_state_dict": optimizer.state_dict(),
188+ # }
189+ # checkpoint = Checkpoint.from_dict(checkpoint_data)
177190#
178- # tune.report(loss=(val_loss / val_steps), accuracy=correct / total)
191+ # session.report(
192+ # {"loss": val_loss / val_steps, "accuracy": correct / total},
193+ # checkpoint=checkpoint,
194+ # )
179195#
180196# Here we first save a checkpoint and then report some metrics back to Ray Tune. Specifically,
181197# we send the validation loss and accuracy back to Ray Tune. Ray Tune can then use these metrics
@@ -187,15 +203,16 @@ def forward(self, x):
187203# schedulers like
188204# `Population Based Training <https://docs.ray.io/en/master/tune/tutorials/tune-advanced-tutorial.html>`_.
189205# Also, by saving the checkpoint we can later load the trained models and validate them
190- # on a test set.
206+ # on a test set. Lastly, saving checkpoints is useful for fault tolerance, and it allows
207+ # us to interrupt training and continue training later.
191208#
192209# Full training function
193210# ~~~~~~~~~~~~~~~~~~~~~~
194211#
195212# The full code example looks like this:
196213
197214
198- def train_cifar (config , checkpoint_dir = None , data_dir = None ):
215+ def train_cifar (config , data_dir = None ):
199216 net = Net (config ["l1" ], config ["l2" ])
200217
201218 device = "cpu"
@@ -208,30 +225,31 @@ def train_cifar(config, checkpoint_dir=None, data_dir=None):
208225 criterion = nn .CrossEntropyLoss ()
209226 optimizer = optim .SGD (net .parameters (), lr = config ["lr" ], momentum = 0.9 )
210227
211- if checkpoint_dir :
212- model_state , optimizer_state = torch .load (
213- os .path .join (checkpoint_dir , "checkpoint" ))
214- net .load_state_dict (model_state )
215- optimizer .load_state_dict (optimizer_state )
228+ checkpoint = session .get_checkpoint ()
229+
230+ if checkpoint :
231+ checkpoint_state = checkpoint .to_dict ()
232+ start_epoch = checkpoint_state ["epoch" ]
233+ net .load_state_dict (checkpoint_state ["net_state_dict" ])
234+ optimizer .load_state_dict (checkpoint_state ["optimizer_state_dict" ])
235+ else :
236+ start_epoch = 0
216237
217238 trainset , testset = load_data (data_dir )
218239
219240 test_abs = int (len (trainset ) * 0.8 )
220241 train_subset , val_subset = random_split (
221- trainset , [test_abs , len (trainset ) - test_abs ])
242+ trainset , [test_abs , len (trainset ) - test_abs ]
243+ )
222244
223245 trainloader = torch .utils .data .DataLoader (
224- train_subset ,
225- batch_size = int (config ["batch_size" ]),
226- shuffle = True ,
227- num_workers = 8 )
246+ train_subset , batch_size = int (config ["batch_size" ]), shuffle = True , num_workers = 8
247+ )
228248 valloader = torch .utils .data .DataLoader (
229- val_subset ,
230- batch_size = int (config ["batch_size" ]),
231- shuffle = True ,
232- num_workers = 8 )
249+ val_subset , batch_size = int (config ["batch_size" ]), shuffle = True , num_workers = 8
250+ )
233251
234- for epoch in range (10 ): # loop over the dataset multiple times
252+ for epoch in range (start_epoch , 10 ): # loop over the dataset multiple times
235253 running_loss = 0.0
236254 epoch_steps = 0
237255 for i , data in enumerate (trainloader , 0 ):
@@ -252,8 +270,10 @@ def train_cifar(config, checkpoint_dir=None, data_dir=None):
252270 running_loss += loss .item ()
253271 epoch_steps += 1
254272 if i % 2000 == 1999 : # print every 2000 mini-batches
255- print ("[%d, %5d] loss: %.3f" % (epoch + 1 , i + 1 ,
256- running_loss / epoch_steps ))
273+ print (
274+ "[%d, %5d] loss: %.3f"
275+ % (epoch + 1 , i + 1 , running_loss / epoch_steps )
276+ )
257277 running_loss = 0.0
258278
259279 # Validation loss
@@ -275,13 +295,20 @@ def train_cifar(config, checkpoint_dir=None, data_dir=None):
275295 val_loss += loss .cpu ().numpy ()
276296 val_steps += 1
277297
278- with tune .checkpoint_dir (epoch ) as checkpoint_dir :
279- path = os .path .join (checkpoint_dir , "checkpoint" )
280- torch .save ((net .state_dict (), optimizer .state_dict ()), path )
281-
282- tune .report (loss = (val_loss / val_steps ), accuracy = correct / total )
298+ checkpoint_data = {
299+ "epoch" : epoch ,
300+ "net_state_dict" : net .state_dict (),
301+ "optimizer_state_dict" : optimizer .state_dict (),
302+ }
303+ checkpoint = Checkpoint .from_dict (checkpoint_data )
304+
305+ session .report (
306+ {"loss" : val_loss / val_steps , "accuracy" : correct / total },
307+ checkpoint = checkpoint ,
308+ )
283309 print ("Finished Training" )
284310
311+
285312######################################################################
286313# As you can see, most of the code is adapted directly from the original example.
287314#
@@ -296,7 +323,8 @@ def test_accuracy(net, device="cpu"):
296323 trainset , testset = load_data ()
297324
298325 testloader = torch .utils .data .DataLoader (
299- testset , batch_size = 4 , shuffle = False , num_workers = 2 )
326+ testset , batch_size = 4 , shuffle = False , num_workers = 2
327+ )
300328
301329 correct = 0
302330 total = 0
@@ -311,6 +339,7 @@ def test_accuracy(net, device="cpu"):
311339
312340 return correct / total
313341
342+
314343######################################################################
315344# The function also expects a ``device`` parameter, so we can do the
316345# test set validation on a GPU.
@@ -322,14 +351,14 @@ def test_accuracy(net, device="cpu"):
322351# .. code-block:: python
323352#
324353# config = {
325- # "l1": tune.sample_from(lambda _: 2**np.random.randint(2, 9) ),
326- # "l2": tune.sample_from(lambda _: 2**np.random.randint(2, 9) ),
354+ # "l1": tune.choice([2 ** i for i in range(9)] ),
355+ # "l2": tune.choice([2 ** i for i in range(9)] ),
327356# "lr": tune.loguniform(1e-4, 1e-1),
328357# "batch_size": tune.choice([2, 4, 8, 16])
329358# }
330359#
331- # The ``tune.sample_from ()`` function makes it possible to define your own sample
332- # methods to obtain hyperparameters. In this example, the ``l1`` and ``l2`` parameters
360+ # The ``tune.choice ()`` accepts a list of values that are uniformly sampled from.
361+ # In this example, the ``l1`` and ``l2`` parameters
333362# should be powers of 2 between 4 and 256, so either 4, 8, 16, 32, 64, 128, or 256.
334363# The ``lr`` (learning rate) should be uniformly sampled between 0.0001 and 0.1. Lastly,
335364# the batch size is a choice between 2, 4, 8, and 16.
@@ -353,7 +382,6 @@ def test_accuracy(net, device="cpu"):
353382# config=config,
354383# num_samples=num_samples,
355384# scheduler=scheduler,
356- # progress_reporter=reporter,
357385# checkpoint_at_end=True)
358386#
359387# You can specify the number of CPUs, which are then available e.g.
@@ -377,34 +405,30 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
377405 data_dir = os .path .abspath ("./data" )
378406 load_data (data_dir )
379407 config = {
380- "l1" : tune .sample_from ( lambda _ : 2 ** np . random . randint ( 2 , 9 ) ),
381- "l2" : tune .sample_from ( lambda _ : 2 ** np . random . randint ( 2 , 9 ) ),
408+ "l1" : tune .choice ([ 2 ** i for i in range ( 9 )] ),
409+ "l2" : tune .choice ([ 2 ** i for i in range ( 9 )] ),
382410 "lr" : tune .loguniform (1e-4 , 1e-1 ),
383- "batch_size" : tune .choice ([2 , 4 , 8 , 16 ])
411+ "batch_size" : tune .choice ([2 , 4 , 8 , 16 ]),
384412 }
385413 scheduler = ASHAScheduler (
386414 metric = "loss" ,
387415 mode = "min" ,
388416 max_t = max_num_epochs ,
389417 grace_period = 1 ,
390- reduction_factor = 2 )
391- reporter = CLIReporter (
392- # ``parameter_columns=["l1", "l2", "lr", "batch_size"]``,
393- metric_columns = ["loss" , "accuracy" , "training_iteration" ])
418+ reduction_factor = 2 ,
419+ )
394420 result = tune .run (
395421 partial (train_cifar , data_dir = data_dir ),
396422 resources_per_trial = {"cpu" : 2 , "gpu" : gpus_per_trial },
397423 config = config ,
398424 num_samples = num_samples ,
399425 scheduler = scheduler ,
400- progress_reporter = reporter )
426+ )
401427
402428 best_trial = result .get_best_trial ("loss" , "min" , "last" )
403- print ("Best trial config: {}" .format (best_trial .config ))
404- print ("Best trial final validation loss: {}" .format (
405- best_trial .last_result ["loss" ]))
406- print ("Best trial final validation accuracy: {}" .format (
407- best_trial .last_result ["accuracy" ]))
429+ print (f"Best trial config: { best_trial .config } " )
430+ print (f"Best trial final validation loss: { best_trial .last_result ['loss' ]} " )
431+ print (f"Best trial final validation accuracy: { best_trial .last_result ['accuracy' ]} " )
408432
409433 best_trained_model = Net (best_trial .config ["l1" ], best_trial .config ["l2" ])
410434 device = "cpu"
@@ -414,10 +438,10 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
414438 best_trained_model = nn .DataParallel (best_trained_model )
415439 best_trained_model .to (device )
416440
417- best_checkpoint_dir = best_trial .checkpoint .value
418- model_state , optimizer_state = torch . load ( os . path . join (
419- best_checkpoint_dir , "checkpoint" ))
420- best_trained_model .load_state_dict (model_state )
441+ best_checkpoint = best_trial .checkpoint .to_air_checkpoint ()
442+ best_checkpoint_data = best_checkpoint . to_dict ()
443+
444+ best_trained_model .load_state_dict (best_checkpoint_data [ "net_state_dict" ] )
421445
422446 test_acc = test_accuracy (best_trained_model , device )
423447 print ("Best trial test set accuracy: {}" .format (test_acc ))
@@ -428,6 +452,7 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
428452 # Fixes ``AttributeError: '_LoggingTee' object has no attribute 'fileno'``.
429453 # This is only needed to run with sphinx-build.
430454 import sys
455+
431456 sys .stdout .fileno = lambda : False
432457 # sphinx_gallery_end_ignore
433458 # You can change the number of GPUs per trial here:
@@ -439,30 +464,29 @@ def main(num_samples=10, max_num_epochs=10, gpus_per_trial=2):
439464#
440465# ::
441466#
442- # Number of trials: 10 (10 TERMINATED)
443- # +-----+------+------+-------------+--------------+---------+------------+--------------------+
444- # | ... | l1 | l2 | lr | batch_size | loss | accuracy | training_iteration |
445- # |-----+------+------+-------------+--------------+---------+------------+--------------------|
446- # | ... | 64 | 4 | 0.00011629 | 2 | 1.87273 | 0.244 | 2 |
447- # | ... | 32 | 64 | 0.000339763 | 8 | 1.23603 | 0.567 | 8 |
448- # | ... | 8 | 16 | 0.00276249 | 16 | 1.1815 | 0.5836 | 10 |
449- # | ... | 4 | 64 | 0.000648721 | 4 | 1.31131 | 0.5224 | 8 |
450- # | ... | 32 | 16 | 0.000340753 | 8 | 1.26454 | 0.5444 | 8 |
451- # | ... | 8 | 4 | 0.000699775 | 8 | 1.99594 | 0.1983 | 2 |
452- # | ... | 256 | 8 | 0.0839654 | 16 | 2.3119 | 0.0993 | 1 |
453- # | ... | 16 | 128 | 0.0758154 | 16 | 2.33575 | 0.1327 | 1 |
454- # | ... | 16 | 8 | 0.0763312 | 16 | 2.31129 | 0.1042 | 4 |
455- # | ... | 128 | 16 | 0.000124903 | 4 | 2.26917 | 0.1945 | 1 |
456- # +-----+------+------+-------------+--------------+---------+------------+--------------------+
457- #
458- #
459- # Best trial config: {'l1': 8, 'l2': 16, 'lr': 0.00276249, 'batch_size': 16, 'data_dir': '...'}
460- # Best trial final validation loss: 1.181501
461- # Best trial final validation accuracy: 0.5836
462- # Best trial test set accuracy: 0.5806
467+ # Number of trials: 10/10 (10 TERMINATED)
468+ # +-----+--------------+------+------+-------------+--------+---------+------------+
469+ # | ... | batch_size | l1 | l2 | lr | iter | loss | accuracy |
470+ # |-----+--------------+------+------+-------------+--------+---------+------------|
471+ # | ... | 2 | 1 | 256 | 0.000668163 | 1 | 2.31479 | 0.0977 |
472+ # | ... | 4 | 64 | 8 | 0.0331514 | 1 | 2.31605 | 0.0983 |
473+ # | ... | 4 | 2 | 1 | 0.000150295 | 1 | 2.30755 | 0.1023 |
474+ # | ... | 16 | 32 | 32 | 0.0128248 | 10 | 1.66912 | 0.4391 |
475+ # | ... | 4 | 8 | 128 | 0.00464561 | 2 | 1.7316 | 0.3463 |
476+ # | ... | 8 | 256 | 8 | 0.00031556 | 1 | 2.19409 | 0.1736 |
477+ # | ... | 4 | 16 | 256 | 0.00574329 | 2 | 1.85679 | 0.3368 |
478+ # | ... | 8 | 2 | 2 | 0.00325652 | 1 | 2.30272 | 0.0984 |
479+ # | ... | 2 | 2 | 2 | 0.000342987 | 2 | 1.76044 | 0.292 |
480+ # | ... | 4 | 64 | 32 | 0.003734 | 8 | 1.53101 | 0.4761 |
481+ # +-----+--------------+------+------+-------------+--------+---------+------------+
482+ #
483+ # Best trial config: {'l1': 64, 'l2': 32, 'lr': 0.0037339984519545164, 'batch_size': 4}
484+ # Best trial final validation loss: 1.5310075663924216
485+ # Best trial final validation accuracy: 0.4761
486+ # Best trial test set accuracy: 0.4737
463487#
464488# Most trials have been stopped early in order to avoid wasting resources.
465- # The best performing trial achieved a validation accuracy of about 58 %, which could
489+ # The best performing trial achieved a validation accuracy of about 47 %, which could
466490# be confirmed on the test set.
467491#
468492# So that's it! You can now tune the parameters of your PyTorch models.
0 commit comments