|
| 1 | +import sys |
| 2 | +import torch |
| 3 | + |
| 4 | +from torch.distributed import _sharded_tensor |
| 5 | +from torch.distributed._sharding_spec import ( |
| 6 | + ChunkShardingSpec, |
| 7 | +) |
| 8 | +from torch.testing._internal.common_distributed import ( |
| 9 | + requires_nccl, |
| 10 | + skip_if_lt_x_gpu, |
| 11 | +) |
| 12 | +from torch.testing._internal.distributed._sharded_tensor import ( |
| 13 | + ShardedTensorTestBase, |
| 14 | + with_comms, |
| 15 | +) |
| 16 | +from torch.testing._internal.common_utils import ( |
| 17 | + TEST_WITH_DEV_DBG_ASAN, |
| 18 | + run_tests, |
| 19 | +) |
| 20 | + |
| 21 | +if TEST_WITH_DEV_DBG_ASAN: |
| 22 | + print("Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr) |
| 23 | + sys.exit(0) |
| 24 | + |
| 25 | +class TestShardedTensorNNInit(ShardedTensorTestBase): |
| 26 | + """ Testing torch.nn.init functions for ShardedTensor """ |
| 27 | + |
| 28 | + @with_comms |
| 29 | + @skip_if_lt_x_gpu(4) |
| 30 | + @requires_nccl() |
| 31 | + def test_init_sharded_tensor_with_uniform(self): |
| 32 | + """ Test torch.nn.init.uniform_(ShardedTensor, a, b) """ |
| 33 | + |
| 34 | + spec = ChunkShardingSpec( |
| 35 | + dim=0, |
| 36 | + placements=[ |
| 37 | + "rank:0/cuda:0", |
| 38 | + "rank:1/cuda:1", |
| 39 | + "rank:2/cuda:2", |
| 40 | + "rank:3/cuda:3", |
| 41 | + ], |
| 42 | + ) |
| 43 | + h, w = 8, 2 |
| 44 | + expected_h = 2 |
| 45 | + expected_device = torch.device(f"cuda:{self.rank}") |
| 46 | + a, b = 10, 20 |
| 47 | + |
| 48 | + seed = 1234 |
| 49 | + dtype = torch.double |
| 50 | + |
| 51 | + sharded_tensor = _sharded_tensor.empty(spec, h, w, dtype=dtype) |
| 52 | + self.assertEqual(1, len(sharded_tensor.local_shards())) |
| 53 | + |
| 54 | + # Clone local tensor to ensure torch.nn.init starts from the same input |
| 55 | + local_tensor_clone = torch.clone(sharded_tensor.local_shards()[0].tensor) |
| 56 | + torch.manual_seed(seed) |
| 57 | + torch.nn.init.uniform_(sharded_tensor, a=a, b=b) |
| 58 | + |
| 59 | + torch.manual_seed(seed) |
| 60 | + torch.nn.init.uniform_(local_tensor_clone, a=a, b=b) |
| 61 | + self.assertEqual(local_tensor_clone, sharded_tensor.local_shards()[0].tensor) |
| 62 | + |
| 63 | + |
| 64 | +if __name__ == '__main__': |
| 65 | + run_tests() |
0 commit comments