Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 1d01e5c

Browse files
shjwudptimmoon10
andauthored
Add the warning of distributed_fused_adam low bucket usage (#1714)
* feat: Add the warning of distributed_fused_adam low bucket usage. * correct unittest * Update apex/contrib/optimizers/distributed_fused_adam.py Co-authored-by: Tim Moon <[email protected]> --------- Co-authored-by: Tim Moon <[email protected]>
1 parent bc4be41 commit 1d01e5c

2 files changed

Lines changed: 51 additions & 4 deletions

File tree

apex/contrib/optimizers/distributed_fused_adam.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,8 @@ def __init__(
393393
self.shard_size: int = shard_size
394394
# Size of the filled region in the bucket
395395
self.filled_size: int = 0
396+
# Is it able to continue filling
397+
self.able_to_fill: bool = True
396398
# Offset to bucket in contiguous buffers
397399
self.contiguous_buffer_offset: int = contiguous_buffer_offset
398400
# Buffer ranges corresponding to parameter fragments
@@ -1037,6 +1039,21 @@ def init_params(
10371039
param_group_id, param_id = id_map[param]
10381040
self._init_param_state(param, param_group_id, param_id)
10391041

1042+
num_params = sum(1 for param in self.parameters())
1043+
num_initialized_params = sum(
1044+
1 for param in self.parameters()
1045+
if "fragments" in self.state[param]
1046+
)
1047+
if num_initialized_params == num_params:
1048+
bucket_size = sum(bucket.bucket_size for bucket in self.state["buckets"])
1049+
filled_size = sum(bucket.filled_size for bucket in self.state["buckets"])
1050+
buckets_utilization = filled_size / bucket_size
1051+
if buckets_utilization < 0.7:
1052+
warnings.warn(
1053+
f"Only {buckets_utilization:.1%} of buckets are used. "
1054+
"Consider decreasing the bucket_cap_mb argument."
1055+
)
1056+
10401057
def init_params_bucket(self, params: Iterable[torch.nn.Parameter]) -> None:
10411058
"""Initialize optimizer state for parameters in one effective bucket
10421059
@@ -1065,7 +1082,7 @@ def init_params_bucket(self, params: Iterable[torch.nn.Parameter]) -> None:
10651082

10661083
# Mark existings bucket as fully filled
10671084
for bucket in self.state["buckets"]:
1068-
bucket.filled_size = bucket.bucket_size
1085+
bucket.able_to_fill = False
10691086

10701087
# Initialize optimizer state for parameters
10711088
start_bucket_id = len(self.state["buckets"])
@@ -1076,7 +1093,7 @@ def init_params_bucket(self, params: Iterable[torch.nn.Parameter]) -> None:
10761093
for bucket_id in range(start_bucket_id, end_bucket_id):
10771094
bucket = self.state["buckets"][bucket_id]
10781095
bucket_size = bucket.bucket_size
1079-
bucket.filled_size = bucket_size
1096+
bucket.able_to_fill = False
10801097
ids_in_bucket = set(
10811098
(fragment.param_group_id, fragment.param_id)
10821099
for fragment in bucket.fragments
@@ -1151,7 +1168,7 @@ def _init_param_state(
11511168
bucket_end = bucket_start + fragment_size
11521169

11531170
# Create new bucket if current one is full
1154-
if fragment_size <= 0:
1171+
if fragment_size <= 0 or not bucket.able_to_fill:
11551172
shard_size = self.default_shard_size
11561173
bucket_size = shard_size * self.distributed_size
11571174
buffer_offset = bucket.contiguous_buffer_offset + bucket.bucket_size

apex/contrib/test/optimizers/test_dist_adam.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import io
33
from typing import Optional, Tuple
44
import unittest
5+
import warnings
56

67
import torch
78
from torch.testing._internal import common_utils
@@ -43,6 +44,7 @@ def make_models(
4344
contiguous_buffers=False,
4445
store_params=False,
4546
store_param_remainders=False,
47+
bucket_cap_mb=71/(4*1024*1024),
4648
):
4749

4850
# Construct models with same parameters
@@ -82,7 +84,7 @@ def make_models(
8284
adam_w_mode=adam_w_mode,
8385
overlap_grad_sync=overlap_communication,
8486
overlap_param_sync=overlap_communication,
85-
bucket_cap_mb=71/(4*1024*1024),
87+
bucket_cap_mb=bucket_cap_mb,
8688
dtype=optim_dtype,
8789
grad_sync_dtype=grad_sync_dtype,
8890
param_sync_dtype=param_sync_dtype,
@@ -131,6 +133,7 @@ def test_matches_pytorch(
131133
contiguous_buffers=False,
132134
store_params=False,
133135
store_param_remainders=False,
136+
bucket_cap_mb=71/(4*1024*1024),
134137
):
135138

136139
torch.manual_seed(self.seed + self.rank)
@@ -149,6 +152,7 @@ def test_matches_pytorch(
149152
contiguous_buffers=contiguous_buffers,
150153
store_params=store_params,
151154
store_param_remainders=store_param_remainders,
155+
bucket_cap_mb=bucket_cap_mb,
152156
)
153157

154158
# Training loop
@@ -678,6 +682,32 @@ def test_checkpoint_bf16(self):
678682
),
679683
)
680684

685+
def test_bucket_low_utilization_warning(self):
686+
"""Test warning when bucket utilization is low"""
687+
layer_size = 2*1024*1024
688+
num_layers = 4
689+
fairish_bucket_cap_mb = 4*num_layers*layer_size/(1024*1024)
690+
691+
# Check that warning is raised when bucket utilization is low
692+
with self.assertWarnsRegex(Warning, ".*Consider decreasing the bucket_cap_mb argument."):
693+
self.test_matches_pytorch(
694+
num_layers=num_layers,
695+
layer_size=layer_size,
696+
bucket_cap_mb=fairish_bucket_cap_mb * 2,
697+
contiguous_buffers=True,
698+
)
699+
700+
# Check that warning is not raised when bucket utilization is high
701+
with warnings.catch_warnings(record=True) as warns:
702+
self.test_matches_pytorch(
703+
num_layers=num_layers,
704+
layer_size=layer_size,
705+
bucket_cap_mb=fairish_bucket_cap_mb,
706+
contiguous_buffers=True,
707+
)
708+
for w in warns:
709+
self.assertNotRegex(str(w.message), ".*Consider decreasing the bucket_cap_mb argument.")
710+
681711

682712
if __name__ == "__main__":
683713
# Assume script has been run with torchrun

0 commit comments

Comments
 (0)