@@ -73,8 +73,11 @@ def forward(self, x):
7373# The `nn.TransformerEncoder <https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoder.html>`__
7474# itself consists of ``nlayers`` of `nn.TransformerEncoderLayer <https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html>`__.
7575# As a result, our focus is on ``nn.TransformerEncoder`` and we split the model
76- # such that half of the ``nn.TransformerEncoderLayer`` are in ``TransformerModelStage1``
77- # and the other half are in ``TransformerModelStage2``.
76+ # such that half of the ``nn.TransformerEncoderLayer`` are on one GPU and the
77+ # other half are on another. To do this, we pull out the ``Encoder`` and
78+ # ``Decoder`` sections into seperate modules and then build an nn.Sequential
79+ # representing the original Transformer module.
80+
7881
7982if sys .platform == 'win32' :
8083 print ('Windows platform is not supported for pipeline parallelism' )
@@ -83,69 +86,46 @@ def forward(self, x):
8386 print ('Need at least four GPU devices for this tutorial' )
8487 sys .exit (0 )
8588
86- class TransformerModelStage1 (nn .Module ):
87-
88- def __init__ (self , ntoken , ninp , nhead , nhid , nlayers , dropout = 0.5 ):
89- super (TransformerModelStage1 , self ).__init__ ()
89+ class Encoder (nn .Module ):
90+ def __init__ (self , ntoken , ninp , dropout = 0.5 ):
91+ super (Encoder , self ).__init__ ()
9092 self .src_mask = None
9193 self .pos_encoder = PositionalEncoding (ninp , dropout )
92- encoder_layers = TransformerEncoderLayer (ninp , nhead , nhid , dropout )
93- self .transformer_encoder = TransformerEncoder (encoder_layers , nlayers )
9494 self .encoder = nn .Embedding (ntoken , ninp )
9595 self .ninp = ninp
96-
9796 self .init_weights ()
9897
98+ def init_weights (self ):
99+ initrange = 0.1
100+ self .encoder .weight .data .uniform_ (- initrange , initrange )
101+
99102 def _generate_square_subsequent_mask (self , sz ):
100103 mask = (torch .triu (torch .ones (sz , sz )) == 1 ).transpose (0 , 1 )
101104 mask = mask .float ().masked_fill (mask == 0 , float ('-inf' )).masked_fill (mask == 1 , float (0.0 ))
102105 return mask
103106
104- def init_weights (self ):
105- initrange = 0.1
106- self .encoder .weight .data .uniform_ (- initrange , initrange )
107-
108107 def forward (self , src ):
109108 if self .src_mask is None or self .src_mask .size (0 ) != src .size (0 ):
110109 device = src .device
111110 mask = self ._generate_square_subsequent_mask (src .size (0 )).to (device )
112111 self .src_mask = mask
113112
114113 src = self .encoder (src ) * math .sqrt (self .ninp )
115- src = self .pos_encoder (src )
116- output = self .transformer_encoder (src , self .src_mask )
117- return output
118-
119- class TransformerModelStage2 (nn .Module ):
114+ return self .pos_encoder (src )
120115
121- def __init__ (self , ntoken , ninp , nhead , nhid , nlayers , dropout = 0.5 ):
122- super (TransformerModelStage2 , self ).__init__ ()
123- self .src_mask = None
124- encoder_layers = TransformerEncoderLayer (ninp , nhead , nhid , dropout )
125- self .transformer_encoder = TransformerEncoder (encoder_layers , nlayers )
116+ class Decoder (nn .Module ):
117+ def __init__ (self , ntoken , ninp ):
118+ super (Decoder , self ).__init__ ()
126119 self .decoder = nn .Linear (ninp , ntoken )
127-
128120 self .init_weights ()
129121
130- def _generate_square_subsequent_mask (self , sz ):
131- mask = (torch .triu (torch .ones (sz , sz )) == 1 ).transpose (0 , 1 )
132- mask = mask .float ().masked_fill (mask == 0 , float ('-inf' )).masked_fill (mask == 1 , float (0.0 ))
133- return mask
134-
135122 def init_weights (self ):
136123 initrange = 0.1
137124 self .decoder .bias .data .zero_ ()
138125 self .decoder .weight .data .uniform_ (- initrange , initrange )
139126
140- def forward (self , src ):
141- if self .src_mask is None or self .src_mask .size (0 ) != src .size (0 ):
142- device = src .device
143- mask = self ._generate_square_subsequent_mask (src .size (0 )).to (device )
144- self .src_mask = mask
145-
146- output = self .transformer_encoder (src , self .src_mask )
147- output = self .decoder (output )
148- return output
127+ def forward (self , inp ):
128+ return self .decoder (inp )
149129
150130######################################################################
151131# Start multiple processes for training
@@ -306,20 +286,40 @@ def get_batch(source, i):
306286 world_size = 1 ,
307287 rpc_backend_options = rpc .TensorPipeRpcBackendOptions (
308288 init_method = "file://{}" .format (tmpfile .name ),
289+ # Specifying _transports and _channels is a workaround and we no longer
290+ # will have to specify _transports and _channels for PyTorch
291+ # versions >= 1.8.1
292+ _transports = ["ibv" , "uv" ],
293+ _channels = ["cuda_ipc" , "cuda_basic" ],
309294 )
310295 )
311296
297+ # Num gpus for model parallelism.
298+ num_gpus = 2
299+ partition_len = ((nlayers - 1 ) // num_gpus ) + 1
300+
301+ # Add encoder in the beginning.
302+ tmp_list = [Encoder (ntokens , emsize , dropout ).cuda (2 * rank )]
303+ module_list = []
304+
305+ # Add all the necessary transformer blocks.
306+ for i in range (nlayers ):
307+ transformer_block = TransformerEncoderLayer (emsize , nhead , nhid , dropout )
308+ if i != 0 and i % (partition_len ) == 0 :
309+ module_list .append (nn .Sequential (* tmp_list ))
310+ tmp_list = []
311+ device = i // (partition_len )
312+ tmp_list .append (transformer_block .to (2 * rank + device ))
313+
314+ # Add decoder in the end.
315+ tmp_list .append (Decoder (ntokens , emsize ).cuda (2 * rank + num_gpus - 1 ))
316+ module_list .append (nn .Sequential (* tmp_list ))
317+
312318 # Need to use 'checkpoint=never' since as of PyTorch 1.8, Pipe checkpointing
313319 # doesn't work with DDP.
314320 from torch .distributed .pipeline .sync import Pipe
315- model = Pipe (
316- torch .nn .Sequential (
317- TransformerModelStage1 (ntokens , emsize , nhead , nhid , int (nlayers / 2 ), dropout ).cuda (2 * rank ),
318- TransformerModelStage2 (ntokens , emsize , nhead , nhid , int (nlayers / 2 ), dropout ).cuda (2 * rank + 1 ),
319- ),
320- chunks = 8 ,
321- checkpoint = "never"
322- )
321+ model = Pipe (torch .nn .Sequential (
322+ * module_list ), chunks = 8 , checkpoint = "never" )
323323
324324 # Initialize process group and wrap model in DDP.
325325 from torch .nn .parallel import DistributedDataParallel
@@ -462,3 +462,62 @@ def evaluate(eval_model, data_source):
462462 world_size = 2
463463 mp .spawn (run_worker , args = (world_size , ), nprocs = world_size , join = True )
464464
465+
466+ ######################################################################
467+ # Output
468+ # ------
469+ #
470+
471+
472+ ######################################################################
473+ #.. code-block:: py
474+ #
475+ # [RANK 1]: Total parameters in model: 1,041,453,167
476+ # [RANK 0]: Total parameters in model: 1,041,453,167
477+ # [RANK 0]: | epoch 1 | 10/ 50 batches | lr 5.00 | ms/batch 1414.18 | loss 48.70 | ppl 1406154472673147092992.00
478+ # [RANK 1]: | epoch 1 | 10/ 50 batches | lr 5.00 | ms/batch 1414.42 | loss 48.49 | ppl 1146707511057334927360.00
479+ # [RANK 0]: | epoch 1 | 20/ 50 batches | lr 5.00 | ms/batch 1260.76 | loss 42.74 | ppl 3648812398518492672.00
480+ # [RANK 1]: | epoch 1 | 20/ 50 batches | lr 5.00 | ms/batch 1260.76 | loss 41.51 | ppl 1064844757565813248.00
481+ # [RANK 0]: | epoch 1 | 30/ 50 batches | lr 5.00 | ms/batch 1246.80 | loss 41.85 | ppl 1497706388552644096.00
482+ # [RANK 1]: | epoch 1 | 30/ 50 batches | lr 5.00 | ms/batch 1246.80 | loss 40.46 | ppl 373830103285747072.00
483+ # [RANK 0]: | epoch 1 | 40/ 50 batches | lr 5.00 | ms/batch 1246.69 | loss 39.76 | ppl 185159839078666368.00
484+ # [RANK 1]: | epoch 1 | 40/ 50 batches | lr 5.00 | ms/batch 1246.69 | loss 39.89 | ppl 211756997625874912.00
485+ # [RANK 0]: -----------------------------------------------------------------------------------------
486+ # [RANK 0]: | end of epoch 1 | time: 69.37s | valid loss 2.92 | valid ppl 18.46
487+ # [RANK 0]: -----------------------------------------------------------------------------------------
488+ # [RANK 1]: -----------------------------------------------------------------------------------------
489+ # [RANK 1]: | end of epoch 1 | time: 69.39s | valid loss 2.92 | valid ppl 18.46
490+ # [RANK 1]: -----------------------------------------------------------------------------------------
491+ # [RANK 1]: | epoch 2 | 10/ 50 batches | lr 4.51 | ms/batch 1373.91 | loss 39.77 | ppl 187532281612905856.00
492+ # [RANK 0]: | epoch 2 | 10/ 50 batches | lr 4.51 | ms/batch 1375.62 | loss 39.05 | ppl 91344349371016336.00
493+ # [RANK 0]: | epoch 2 | 20/ 50 batches | lr 4.51 | ms/batch 1250.33 | loss 30.62 | ppl 19917977906884.78
494+ # [RANK 1]: | epoch 2 | 20/ 50 batches | lr 4.51 | ms/batch 1250.33 | loss 30.48 | ppl 17250186491252.32
495+ # [RANK 1]: | epoch 2 | 30/ 50 batches | lr 4.51 | ms/batch 1250.73 | loss 29.14 | ppl 4534527326854.47
496+ # [RANK 0]: | epoch 2 | 30/ 50 batches | lr 4.51 | ms/batch 1250.73 | loss 29.43 | ppl 6035762659681.65
497+ # [RANK 0]: | epoch 2 | 40/ 50 batches | lr 4.51 | ms/batch 1249.54 | loss 23.11 | ppl 10869828323.89
498+ # [RANK 1]: | epoch 2 | 40/ 50 batches | lr 4.51 | ms/batch 1249.55 | loss 22.90 | ppl 8785318464.24
499+ # [RANK 0]: -----------------------------------------------------------------------------------------
500+ # [RANK 0]: | end of epoch 2 | time: 69.02s | valid loss 0.94 | valid ppl 2.55
501+ # [RANK 0]: -----------------------------------------------------------------------------------------
502+ # [RANK 1]: -----------------------------------------------------------------------------------------
503+ # [RANK 1]: | end of epoch 2 | time: 69.05s | valid loss 0.94 | valid ppl 2.55
504+ # [RANK 1]: -----------------------------------------------------------------------------------------
505+ # [RANK 0]: | epoch 3 | 10/ 50 batches | lr 4.29 | ms/batch 1380.66 | loss 12.98 | ppl 434052.59
506+ # [RANK 1]: | epoch 3 | 10/ 50 batches | lr 4.29 | ms/batch 1376.47 | loss 12.92 | ppl 410203.33
507+ # [RANK 1]: | epoch 3 | 20/ 50 batches | lr 4.29 | ms/batch 1250.88 | loss 9.80 | ppl 18034.58
508+ # [RANK 0]: | epoch 3 | 20/ 50 batches | lr 4.29 | ms/batch 1250.88 | loss 9.78 | ppl 17741.88
509+ # [RANK 0]: | epoch 3 | 30/ 50 batches | lr 4.29 | ms/batch 1251.89 | loss 10.37 | ppl 32016.45
510+ # [RANK 1]: | epoch 3 | 30/ 50 batches | lr 4.29 | ms/batch 1251.90 | loss 10.46 | ppl 34735.08
511+ # [RANK 0]: | epoch 3 | 40/ 50 batches | lr 4.29 | ms/batch 1250.70 | loss 10.09 | ppl 24147.61
512+ # [RANK 1]: | epoch 3 | 40/ 50 batches | lr 4.29 | ms/batch 1250.71 | loss 10.08 | ppl 23748.31
513+ # [RANK 0]: -----------------------------------------------------------------------------------------
514+ # [RANK 0]: | end of epoch 3 | time: 69.12s | valid loss 0.69 | valid ppl 2.00
515+ # [RANK 0]: -----------------------------------------------------------------------------------------
516+ # [RANK 1]: -----------------------------------------------------------------------------------------
517+ # [RANK 1]: | end of epoch 3 | time: 69.12s | valid loss 0.69 | valid ppl 2.00
518+ # [RANK 1]: -----------------------------------------------------------------------------------------
519+ # [RANK 0]: =========================================================================================
520+ # [RANK 0]: | End of training | test loss 0.60 | test ppl 1.83
521+ # [RANK 0]: =========================================================================================
522+ # [RANK 1]: =========================================================================================
523+ # [RANK 1]: | End of training | test loss 0.60 | test ppl 1.83
0 commit comments