Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 51de9f9

Browse files
pritamdamania87pritamdamaniabrianjo
authored
Updates to ddp pipeline tutorial (#1419)
* Updates to ddp pipeline tutorial Summary: Bring the tutorial to be more on par with the pipeline. * Add tensorpipe options * Update build.sh Co-authored-by: pritam <[email protected]> Co-authored-by: Brian Johnson <[email protected]>
1 parent 002409a commit 51de9f9

1 file changed

Lines changed: 105 additions & 46 deletions

File tree

advanced_source/ddp_pipeline_tutorial.py

Lines changed: 105 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,11 @@ def forward(self, x):
7373
# The `nn.TransformerEncoder <https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoder.html>`__
7474
# itself consists of ``nlayers`` of `nn.TransformerEncoderLayer <https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html>`__.
7575
# As a result, our focus is on ``nn.TransformerEncoder`` and we split the model
76-
# such that half of the ``nn.TransformerEncoderLayer`` are in ``TransformerModelStage1``
77-
# and the other half are in ``TransformerModelStage2``.
76+
# such that half of the ``nn.TransformerEncoderLayer`` are on one GPU and the
77+
# other half are on another. To do this, we pull out the ``Encoder`` and
78+
# ``Decoder`` sections into seperate modules and then build an nn.Sequential
79+
# representing the original Transformer module.
80+
7881

7982
if sys.platform == 'win32':
8083
print('Windows platform is not supported for pipeline parallelism')
@@ -83,69 +86,46 @@ def forward(self, x):
8386
print('Need at least four GPU devices for this tutorial')
8487
sys.exit(0)
8588

86-
class TransformerModelStage1(nn.Module):
87-
88-
def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
89-
super(TransformerModelStage1, self).__init__()
89+
class Encoder(nn.Module):
90+
def __init__(self, ntoken, ninp, dropout=0.5):
91+
super(Encoder, self).__init__()
9092
self.src_mask = None
9193
self.pos_encoder = PositionalEncoding(ninp, dropout)
92-
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
93-
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
9494
self.encoder = nn.Embedding(ntoken, ninp)
9595
self.ninp = ninp
96-
9796
self.init_weights()
9897

98+
def init_weights(self):
99+
initrange = 0.1
100+
self.encoder.weight.data.uniform_(-initrange, initrange)
101+
99102
def _generate_square_subsequent_mask(self, sz):
100103
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
101104
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
102105
return mask
103106

104-
def init_weights(self):
105-
initrange = 0.1
106-
self.encoder.weight.data.uniform_(-initrange, initrange)
107-
108107
def forward(self, src):
109108
if self.src_mask is None or self.src_mask.size(0) != src.size(0):
110109
device = src.device
111110
mask = self._generate_square_subsequent_mask(src.size(0)).to(device)
112111
self.src_mask = mask
113112

114113
src = self.encoder(src) * math.sqrt(self.ninp)
115-
src = self.pos_encoder(src)
116-
output = self.transformer_encoder(src, self.src_mask)
117-
return output
118-
119-
class TransformerModelStage2(nn.Module):
114+
return self.pos_encoder(src)
120115

121-
def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
122-
super(TransformerModelStage2, self).__init__()
123-
self.src_mask = None
124-
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
125-
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
116+
class Decoder(nn.Module):
117+
def __init__(self, ntoken, ninp):
118+
super(Decoder, self).__init__()
126119
self.decoder = nn.Linear(ninp, ntoken)
127-
128120
self.init_weights()
129121

130-
def _generate_square_subsequent_mask(self, sz):
131-
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
132-
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
133-
return mask
134-
135122
def init_weights(self):
136123
initrange = 0.1
137124
self.decoder.bias.data.zero_()
138125
self.decoder.weight.data.uniform_(-initrange, initrange)
139126

140-
def forward(self, src):
141-
if self.src_mask is None or self.src_mask.size(0) != src.size(0):
142-
device = src.device
143-
mask = self._generate_square_subsequent_mask(src.size(0)).to(device)
144-
self.src_mask = mask
145-
146-
output = self.transformer_encoder(src, self.src_mask)
147-
output = self.decoder(output)
148-
return output
127+
def forward(self, inp):
128+
return self.decoder(inp)
149129

150130
######################################################################
151131
# Start multiple processes for training
@@ -306,20 +286,40 @@ def get_batch(source, i):
306286
world_size=1,
307287
rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
308288
init_method="file://{}".format(tmpfile.name),
289+
# Specifying _transports and _channels is a workaround and we no longer
290+
# will have to specify _transports and _channels for PyTorch
291+
# versions >= 1.8.1
292+
_transports=["ibv", "uv"],
293+
_channels=["cuda_ipc", "cuda_basic"],
309294
)
310295
)
311296

