Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit bf3c008

Browse files
authored
[UCC][TORCH_UCC]Do integer driver version comparison for UCC (#1411)
* Integer driver number comparison * packaging
1 parent 3ff1a10 commit bf3c008

2 files changed

Lines changed: 4 additions & 2 deletions

File tree

apex/transformer/testing/distributed_test_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import sys
33
import unittest
4+
from packaging.version import Version, parse
45

56
import torch
67
from torch import distributed as dist
@@ -16,10 +17,10 @@
1617
HAS_TORCH_UCC = False
1718

1819
# NOTE(mkozuki): Version guard for ucc. ref: https://github.com/openucx/ucc/issues/496
19-
_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION = "470.42.01"
20+
_TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION = Version("470.42.01")
2021
_driver_version = None
2122
if torch.cuda.is_available():
22-
_driver_version = collect_env.get_nvidia_driver_version(collect_env.run)
23+
_driver_version = parse(collect_env.get_nvidia_driver_version(collect_env.run))
2324
HAS_TORCH_UCC_COMPAT_NVIDIA_DRIVER = _driver_version is not None and _driver_version >= _TORCH_UCC_COMPAT_NVIDIA_DRIVER_VERSION
2425

2526

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ tqdm>=4.28.1
33
numpy>=1.15.3
44
PyYAML>=5.1
55
pytest>=3.5.1
6+
packaging>=14.0

0 commit comments

Comments
 (0)