diff --git a/_posts/2023-08-31-pytorch-xla-spmd.md b/_posts/2023-08-31-pytorch-xla-spmd.md index 8c5607ee087f..715a1dc6ff4b 100644 --- a/_posts/2023-08-31-pytorch-xla-spmd.md +++ b/_posts/2023-08-31-pytorch-xla-spmd.md @@ -41,8 +41,12 @@ We abstract logical mesh with [Mesh API](https://github.com/pytorch/xla/blob/028 ``` import numpy as np import torch_xla.runtime as xr +import torch_xla.experimental.xla_sharding as xs from torch_xla.experimental.xla_sharding import Mesh +# Enable XLA SPMD execution mode. +xr.use_spmd() + # Assuming you are running on a TPU host that has 8 devices attached num_devices = xr.global_runtime_device_count() # mesh shape will be (4,2) in this example @@ -122,7 +126,7 @@ assert isinstance(m1_sharded, XLAShardedTensor) == True We can annotate different tensors in the PyTorch program to enable different parallelism techniques, as described in the comment below: ``` -# Sharding annotate the linear layer weights. +# Sharding annotate the linear layer weights. SimpleLinear() is a nn.Module. model = SimpleLinear().to(xm.xla_device()) xs.mark_sharding(model.fc1.weight, mesh, partition_spec)