-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Expand file tree
/
Copy path__init__.py
More file actions
43 lines (33 loc) · 1.64 KB
/
__init__.py
File metadata and controls
43 lines (33 loc) · 1.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import logging
import warnings
# May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten
import torch
# For optimizers and normalization there is no Python fallback.
# Absence of cuda backend is a hard error.
# I would like the errors from importing fused_adam_cuda or fused_layer_norm_cuda
# to be triggered lazily, because if someone has installed with --cpp_ext and --cuda_ext
# so they expect those backends to be available, but for some reason they actually aren't
# available (for example because they built improperly in a way that isn't revealed until
# load time) the error message is timely and visible.
from . import optimizers
from . import normalization
__all__ = ["optimizers", "normalization"]
def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool:
cudnn_available = torch.backends.cudnn.is_available()
cudnn_version = torch.backends.cudnn.version() if cudnn_available else None
if not (cudnn_available and (cudnn_version >= required_cudnn_version)):
warnings.warn(
f"`{global_option}` depends on cuDNN {required_cudnn_version} or later, "
f"but {'cuDNN is not available' if not cudnn_available else cudnn_version}"
)
return False
return True
class DeprecatedFeatureWarning(FutureWarning):
pass
def deprecated_warning(msg: str) -> None:
if (
not torch.distributed.is_available
or not torch.distributed.is_initialized()
or (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0)
):
warnings.warn(msg, DeprecatedFeatureWarning)