Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 4fcca77

Browse files
Remove apex.transformer (#1968)
* removed unused functions * apex.transformer no logner mentioned outside of tests and apex/transformer * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove transformer folder * remove transformer based tests in Apex * bring back files under contrib * add distributed folder w/ old definitions * changed structure * proper path for distributed test base * change paths, add __init__ * edit __init__.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 3c98f93 commit 4fcca77

69 files changed

Lines changed: 20 additions & 15837 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

apex/__init__.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,8 @@
1414
from . import optimizers
1515
from . import normalization
1616

17-
if torch.distributed.is_available():
18-
from . import transformer
1917

20-
__all__ = ["optimizers", "normalization", "transformer"]
21-
22-
# Logging utilities for apex.transformer module
23-
class RankInfoFormatter(logging.Formatter):
24-
def format(self, record):
25-
from apex.transformer.parallel_state import get_rank_info
26-
27-
record.rank_info = get_rank_info()
28-
return super().format(record)
29-
30-
_library_root_logger = logging.getLogger(__name__)
31-
handler = logging.StreamHandler()
32-
handler.setFormatter(
33-
RankInfoFormatter(
34-
"%(asctime)s - PID:%(process)d - rank:%(rank_info)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s",
35-
"%y-%m-%d %H:%M:%S",
36-
)
37-
)
38-
_library_root_logger.addHandler(handler)
39-
_library_root_logger.propagate = False
40-
else:
41-
# Transformers require PyTorch built with distributed support
42-
__all__ = ["optimizers", "normalization"]
18+
__all__ = ["optimizers", "normalization"]
4319

4420

4521
def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool:

apex/contrib/test/bottleneck/test_bottleneck_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from torch.testing._internal import common_utils
55

6-
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
6+
from apex.distributed_testing.distributed_test_base import NcclDistributedTestBase
77

88
SKIP_TEST = None
99
try:

apex/contrib/test/cudnn_gbn/test_cudnn_gbn_with_two_gpus.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch.testing._internal import common_utils
88

99
SKIP_TEST = None
10-
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
10+
from apex.distributed_testing.distributed_test_base import NcclDistributedTestBase
1111

1212
try:
1313
from apex.contrib.cudnn_gbn import GroupBatchNorm2d as GBN

apex/contrib/test/optimizers/test_dist_adam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from apex.contrib.optimizers.distributed_fused_adam import DistributedFusedAdam
1414
except ImportError as e:
1515
SKIP_TEST = e
16-
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
16+
from apex.distributed_testing.distributed_test_base import NcclDistributedTestBase
1717

1818

1919
class SimpleModel(torch.nn.Module):

apex/contrib/test/optimizers/test_distributed_fused_lamb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from torch.distributed.distributed_c10d import _coalescing_manager
77

88
from apex.contrib.optimizers.distributed_fused_lamb import DistributedFusedLAMB
9-
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
9+
from apex.distributed_testing.distributed_test_base import NcclDistributedTestBase
1010

1111

1212
def flat_dist_call(param_list: list[torch.Tensor], op, args):

apex/contrib/test/peer_memory/test_peer_halo_exchange_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch.testing._internal import common_utils
55

66
SKIP_TEST = None
7-
from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
7+
from apex.distributed_testing.distributed_test_base import NcclDistributedTestBase
88

99
try:
1010
from apex.contrib.peer_memory import PeerMemoryPool, PeerHaloExchanger1d
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Distributed testing utilities."""
2+
3+
from apex.distributed_testing.distributed_test_base import (
4+
DistributedTestBase,
5+
NcclDistributedTestBase,
6+
UccDistributedTestBase,
7+
)
8+
9+
__all__ = [
10+
"DistributedTestBase",
11+
"NcclDistributedTestBase",
12+
"UccDistributedTestBase",
13+
]

apex/transformer/testing/distributed_test_base.py renamed to apex/distributed_testing/distributed_test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch.testing._internal import common_utils
1010
from torch.testing._internal import common_distributed
1111

12-
from apex.transformer._ucc_util import HAS_UCC
12+
from apex.distributed_testing._ucc_util import HAS_UCC
1313

1414
# NOTE(mkozuki): Version guard for ucc. ref: https://github.com/openucx/ucc/issues/496
1515
_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION = Version("470.42.01")

apex/transformer/README.md

Lines changed: 0 additions & 81 deletions
This file was deleted.

0 commit comments

Comments
 (0)