Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit b6df043

Browse files
bowangbjfacebook-github-bot
authored andcommitted
Add torch.nn.init.uniform_ operator to ShardedTensor. (#63997)
Summary: Pull Request resolved: #63997 Use torch_function to extend torch.nn.init.uniform_ The Init is done in SPMD fashion. Note that ideally we want to aggregate sharded tensors into a global tensor, init it and reshard. It's fine to run it SPMD since uniform is I.I.D indepenent and identifically distributed. Also enable unit test for test_linear.py for OSS test Test Plan: a) Unit Test (pytorch) ... $ python test/distributed/_sharded_tensor/ops/test_init.py TestShardedTensorNNInit --v (pytorch) ... $ python test/distributed/_sharded_tensor/ops/test_linear.py --v (before runs this command is no-op) or b) Manual run: Instruction here: https://docs.google.com/document/d/1_m1Hdo5w51-hhPlZ_F8Y6PIWrN7UgJZqiSpARYvhsaE/edit# Imported from OSS Reviewed By: pritamdamania87, anjali411 Differential Revision: D30563017 fbshipit-source-id: d1859f7682235bcb44515efc69ca92bc5e34fce1
1 parent bdb889a commit b6df043

File tree

7 files changed

+107
-2
lines changed

7 files changed

+107
-2
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import sys
2+
import torch
3+
4+
from torch.distributed import _sharded_tensor
5+
from torch.distributed._sharding_spec import (
6+
ChunkShardingSpec,
7+
)
8+
from torch.testing._internal.common_distributed import (
9+
requires_nccl,
10+
skip_if_lt_x_gpu,
11+
)
12+
from torch.testing._internal.distributed._sharded_tensor import (
13+
ShardedTensorTestBase,
14+
with_comms,
15+
)
16+
from torch.testing._internal.common_utils import (
17+
TEST_WITH_DEV_DBG_ASAN,
18+
run_tests,
19+
)
20+
21+
if TEST_WITH_DEV_DBG_ASAN:
22+
print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr)
23+
sys.exit(0)
24+
25+
class TestShardedTensorNNInit(ShardedTensorTestBase):
26+
""" Testing torch.nn.init functions for ShardedTensor """
27+
28+
@with_comms
29+
@skip_if_lt_x_gpu(4)
30+
@requires_nccl()
31+
def test_init_sharded_tensor_with_uniform(self):
32+
""" Test torch.nn.init.uniform_(ShardedTensor, a, b) """
33+
34+
spec = ChunkShardingSpec(
35+
dim=0,
36+
placements=[
37+
"rank:0/cuda:0",
38+
"rank:1/cuda:1",
39+
"rank:2/cuda:2",
40+
"rank:3/cuda:3",
41+
],
42+
)
43+
h, w = 8, 2
44+
expected_h = 2
45+
expected_device = torch.device(f"cuda:{self.rank}")
46+
a, b = 10, 20
47+
48+
seed = 1234
49+
dtype = torch.double
50+
51+
sharded_tensor = _sharded_tensor.empty(spec, h, w, dtype=dtype)
52+
self.assertEqual(1, len(sharded_tensor.local_shards()))
53+
54+
# Clone local tensor to ensure torch.nn.init starts from the same input
55+
local_tensor_clone = torch.clone(sharded_tensor.local_shards()[0].tensor)
56+
torch.manual_seed(seed)
57+
torch.nn.init.uniform_(sharded_tensor, a=a, b=b)
58+
59+
torch.manual_seed(seed)
60+
torch.nn.init.uniform_(local_tensor_clone, a=a, b=b)
61+
self.assertEqual(local_tensor_clone, sharded_tensor.local_shards()[0].tensor)
62+
63+
64+
if __name__ == '__main__':
65+
run_tests()

