Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Edits to blog post #1443

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 1, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion _posts/2023-08-31-pytorch-xla-spmd.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@ We abstract logical mesh with [Mesh API](https://github.com/pytorch/xla/blob/028
```
import numpy as np
import torch_xla.runtime as xr
import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharding import Mesh

# Enable XLA SPMD execution mode.
xr.use_spmd()

# Assuming you are running on a TPU host that has 8 devices attached
num_devices = xr.global_runtime_device_count()
# mesh shape will be (4,2) in this example
Expand Down Expand Up @@ -122,7 +126,7 @@ assert isinstance(m1_sharded, XLAShardedTensor) == True
We can annotate different tensors in the PyTorch program to enable different parallelism techniques, as described in the comment below:

```
# Sharding annotate the linear layer weights.
# Sharding annotate the linear layer weights. SimpleLinear() is a nn.Module.
model = SimpleLinear().to(xm.xla_device())
xs.mark_sharding(model.fc1.weight, mesh, partition_spec)

Expand Down