Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit e74a67b

Browse files
Aidyn-Acrcrpar
andauthored
[Contrib][Openfold Triton] Use json instead of pickle (#1900)
* use json instead of pickle * Remove space * Add backward call * Fix function arguments * Add run_tests() * Reorder imports * Apply nit-picks Co-Authored-By: Masaki Kozuki <[email protected]> --------- Co-authored-by: Masaki Kozuki <[email protected]>
1 parent 8ad740d commit e74a67b

2 files changed

Lines changed: 117 additions & 16 deletions

File tree

apex/contrib/openfold_triton/__init__.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# © 2023 NVIDIA CORPORATION & AFFILIATES
22

3-
import pickle
3+
import json
4+
import warnings
45
from collections import OrderedDict
56
from copy import deepcopy
67
from io import BytesIO
78
from typing import BinaryIO, Union
89

910
import torch
10-
from triton.runtime.autotuner import Autotuner, Heuristics
11+
from triton.runtime.autotuner import Autotuner, Config, Heuristics
1112
from triton.runtime.jit import JITFunction
1213

1314
from apex.contrib.openfold_triton._layer_norm_backward_kernels import (
@@ -58,23 +59,27 @@ def _get_tuneable_triton_func_name(f: Union[Autotuner, Heuristics, JITFunction])
5859
)
5960

6061

61-
def _save_triton_auto_tune_cache(f: BinaryIO, verbose: bool = False) -> None:
62+
def _save_triton_auto_tune_cache(strict: bool = True, verbose: bool = False) -> BytesIO:
6263
caches = OrderedDict()
6364
for func_name, func in _tuneable_triton_kernels.items():
6465
if len(func.cache) < 1:
65-
raise ValueError(
66-
f"Triton JIT kernel {func.__name__} didn't have tuning cache"
67-
)
68-
caches[func_name] = deepcopy(func.cache)
69-
pickle.dump(caches, f)
66+
msg = f"Triton JIT kernel {func_name} didn't have tuning cache"
67+
if strict:
68+
raise ValueError(msg)
69+
else:
70+
warnings.warn(msg)
71+
else:
72+
caches[func_name] = [(keys, vals.all_kwargs()) for keys, vals in zip(func.cache.keys(), func.cache.values())]
73+
f = BytesIO(json.dumps(caches).encode('utf-8'))
7074
if verbose:
7175
print(f"Triton kernel auto-tuning caches written to {f}")
76+
return f
7277

7378

7479
def _load_triton_auto_tune_cache(
7580
f: BinaryIO, strict: bool = True, verbose: bool = False
7681
) -> None:
77-
caches = pickle.load(f)
82+
caches = json.load(f)
7883
if strict:
7984
loaded_func_name = set(caches.keys())
8085
tuneable_func_name = set(_tuneable_triton_kernels.keys())
@@ -84,23 +89,23 @@ def _load_triton_auto_tune_cache(
8489
f"Missing kernel caches: {tuneable_func_name - loaded_func_name}\n"
8590
f"Unexpected kernel caches: {loaded_func_name - tuneable_func_name}"
8691
)
87-
for func_name, cache in caches.items():
92+
for func_name, func_cache in caches.items():
8893
if func_name not in _tuneable_triton_kernels:
8994
raise ValueError(
9095
f"{func_name} from {f} doesn't match any tuneable Triton kernels"
9196
)
92-
_tuneable_triton_kernels[func_name].cache = cache
97+
for key, val in func_cache:
98+
_tuneable_triton_kernels[func_name].cache[tuple(key)] = Config(val)
9399
if verbose:
94100
print(f"Triton kernel auto-tuning caches loaded from {f}")
95101

96102

97-
def sync_triton_auto_tune_cache_across_gpus() -> None:
103+
def sync_triton_auto_tune_cache_across_gpus(strict: bool = True, verbose: bool = False) -> None:
98104
if not torch.distributed.is_initialized():
99105
return
100106
if torch.distributed.get_rank() == 0:
101107
print("Broadcasting Triton auto-tuning cache from rank 0 to other ranks...")
102-
cache = BytesIO()
103-
_save_triton_auto_tune_cache(cache)
108+
cache = _save_triton_auto_tune_cache(strict=strict, verbose=verbose)
104109
cache.seek(0)
105110
cache_list = [
106111
cache,
@@ -113,6 +118,5 @@ def sync_triton_auto_tune_cache_across_gpus() -> None:
113118
None,
114119
]
115120
torch.distributed.broadcast_object_list(cache_list)
116-
cache = cache_list[0]
117-
_load_triton_auto_tune_cache(cache)
121+
_load_triton_auto_tune_cache(cache_list[0], strict=strict, verbose=verbose)
118122
print("Succeed!")
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import os
2+
import torch
3+
import torch.distributed as dist
4+
from torch.testing._internal.common_distributed import (
5+
MultiProcessTestCase,
6+
requires_nccl,
7+
skip_if_lt_x_gpu,
8+
run_tests,
9+
)
10+
from apex.contrib.openfold_triton import (
11+
LayerNormSmallShapeOptImpl,
12+
sync_triton_auto_tune_cache_across_gpus,
13+
_tuneable_triton_kernels,
14+
)
15+
16+
class SyncTritonAutoTuneCacheTest(MultiProcessTestCase):
17+
device_type = "cuda"
18+
def __init__(self, *args, **kwargs) -> None:
19+
super().__init__(*args, **kwargs)
20+
21+
def setUp(self) -> None:
22+
super().setUp()
23+
self._spawn_processes()
24+
25+
def tearDown(self) -> None:
26+
torch.cuda.synchronize()
27+
torch.cuda.empty_cache()
28+
super().tearDown()
29+
30+
@property
31+
def world_size(self) -> int:
32+
return min(torch.cuda.device_count(), 2)
33+
34+
@property
35+
def init_method(self):
36+
return f"{common_utils.FILE_SCHEMA}{self.file_name}"
37+
38+
@property
39+
def destroy_pg_upon_exit(self) -> bool:
40+
return True
41+
42+
def _create_process_group_nccl(self):
43+
def maybe_export(env, val):
44+
if not type(env) == str:
45+
raise ValueError(f"Type of type of env is expected to be str, but got {type(env)}")
46+
if not type(val) == str:
47+
raise ValueError(f"Type of type of val is expected to be str, but got {type(val)}")
48+
if os.getenv(env) is None:
49+
os.environ[env] = val
50+
51+
maybe_export("MASTER_PORT", "29500")
52+
maybe_export("MASTER_ADDR", "localhost")
53+
54+
# create nccl processgroup for two ranks
55+
dist.init_process_group(
56+
"nccl",
57+
world_size=self.world_size,
58+
rank=self.rank,
59+
)
60+
pg = dist.distributed_c10d._get_default_group()
61+
return pg
62+
63+
64+
@requires_nccl()
65+
@skip_if_lt_x_gpu(1)
66+
def test_sync_triton_auto_tune_cache_across_gpus(self):
67+
pg = self._create_process_group_nccl()
68+
device = torch.device(f"cuda:{self.rank % torch.cuda.device_count()}")
69+
torch.cuda.set_device(device)
70+
71+
if self.rank == 0:
72+
eps = 1e-5
73+
normalized_shape = (128, 64,)
74+
75+
weight = torch.ones(normalized_shape, device=device, requires_grad=True)
76+
bias= torch.zeros(normalized_shape, device=device, requires_grad=True)
77+
78+
x = torch.randn((2, 2,) + normalized_shape, device=device)
79+
y = LayerNormSmallShapeOptImpl.apply(
80+
x, normalized_shape, weight, bias, eps
81+
)
82+
l = torch.sum(y)
83+
l.backward()
84+
85+
sync_triton_auto_tune_cache_across_gpus(strict=False, verbose=True)
86+
87+
caches_synced = 0
88+
for func_name, func in _tuneable_triton_kernels.items():
89+
if len(func.cache) > 0:
90+
caches_synced = caches_synced + 1
91+
print(f"caches were synchronized for {func_name} at rank = {self.rank}:", func.cache)
92+
93+
self.assertTrue(caches_synced > 0)
94+
95+
96+
if __name__ == '__main__':
97+
run_tests()

0 commit comments

Comments
 (0)