297+
# Num gpus for model parallelism.
298+
num_gpus = 2
299+
partition_len = ((nlayers - 1) // num_gpus) + 1
300+
301+
# Add encoder in the beginning.
302+
tmp_list = [Encoder(ntokens, emsize, dropout).cuda(2 * rank)]
303+
module_list = []
304+
305+
# Add all the necessary transformer blocks.
306+
for i in range(nlayers):
307+
transformer_block = TransformerEncoderLayer(emsize, nhead, nhid, dropout)
308+
if i != 0 and i % (partition_len) == 0:
309+
module_list.append(nn.Sequential(*tmp_list))
310+
tmp_list = []
311+
device = i // (partition_len)
312+
tmp_list.append(transformer_block.to(2 * rank + device))
313+
314+
# Add decoder in the end.
315+
tmp_list.append(Decoder(ntokens, emsize).cuda(2 * rank + num_gpus - 1))
316+
module_list.append(nn.Sequential(*tmp_list))
317+
312318
# Need to use 'checkpoint=never' since as of PyTorch 1.8, Pipe checkpointing
313319
# doesn't work with DDP.
314320
from torch.distributed.pipeline.sync import Pipe
315-
model = Pipe(
316-
torch.nn.Sequential(
317-
TransformerModelStage1(ntokens, emsize, nhead, nhid, int(nlayers/2), dropout).cuda(2 * rank),
318-
TransformerModelStage2(ntokens, emsize, nhead, nhid, int(nlayers/2), dropout).cuda(2 * rank + 1),
319-
),
320-
chunks = 8,
321-
checkpoint = "never"
322-
)
321+
model = Pipe(torch.nn.Sequential(
322+
*module_list), chunks = 8, checkpoint="never")
323323

324324
# Initialize process group and wrap model in DDP.
325325
from torch.nn.parallel import DistributedDataParallel
@@ -462,3 +462,62 @@ def evaluate(eval_model, data_source):
462462
world_size = 2
463463
mp.spawn(run_worker, args=(world_size, ), nprocs=world_size, join=True)
464464

465+
466+
######################################################################
467+
# Output
468+
# ------
469+
#
470+
471+
472+
######################################################################
473+
#.. code-block:: py
474+
#
475+
# [RANK 1]: Total parameters in model: 1,041,453,167
476+
# [RANK 0]: Total parameters in model: 1,041,453,167
477+
# [RANK 0]: | epoch 1 | 10/ 50 batches | lr 5.00 | ms/batch 1414.18 | loss 48.70 | ppl 1406154472673147092992.00
478+
# [RANK 1]: | epoch 1 | 10/ 50 batches | lr 5.00 | ms/batch 1414.42 | loss 48.49 | ppl 1146707511057334927360.00
479+
# [RANK 0]: | epoch 1 | 20/ 50 batches | lr 5.00 | ms/batch 1260.76 | loss 42.74 | ppl 3648812398518492672.00
480+
# [RANK 1]: | epoch 1 | 20/ 50 batches | lr 5.00 | ms/batch 1260.76 | loss 41.51 | ppl 1064844757565813248.00
481+
# [RANK 0]: | epoch 1 | 30/ 50 batches | lr 5.00 | ms/batch 1246.80 | loss 41.85 | ppl 1497706388552644096.00
482+
# [RANK 1]: | epoch 1 | 30/ 50 batches | lr 5.00 | ms/batch 1246.80 | loss 40.46 | ppl 373830103285747072.00
483+
# [RANK 0]: | epoch 1 | 40/ 50 batches | lr 5.00 | ms/batch 1246.69 | loss 39.76 | ppl 185159839078666368.00
484+
# [RANK 1]: | epoch 1 | 40/ 50 batches | lr 5.00 | ms/batch 1246.69 | loss 39.89 | ppl 211756997625874912.00
485+
# [RANK 0]: -----------------------------------------------------------------------------------------
486+
# [RANK 0]: | end of epoch 1 | time: 69.37s | valid loss 2.92 | valid ppl 18.46
487+
# [RANK 0]: -----------------------------------------------------------------------------------------
488+
# [RANK 1]: -----------------------------------------------------------------------------------------
489+
# [RANK 1]: | end of epoch 1 | time: 69.39s | valid loss 2.92 | valid ppl 18.46
490+
# [RANK 1]: -----------------------------------------------------------------------------------------
491+
# [RANK 1]: | epoch 2 | 10/ 50 batches | lr 4.51 | ms/batch 1373.91 | loss 39.77 | ppl 187532281612905856.00
492+
# [RANK 0]: | epoch 2 | 10/ 50 batches | lr 4.51 | ms/batch 1375.62 | loss 39.05 | ppl 91344349371016336.00
493+
# [RANK 0]: | epoch 2 | 20/ 50 batches | lr 4.51 | ms/batch 1250.33 | loss 30.62 | ppl 19917977906884.78
494+
# [RANK 1]: | epoch 2 | 20/ 50 batches | lr 4.51 | ms/batch 1250.33 | loss 30.48 | ppl 17250186491252.32
495+
# [RANK 1]: | epoch 2 | 30/ 50 batches | lr 4.51 | ms/batch 1250.73 | loss 29.14 | ppl 4534527326854.47
496+
# [RANK 0]: | epoch 2 | 30/ 50 batches | lr 4.51 | ms/batch 1250.73 | loss 29.43 | ppl 6035762659681.65
497+
# [RANK 0]: | epoch 2 | 40/ 50 batches | lr 4.51 | ms/batch 1249.54 | loss 23.11 | ppl 10869828323.89
498+
# [RANK 1]: | epoch 2 | 40/ 50 batches | lr 4.51 | ms/batch 1249.55 | loss 22.90 | ppl 8785318464.24
499+
# [RANK 0]: -----------------------------------------------------------------------------------------
500+
# [RANK 0]: | end of epoch 2 | time: 69.02s | valid loss 0.94 | valid ppl 2.55
501+
# [RANK 0]: -----------------------------------------------------------------------------------------
502+
# [RANK 1]: -----------------------------------------------------------------------------------------
503+
# [RANK 1]: | end of epoch 2 | time: 69.05s | valid loss 0.94 | valid ppl 2.55
504+
# [RANK 1]: -----------------------------------------------------------------------------------------
505+
# [RANK 0]: | epoch 3 | 10/ 50 batches | lr 4.29 | ms/batch 1380.66 | loss 12.98 | ppl 434052.59
506+
# [RANK 1]: | epoch 3 | 10/ 50 batches | lr 4.29 | ms/batch 1376.47 | loss 12.92 | ppl 410203.33
507+
# [RANK 1]: | epoch 3 | 20/ 50 batches | lr 4.29 | ms/batch 1250.88 | loss 9.80 | ppl 18034.58
508+
# [RANK 0]: | epoch 3 | 20/ 50 batches | lr 4.29 | ms/batch 1250.88 | loss 9.78 | ppl 17741.88
509+
# [RANK 0]: | epoch 3 | 30/ 50 batches | lr 4.29 | ms/batch 1251.89 | loss 10.37 | ppl 32016.45
510+
# [RANK 1]: | epoch 3 | 30/ 50 batches | lr 4.29 | ms/batch 1251.90 | loss 10.46 | ppl 34735.08
511+
# [RANK 0]: | epoch 3 | 40/ 50 batches | lr 4.29 | ms/batch 1250.70 | loss 10.09 | ppl 24147.61
512+
# [RANK 1]: | epoch 3 | 40/ 50 batches | lr 4.29 | ms/batch 1250.71 | loss 10.08 | ppl 23748.31
513+
# [RANK 0]: -----------------------------------------------------------------------------------------
514+
# [RANK 0]: | end of epoch 3 | time: 69.12s | valid loss 0.69 | valid ppl 2.00
515+
# [RANK 0]: -----------------------------------------------------------------------------------------
516+
# [RANK 1]: -----------------------------------------------------------------------------------------
517+
# [RANK 1]: | end of epoch 3 | time: 69.12s | valid loss 0.69 | valid ppl 2.00
518+
# [RANK 1]: -----------------------------------------------------------------------------------------
519+
# [RANK 0]: =========================================================================================
520+
# [RANK 0]: | End of training | test loss 0.60 | test ppl 1.83
521+
# [RANK 0]: =========================================================================================
522+
# [RANK 1]: =========================================================================================
523+
# [RANK 1]: | End of training | test loss 0.60 | test ppl 1.83

0 commit comments

Comments
 (0)