@@ -2195,31 +2195,46 @@ def _local_step_with_param_remainders(
21952195 def state_dict (
21962196 self ,
21972197 * ,
2198- v1_format : bool = False ,
2198+ state_dict_format : Optional [ int ] = None ,
21992199 gather_on_root : Optional [bool ] = None ,
22002200 ) -> Optional [dict ]:
22012201 """Get dictionary containing optimizer state
22022202
22032203 Gathers optimizer state on the process group's root rank and
2204- returns empty dictionaries on non-root ranks.
2204+ returns None on non-root ranks.
22052205
22062206 Arguments:
2207- v1_format (bool, optional): Use deprecated v1 format
2208- (default: False) .
2207+ state_dict_format ( optional): Tag for custom or deprecated
2208+ state dict format .
22092209 gather_on_root (bool, optional): Option for deprecated v1
22102210 format.
22112211
22122212 """
22132213
2214- # Deprecated v1 format
2215- if v1_format :
2214+ # Default state dict format
2215+ if state_dict_format is None :
2216+ state_dict_format = 2
2217+
2218+ # Construct state dict
2219+ state_dict = None
2220+ if state_dict_format == 1 :
2221+ # Deprecated v1 format
22162222 kwargs = {}
22172223 if gather_on_root is not None :
22182224 kwargs ["gather_on_root" ] = gather_on_root
2219- return self ._state_dict_v1 (** kwargs )
2225+ state_dict = self ._state_dict_v1 (** kwargs )
2226+ elif state_dict_format == 2 :
2227+ # Default v2 format
2228+ state_dict = self ._state_dict_v2 ()
2229+ else :
2230+ # Unrecognized format
2231+ raise ValueError (f"Unrecognized state dict format ({ state_dict_format } )" )
2232+
2233+ # Add format tag to state dict
2234+ if state_dict is not None :
2235+ state_dict ["format" ] = state_dict_format
22202236
2221- # Default v2 format
2222- return self ._state_dict_v2 ()
2237+ return state_dict
22232238
22242239 def _state_dict_v1 (self , gather_on_root : bool = True ) -> Optional [dict ]:
22252240 """Get dictionary containing optimizer state (deprecated v1 format)
@@ -2381,11 +2396,11 @@ def _state_dict_v1(self, gather_on_root: bool = True) -> Optional[dict]:
23812396 return None
23822397
23832398 @torch .no_grad ()
2384- def _state_dict_v2 (self ) -> dict :
2399+ def _state_dict_v2 (self ) -> Optional [ dict ] :
23852400 """Get dictionary containing optimizer state (default v2 format)
23862401
23872402 Gathers optimizer state on the process group's root rank and
2388- returns empty dictionaries on non-root ranks.
2403+ returns None on non-root ranks.
23892404
23902405 """
23912406
@@ -2399,7 +2414,7 @@ def _state_dict_v2(self) -> dict:
23992414 # Return immediately on ranks with redundant data
24002415 if self .redundant_size > 1 :
24012416 if torch .distributed .get_rank (self .redundant_process_group ) > 0 :
2402- return {}
2417+ return None
24032418
24042419 # Initialize state dict on root rank
24052420 if self .distributed_rank == 0 :
@@ -2644,18 +2659,29 @@ def finish_gather(bucket_id: int, state_dict_key: str) -> None:
26442659 if self .distributed_rank == 0 :
26452660 return state_dict
26462661 else :
2647- return {}
2662+ return None
26482663
26492664 def load_state_dict (self , state_dict : dict ) -> None :
26502665 """Load optimizer state"""
26512666
2652- # Deprecated v1 format
2653- if "buckets" in state_dict ["state" ] or "gathered_states" in state_dict ["state" ]:
2654- self ._load_state_dict_v1 (state_dict )
2655- return
2667+ # Figure out state dict format
2668+ state_dict_format = state_dict .pop ("format" , None )
2669+ if state_dict_format is None :
2670+ if "buckets" in state_dict or "gathered_states" in state_dict :
2671+ state_dict_format = 1
2672+ else :
2673+ state_dict_format = 2
26562674
2657- # Default v2 format
2658- self ._load_state_dict_v2 (state_dict )
2675+ # Load state dict
2676+ if state_dict_format == 1 :
2677+ # Deprecated v1 format
2678+ self ._load_state_dict_v1 (state_dict )
2679+ elif state_dict_format == 2 :
2680+ # Default v2 format
2681+ self ._load_state_dict_v2 (state_dict )
2682+ else :
2683+ # Unrecognized format
2684+ raise ValueError (f"Unrecognized state dict format ({ state_dict_format } )" )
26592685
26602686 def _load_state_dict_v1 (self , state_dict : dict ) -> None :
26612687 """Load optimizer state (deprecated v1 format)
0 commit comments