Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 5b04afc

Browse files
timmoon10crcrpar
andauthored
Include format version in distopt checkpoints (#1716)
* Include format version in distopt checkpoints Fix bug when loading old v1 checkpoints. Signed-off-by: Tim Moon <[email protected]> * Handle distopt v2 checkpoints without format info * Tweak documentation Signed-off-by: Tim Moon <[email protected]> * Update apex/contrib/optimizers/distributed_fused_adam.py Co-authored-by: Masaki Kozuki <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Masaki Kozuki <[email protected]>
1 parent 868f996 commit 5b04afc

1 file changed

Lines changed: 45 additions & 19 deletions

File tree

apex/contrib/optimizers/distributed_fused_adam.py

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2195,31 +2195,46 @@ def _local_step_with_param_remainders(
21952195
def state_dict(
21962196
self,
21972197
*,
2198-
v1_format: bool = False,
2198+
state_dict_format: Optional[int] = None,
21992199
gather_on_root: Optional[bool] = None,
22002200
) -> Optional[dict]:
22012201
"""Get dictionary containing optimizer state
22022202
22032203
Gathers optimizer state on the process group's root rank and
2204-
returns empty dictionaries on non-root ranks.
2204+
returns None on non-root ranks.
22052205
22062206
Arguments:
2207-
v1_format (bool, optional): Use deprecated v1 format
2208-
(default: False).
2207+
state_dict_format (optional): Tag for custom or deprecated
2208+
state dict format.
22092209
gather_on_root (bool, optional): Option for deprecated v1
22102210
format.
22112211
22122212
"""
22132213

2214-
# Deprecated v1 format
2215-
if v1_format:
2214+
# Default state dict format
2215+
if state_dict_format is None:
2216+
state_dict_format = 2
2217+
2218+
# Construct state dict
2219+
state_dict = None
2220+
if state_dict_format == 1:
2221+
# Deprecated v1 format
22162222
kwargs = {}
22172223
if gather_on_root is not None:
22182224
kwargs["gather_on_root"] = gather_on_root
2219-
return self._state_dict_v1(**kwargs)
2225+
state_dict = self._state_dict_v1(**kwargs)
2226+
elif state_dict_format == 2:
2227+
# Default v2 format
2228+
state_dict = self._state_dict_v2()
2229+
else:
2230+
# Unrecognized format
2231+
raise ValueError(f"Unrecognized state dict format ({state_dict_format})")
2232+
2233+
# Add format tag to state dict
2234+
if state_dict is not None:
2235+
state_dict["format"] = state_dict_format
22202236

2221-
# Default v2 format
2222-
return self._state_dict_v2()
2237+
return state_dict
22232238

22242239
def _state_dict_v1(self, gather_on_root: bool = True) -> Optional[dict]:
22252240
"""Get dictionary containing optimizer state (deprecated v1 format)
@@ -2381,11 +2396,11 @@ def _state_dict_v1(self, gather_on_root: bool = True) -> Optional[dict]:
23812396
return None
23822397

23832398
@torch.no_grad()
2384-
def _state_dict_v2(self) -> dict:
2399+
def _state_dict_v2(self) -> Optional[dict]:
23852400
"""Get dictionary containing optimizer state (default v2 format)
23862401
23872402
Gathers optimizer state on the process group's root rank and
2388-
returns empty dictionaries on non-root ranks.
2403+
returns None on non-root ranks.
23892404
23902405
"""
23912406

@@ -2399,7 +2414,7 @@ def _state_dict_v2(self) -> dict:
23992414
# Return immediately on ranks with redundant data
24002415
if self.redundant_size > 1:
24012416
if torch.distributed.get_rank(self.redundant_process_group) > 0:
2402-
return {}
2417+
return None
24032418

24042419
# Initialize state dict on root rank
24052420
if self.distributed_rank == 0:
@@ -2644,18 +2659,29 @@ def finish_gather(bucket_id: int, state_dict_key: str) -> None:
26442659
if self.distributed_rank == 0:
26452660
return state_dict
26462661
else:
2647-
return {}
2662+
return None
26482663

26492664
def load_state_dict(self, state_dict: dict) -> None:
26502665
"""Load optimizer state"""
26512666

2652-
# Deprecated v1 format
2653-
if "buckets" in state_dict["state"] or "gathered_states" in state_dict["state"]:
2654-
self._load_state_dict_v1(state_dict)
2655-
return
2667+
# Figure out state dict format
2668+
state_dict_format = state_dict.pop("format", None)
2669+
if state_dict_format is None:
2670+
if "buckets" in state_dict or "gathered_states" in state_dict:
2671+
state_dict_format = 1
2672+
else:
2673+
state_dict_format = 2
26562674

2657-
# Default v2 format
2658-
self._load_state_dict_v2(state_dict)
2675+
# Load state dict
2676+
if state_dict_format == 1:
2677+
# Deprecated v1 format
2678+
self._load_state_dict_v1(state_dict)
2679+
elif state_dict_format == 2:
2680+
# Default v2 format
2681+
self._load_state_dict_v2(state_dict)
2682+
else:
2683+
# Unrecognized format
2684+
raise ValueError(f"Unrecognized state dict format ({state_dict_format})")
26592685

26602686
def _load_state_dict_v1(self, state_dict: dict) -> None:
26612687
"""Load optimizer state (deprecated v1 format)

0 commit comments

Comments
 (0)