4646import tempfile
4747from torch .nn import TransformerEncoder , TransformerEncoderLayer
4848
49- if sys .platform == 'win32' :
50- print ('Windows platform is not supported for pipeline parallelism' )
51- sys .exit (0 )
52- if torch .cuda .device_count () < 2 :
53- print ('Need at least two GPU devices for this tutorial' )
54- sys .exit (0 )
55-
5649class Encoder (nn .Module ):
5750 def __init__ (self , ntoken , ninp , dropout = 0.5 ):
5851 super (Encoder , self ).__init__ ()
@@ -162,38 +155,40 @@ def forward(self, x):
162155from torchtext .data .utils import get_tokenizer
163156from torchtext .vocab import build_vocab_from_iterator
164157
165- url = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip'
166- test_filepath , valid_filepath , train_filepath = extract_archive (download_from_url (url ))
167- tokenizer = get_tokenizer ('basic_english' )
168- vocab = build_vocab_from_iterator (map (tokenizer ,
169- iter (io .open (train_filepath ,
170- encoding = "utf8" ))))
158+ def run ():
159+
160+ url = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip'
161+ test_filepath , valid_filepath , train_filepath = extract_archive (download_from_url (url ))
162+ tokenizer = get_tokenizer ('basic_english' )
163+ vocab = build_vocab_from_iterator (map (tokenizer ,
164+ iter (io .open (train_filepath ,
165+ encoding = "utf8" ))))
171166
172- def data_process (raw_text_iter ):
173- data = [torch .tensor ([vocab [token ] for token in tokenizer (item )],
174- dtype = torch .long ) for item in raw_text_iter ]
175- return torch .cat (tuple (filter (lambda t : t .numel () > 0 , data )))
167+ def data_process (raw_text_iter ):
168+ data = [torch .tensor ([vocab [token ] for token in tokenizer (item )],
169+ dtype = torch .long ) for item in raw_text_iter ]
170+ return torch .cat (tuple (filter (lambda t : t .numel () > 0 , data )))
176171
177- train_data = data_process (iter (io .open (train_filepath , encoding = "utf8" )))
178- val_data = data_process (iter (io .open (valid_filepath , encoding = "utf8" )))
179- test_data = data_process (iter (io .open (test_filepath , encoding = "utf8" )))
172+ train_data = data_process (iter (io .open (train_filepath , encoding = "utf8" )))
173+ val_data = data_process (iter (io .open (valid_filepath , encoding = "utf8" )))
174+ test_data = data_process (iter (io .open (test_filepath , encoding = "utf8" )))
180175
181- device = torch .device ("cuda" )
176+ device = torch .device ("cuda" )
182177
183- def batchify (data , bsz ):
184- # Divide the dataset into bsz parts.
185- nbatch = data .size (0 ) // bsz
186- # Trim off any extra elements that wouldn't cleanly fit (remainders).
187- data = data .narrow (0 , 0 , nbatch * bsz )
188- # Evenly divide the data across the bsz batches.
189- data = data .view (bsz , - 1 ).t ().contiguous ()
190- return data .to (device )
178+ def batchify (data , bsz ):
179+ # Divide the dataset into bsz parts.
180+ nbatch = data .size (0 ) // bsz
181+ # Trim off any extra elements that wouldn't cleanly fit (remainders).
182+ data = data .narrow (0 , 0 , nbatch * bsz )
183+ # Evenly divide the data across the bsz batches.
184+ data = data .view (bsz , - 1 ).t ().contiguous ()
185+ return data .to (device )
191186
192- batch_size = 20
193- eval_batch_size = 10
194- train_data = batchify (train_data , batch_size )
195- val_data = batchify (val_data , eval_batch_size )
196- test_data = batchify (test_data , eval_batch_size )
187+ batch_size = 20
188+ eval_batch_size = 10
189+ train_data = batchify (train_data , batch_size )
190+ val_data = batchify (val_data , eval_batch_size )
191+ test_data = batchify (test_data , eval_batch_size )
197192
198193
199194######################################################################
@@ -216,12 +211,12 @@ def batchify(data, bsz):
216211# ``N`` is along dimension 1.
217212#
218213
219- bptt = 35
220- def get_batch (source , i ):
221- seq_len = min (bptt , len (source ) - 1 - i )
222- data = source [i :i + seq_len ]
223- target = source [i + 1 :i + 1 + seq_len ].view (- 1 )
224- return data , target
214+ bptt = 35
215+ def get_batch (source , i ):
216+ seq_len = min (bptt , len (source ) - 1 - i )
217+ data = source [i :i + seq_len ]
218+ target = source [i + 1 :i + 1 + seq_len ].view (- 1 )
219+ return data , target
225220
226221######################################################################
227222# Model scale and Pipe initialization
@@ -251,58 +246,58 @@ def get_batch(source, i):
251246# allows the Pipe to work with only two partitions and avoid any
252247# cross-partition overheads.
253248
254- ntokens = len (vocab .stoi ) # the size of vocabulary
255- emsize = 4096 # embedding dimension
256- nhid = 4096 # the dimension of the feedforward network model in nn.TransformerEncoder
257- nlayers = 16 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
258- nhead = 16 # the number of heads in the multiheadattention models
259- dropout = 0.2 # the dropout value
260-
261- from torch .distributed import rpc
262- tmpfile = tempfile .NamedTemporaryFile ()
263- rpc .init_rpc (
264- name = "worker" ,
265- rank = 0 ,
266- world_size = 1 ,
267- rpc_backend_options = rpc .TensorPipeRpcBackendOptions (
268- init_method = "file://{}" .format (tmpfile .name ),
269- _transports = ["ibv" , "uv" ],
249+ ntokens = len (vocab .stoi ) # the size of vocabulary
250+ emsize = 4096 # embedding dimension
251+ nhid = 4096 # the dimension of the feedforward network model in nn.TransformerEncoder
252+ nlayers = 16 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
253+ nhead = 16 # the number of heads in the multiheadattention models
254+ dropout = 0.2 # the dropout value
255+
256+ from torch .distributed import rpc
257+ tmpfile = tempfile .NamedTemporaryFile ()
258+ rpc .init_rpc (
259+ name = "worker" ,
260+ rank = 0 ,
261+ world_size = 1 ,
262+ rpc_backend_options = rpc .TensorPipeRpcBackendOptions (
263+ init_method = "file://{}" .format (tmpfile .name ),
264+ _transports = ["ibv" , "uv" ],
265+ )
270266 )
271- )
272267
273- num_gpus = 2
274- partition_len = ((nlayers - 1 ) // num_gpus ) + 1
268+ num_gpus = 2
269+ partition_len = ((nlayers - 1 ) // num_gpus ) + 1
275270
276- # Add encoder in the beginning.
277- tmp_list = [Encoder (ntokens , emsize , dropout ).cuda (0 )]
278- module_list = []
271+ # Add encoder in the beginning.
272+ tmp_list = [Encoder (ntokens , emsize , dropout ).cuda (0 )]
273+ module_list = []
279274
280- # Add all the necessary transformer blocks.
281- for i in range (nlayers ):
282- transformer_block = TransformerEncoderLayer (emsize , nhead , nhid , dropout )
283- if i != 0 and i % (partition_len ) == 0 :
284- module_list .append (nn .Sequential (* tmp_list ))
285- tmp_list = []
286- device = i // (partition_len )
287- tmp_list .append (transformer_block .to (device ))
275+ # Add all the necessary transformer blocks.
276+ for i in range (nlayers ):
277+ transformer_block = TransformerEncoderLayer (emsize , nhead , nhid , dropout )
278+ if i != 0 and i % (partition_len ) == 0 :
279+ module_list .append (nn .Sequential (* tmp_list ))
280+ tmp_list = []
281+ device = i // (partition_len )
282+ tmp_list .append (transformer_block .to (device ))
288283
289- # Add decoder in the end.
290- tmp_list .append (Decoder (ntokens , emsize ).cuda (num_gpus - 1 ))
291- module_list .append (nn .Sequential (* tmp_list ))
284+ # Add decoder in the end.
285+ tmp_list .append (Decoder (ntokens , emsize ).cuda (num_gpus - 1 ))
286+ module_list .append (nn .Sequential (* tmp_list ))
292287
293- from torch .distributed .pipeline .sync import Pipe
288+ from torch .distributed .pipeline .sync import Pipe
294289
295- # Build the pipeline.
296- model = Pipe (torch .nn .Sequential (* module_list ), chunks = 8 )
290+ # Build the pipeline.
291+ model = Pipe (torch .nn .Sequential (* module_list ), chunks = 8 )
297292
298293
299- def get_total_params (module : torch .nn .Module ):
300- total_params = 0
301- for param in module .parameters ():
302- total_params += param .numel ()
303- return total_params
294+ def get_total_params (module : torch .nn .Module ):
295+ total_params = 0
296+ for param in module .parameters ():
297+ total_params += param .numel ()
298+ return total_params
304299
305- print ('Total parameters in model: {:,}' .format (get_total_params (model )))
300+ print ('Total parameters in model: {:,}' .format (get_total_params (model )))
306301
307302######################################################################
308303# Run the model
@@ -322,88 +317,88 @@ def get_total_params(module: torch.nn.Module):
322317# function to scale all the gradient together to prevent exploding.
323318#
324319
325- criterion = nn .CrossEntropyLoss ()
326- lr = 5.0 # learning rate
327- optimizer = torch .optim .SGD (model .parameters (), lr = lr )
328- scheduler = torch .optim .lr_scheduler .StepLR (optimizer , 1.0 , gamma = 0.95 )
329-
330- import time
331- def train ():
332- model .train () # Turn on the train mode
333- total_loss = 0.
334- start_time = time .time ()
335- ntokens = len (vocab .stoi )
336-
337- # Train only for 50 batches to keep script execution time low.
338- nbatches = min (50 * bptt , train_data .size (0 ) - 1 )
339-
340- for batch , i in enumerate (range (0 , nbatches , bptt )):
341- data , targets = get_batch (train_data , i )
342- optimizer .zero_grad ()
343- # Since the Pipe is only within a single host and process the ``RRef``
344- # returned by forward method is local to this node and can simply
345- # retrieved via ``RRef.local_value()``.
346- output = model (data ).local_value ()
347- # Need to move targets to the device where the output of the
348- # pipeline resides.
349- loss = criterion (output .view (- 1 , ntokens ), targets .cuda (1 ))
350- loss .backward ()
351- torch .nn .utils .clip_grad_norm_ (model .parameters (), 0.5 )
352- optimizer .step ()
353-
354- total_loss += loss .item ()
355- log_interval = 10
356- if batch % log_interval == 0 and batch > 0 :
357- cur_loss = total_loss / log_interval
358- elapsed = time .time () - start_time
359- print ('| epoch {:3d} | {:5d}/{:5d} batches | '
360- 'lr {:02.2f} | ms/batch {:5.2f} | '
361- 'loss {:5.2f} | ppl {:8.2f}' .format (
362- epoch , batch , nbatches // bptt , scheduler .get_lr ()[0 ],
363- elapsed * 1000 / log_interval ,
364- cur_loss , math .exp (cur_loss )))
365- total_loss = 0
366- start_time = time .time ()
367-
368- def evaluate (eval_model , data_source ):
369- eval_model .eval () # Turn on the evaluation mode
370- total_loss = 0.
371- ntokens = len (vocab .stoi )
372- # Evaluate only for 50 batches to keep script execution time low.
373- nbatches = min (50 * bptt , data_source .size (0 ) - 1 )
374- with torch .no_grad ():
375- for i in range (0 , nbatches , bptt ):
376- data , targets = get_batch (data_source , i )
377- output = eval_model (data ).local_value ()
378- output_flat = output .view (- 1 , ntokens )
320+ criterion = nn .CrossEntropyLoss ()
321+ lr = 5.0 # learning rate
322+ optimizer = torch .optim .SGD (model .parameters (), lr = lr )
323+ scheduler = torch .optim .lr_scheduler .StepLR (optimizer , 1.0 , gamma = 0.95 )
324+
325+ import time
326+ def train ():
327+ model .train () # Turn on the train mode
328+ total_loss = 0.
329+ start_time = time .time ()
330+ ntokens = len (vocab .stoi )
331+
332+ # Train only for 50 batches to keep script execution time low.
333+ nbatches = min (50 * bptt , train_data .size (0 ) - 1 )
334+
335+ for batch , i in enumerate (range (0 , nbatches , bptt )):
336+ data , targets = get_batch (train_data , i )
337+ optimizer .zero_grad ()
338+ # Since the Pipe is only within a single host and process the ``RRef``
339+ # returned by forward method is local to this node and can simply
340+ # retrieved via ``RRef.local_value()``.
341+ output = model (data ).local_value ()
379342 # Need to move targets to the device where the output of the
380343 # pipeline resides.
381- total_loss += len (data ) * criterion (output_flat , targets .cuda (1 )).item ()
382- return total_loss / (len (data_source ) - 1 )
344+ loss = criterion (output .view (- 1 , ntokens ), targets .cuda (1 ))
345+ loss .backward ()
346+ torch .nn .utils .clip_grad_norm_ (model .parameters (), 0.5 )
347+ optimizer .step ()
348+
349+ total_loss += loss .item ()
350+ log_interval = 10
351+ if batch % log_interval == 0 and batch > 0 :
352+ cur_loss = total_loss / log_interval
353+ elapsed = time .time () - start_time
354+ print ('| epoch {:3d} | {:5d}/{:5d} batches | '
355+ 'lr {:02.2f} | ms/batch {:5.2f} | '
356+ 'loss {:5.2f} | ppl {:8.2f}' .format (
357+ epoch , batch , nbatches // bptt , scheduler .get_lr ()[0 ],
358+ elapsed * 1000 / log_interval ,
359+ cur_loss , math .exp (cur_loss )))
360+ total_loss = 0
361+ start_time = time .time ()
362+
363+ def evaluate (eval_model , data_source ):
364+ eval_model .eval () # Turn on the evaluation mode
365+ total_loss = 0.
366+ ntokens = len (vocab .stoi )
367+ # Evaluate only for 50 batches to keep script execution time low.
368+ nbatches = min (50 * bptt , data_source .size (0 ) - 1 )
369+ with torch .no_grad ():
370+ for i in range (0 , nbatches , bptt ):
371+ data , targets = get_batch (data_source , i )
372+ output = eval_model (data ).local_value ()
373+ output_flat = output .view (- 1 , ntokens )
374+ # Need to move targets to the device where the output of the
375+ # pipeline resides.
376+ total_loss += len (data ) * criterion (output_flat , targets .cuda (1 )).item ()
377+ return total_loss / (len (data_source ) - 1 )
383378
384379######################################################################
385380# Loop over epochs. Save the model if the validation loss is the best
386381# we've seen so far. Adjust the learning rate after each epoch.
387382
388- best_val_loss = float ("inf" )
389- epochs = 3 # The number of epochs
390- best_model = None
383+ best_val_loss = float ("inf" )
384+ epochs = 3 # The number of epochs
385+ best_model = None
391386
392- for epoch in range (1 , epochs + 1 ):
393- epoch_start_time = time .time ()
394- train ()
395- val_loss = evaluate (model , val_data )
396- print ('-' * 89 )
397- print ('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
398- 'valid ppl {:8.2f}' .format (epoch , (time .time () - epoch_start_time ),
399- val_loss , math .exp (val_loss )))
400- print ('-' * 89 )
387+ for epoch in range (1 , epochs + 1 ):
388+ epoch_start_time = time .time ()
389+ train ()
390+ val_loss = evaluate (model , val_data )
391+ print ('-' * 89 )
392+ print ('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
393+ 'valid ppl {:8.2f}' .format (epoch , (time .time () - epoch_start_time ),
394+ val_loss , math .exp (val_loss )))
395+ print ('-' * 89 )
401396
402- if val_loss < best_val_loss :
403- best_val_loss = val_loss
404- best_model = model
397+ if val_loss < best_val_loss :
398+ best_val_loss = val_loss
399+ best_model = model
405400
406- scheduler .step ()
401+ scheduler .step ()
407402
408403
409404######################################################################
@@ -415,11 +410,11 @@ def evaluate(eval_model, data_source):
415410######################################################################
416411# Apply the best model to check the result with the test dataset.
417412
418- test_loss = evaluate (best_model , test_data )
419- print ('=' * 89 )
420- print ('| End of training | test loss {:5.2f} | test ppl {:8.2f}' .format (
421- test_loss , math .exp (test_loss )))
422- print ('=' * 89 )
413+ test_loss = evaluate (best_model , test_data )
414+ print ('=' * 89 )
415+ print ('| End of training | test loss {:5.2f} | test ppl {:8.2f}' .format (
416+ test_loss , math .exp (test_loss )))
417+ print ('=' * 89 )
423418
424419
425420######################################################################
@@ -456,3 +451,12 @@ def evaluate(eval_model, data_source):
456451# =========================================================================================
457452# | End of training | test loss 0.69 | test ppl 1.99
458453# =========================================================================================
454+
455+ if __name__ == "__main__" :
456+ if sys .platform == 'win32' :
457+ print ('Windows platform is not supported for pipeline parallelism' )
458+ sys .exit (0 )
459+ if torch .cuda .device_count () < 2 :
460+ print ('Need at least two GPU devices for this tutorial' )
461+ sys .exit (0 )
462+ run ()
0 commit comments