Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 335356d

Browse files
authored
Edits to blog post (pytorch#1443)
Signed-off-by: Chris Abraham <[email protected]>
1 parent f598377 commit 335356d

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

_posts/2023-08-31-pytorch-xla-spmd.md

+5-1
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,12 @@ We abstract logical mesh with [Mesh API](https://github.com/pytorch/xla/blob/028
4141
```
4242
import numpy as np
4343
import torch_xla.runtime as xr
44+
import torch_xla.experimental.xla_sharding as xs
4445
from torch_xla.experimental.xla_sharding import Mesh
4546
47+
# Enable XLA SPMD execution mode.
48+
xr.use_spmd()
49+
4650
# Assuming you are running on a TPU host that has 8 devices attached
4751
num_devices = xr.global_runtime_device_count()
4852
# mesh shape will be (4,2) in this example
@@ -122,7 +126,7 @@ assert isinstance(m1_sharded, XLAShardedTensor) == True
122126
We can annotate different tensors in the PyTorch program to enable different parallelism techniques, as described in the comment below:
123127

124128
```
125-
# Sharding annotate the linear layer weights.
129+
# Sharding annotate the linear layer weights. SimpleLinear() is a nn.Module.
126130
model = SimpleLinear().to(xm.xla_device())
127131
xs.mark_sharding(model.fc1.weight, mesh, partition_spec)
128132

0 commit comments

Comments
 (0)