test/distributed/_sharded_tensor/ops/test_linear.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
from torch.testing._internal.common_utils import (
1414
TEST_WITH_DEV_DBG_ASAN,
15+
run_tests,
1516
)
1617
from torch.testing._internal.distributed._sharded_tensor import (
1718
ShardedTensorTestBase,
@@ -85,3 +86,6 @@ def test_sharded_linear_rowwise(self):
8586
# Test uneven split.
8687
self._run_sharded_linear(spec, [5, 19], [19, 11], 1)
8788
self._run_sharded_linear(spec, [5, 21], [21, 11], 1)
89+
90+
if __name__ == '__main__':
91+
run_tests()

test/run_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def skip_test_p(name: str) -> bool:
199199
"distributed/elastic/multiprocessing/api_test",
200200
"distributed/_sharded_tensor/test_sharded_tensor",
201201
"distributed/_sharded_tensor/ops/test_embedding",
202+
"distributed/_sharded_tensor/ops/test_init",
202203
"distributed/_sharded_tensor/ops/test_linear",
203204
] + FSDP_TEST
204205

@@ -209,6 +210,7 @@ def skip_test_p(name: str) -> bool:
209210
"distributed/rpc/cuda/test_tensorpipe_agent",
210211
"distributed/_sharded_tensor/test_sharded_tensor",
211212
"distributed/_sharded_tensor/ops/test_embedding",
213+
"distributed/_sharded_tensor/ops/test_init",
212214
"distributed/_sharded_tensor/ops/test_linear",
213215
"test_determination",
214216
"test_multiprocessing",
@@ -345,6 +347,7 @@ def skip_test_p(name: str) -> bool:
345347
"distributed/_sharding_spec/test_sharding_spec",
346348
"distributed/_sharded_tensor/test_sharded_tensor",
347349
"distributed/_sharded_tensor/ops/test_embedding",
350+
"distributed/_sharded_tensor/ops/test_init",
348351
"distributed/_sharded_tensor/ops/test_linear",
349352
] + [test for test in TESTS if test.startswith("distributed/fsdp")]
350353

torch/distributed/_sharded_tensor/api.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
get_chunked_dim_size,
2828
)
2929
from torch.types import Number
30-
from .ops import sharded_embedding, sharded_linear
30+
from .ops import sharded_embedding, sharded_linear, uniform_
3131

3232
# Tracking for sharded tensor objects.
3333
_sharded_tensor_lock = threading.Lock()
@@ -638,7 +638,8 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
638638
return sharded_linear(types, args, kwargs, self._process_group)
639639
if func == torch.nn.functional.embedding:
640640
return sharded_embedding(types, args, kwargs, self._process_group)
641-
641+
elif func == torch.nn.init.uniform_:
642+
return uniform_(types, args, kwargs)
642643
raise RuntimeError(
643644
f"torch function '{func.__name__}', with args: {args} and "
644645
f"kwargs: {kwargs} not supported for ShardedTensor!")
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1+
from .init import uniform_
12
from .linear import sharded_linear
23
from .embedding import sharded_embedding
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import torch
2+
3+
def validate_param(param, param_name):
4+
if param is None:
5+
raise ValueError(f"param: {param_name} shouldn't be None!")
6+
7+
def uniform_(types, args=(), kwargs=None):
8+
r"""
9+
Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform
10+
distribution :math:`\mathcal{U}(a, b)`.
11+
Args:
12+
sharded_tensor: tensor sharded across devices
13+
a: the lower bound of the uniform distribution
14+
b: the upper bound of the uniform distribution
15+
"""
16+
validate_param(kwargs, "kwargs")
17+
sharded_tensor = kwargs["tensor"]
18+
validate_param(sharded_tensor, "sharded_tensor")
19+
a = kwargs['a']
20+
validate_param(a, "a")
21+
b = kwargs['b']
22+
validate_param(b, "b")
23+
24+
for shard in sharded_tensor.local_shards():
25+
torch.nn.init.uniform_(shard.tensor, a=a, b=b)
26+
return sharded_tensor

torch/nn/init.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
from torch import Tensor
55
import torch
66

7+
from ..overrides import (
8+
has_torch_function_variadic,
9+
handle_torch_function)
710

811
# These no_grad_* functions are necessary as wrappers around the parts of these
912
# functions that use `with torch.no_grad()`. The JIT doesn't support context
@@ -132,6 +135,8 @@ def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor:
132135
>>> w = torch.empty(3, 5)
133136
>>> nn.init.uniform_(w)
134137
"""
138+
if has_torch_function_variadic(tensor, a, b):
139+
return handle_torch_function(uniform_, (tensor, a, b), tensor=tensor, a=a, b=b)
135140
return _no_grad_uniform_(tensor, a, b)
136141

137142

0 commit comments

Comments
 (0)