From bd272088c33e3042ad7135e655229d615e10c46e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 14 Nov 2025 09:14:00 -0800 Subject: [PATCH] Create own base class for Feature*Net extraction wrappers to avoid FSDP issues. --- timm/models/_features.py | 74 +++++++++++++++++++++++++++++++++------- 1 file changed, 61 insertions(+), 13 deletions(-) diff --git a/timm/models/_features.py b/timm/models/_features.py index 3814869175..c398866e3d 100644 --- a/timm/models/_features.py +++ b/timm/models/_features.py @@ -20,8 +20,8 @@ from ._manipulate import checkpoint __all__ = [ - 'FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet', 'FeatureGetterNet', - 'feature_take_indices' + 'FeatureInfo', 'FeatureHooks', 'FeatureBase', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet', + 'FeatureGetterNet', 'feature_take_indices' ] @@ -227,7 +227,59 @@ def _get_return_layers(feature_info, out_map): return return_layers -class FeatureDictNet(nn.ModuleDict): +class FeatureBase(nn.Module): + """ Base class for feature extraction wrappers + + Provides dict-like interface without inheriting from nn.ModuleDict to avoid FSDP2 issues. + FSDP2's fully_shard has isinstance checks for (ModuleDict, ModuleList) that cause problems. + + This class delegates dict operations to the underlying _modules OrderedDict. + """ + + def __init__(self): + super().__init__() + self.feature_info: Optional[FeatureInfo] = None + self.output_fmt: Optional[Format] = None + self.grad_checkpointing = False + + def set_grad_checkpointing(self, enable: bool = True): + self.grad_checkpointing = enable + + # Dict-like interface methods + def __getitem__(self, key: str) -> nn.Module: + return self._modules[key] + + def __setitem__(self, key: str, module: nn.Module) -> None: + self.add_module(key, module) + + def __delitem__(self, key: str) -> None: + del self._modules[key] + + def __len__(self) -> int: + return len(self._modules) + + def __iter__(self): + return iter(self._modules) + + def __contains__(self, key: str) -> bool: + return key in self._modules + + def keys(self): + return self._modules.keys() + + def values(self): + return self._modules.values() + + def items(self): + return self._modules.items() + + def update(self, modules: Dict[str, nn.Module]) -> None: + """Update _modules with new modules.""" + for key, module in modules.items(): + self.add_module(key, module) + + +class FeatureDictNet(FeatureBase): """ Feature extractor with OrderedDict return Wrap a model and extract features as specified by the out indices, the network is @@ -264,7 +316,6 @@ def __init__( self.feature_info = _get_feature_info(model, out_indices) self.output_fmt = Format(output_fmt) self.concat = feature_concat - self.grad_checkpointing = False self.return_layers = {} return_layers = _get_return_layers(self.feature_info, out_map) @@ -283,9 +334,6 @@ def __init__( f'Return layers ({remaining}) are not present in model' self.update(layers) - def set_grad_checkpointing(self, enable: bool = True): - self.grad_checkpointing = enable - def _collect(self, x) -> (Dict[str, torch.Tensor]): out = OrderedDict() for i, (name, module) in enumerate(self.items()): @@ -345,7 +393,7 @@ def forward(self, x) -> (List[torch.Tensor]): return list(self._collect(x).values()) -class FeatureHookNet(nn.ModuleDict): +class FeatureHookNet(FeatureBase): """ FeatureHookNet Wrap a model and extract features specified by the out indices using forward/forward-pre hooks. @@ -386,7 +434,6 @@ def __init__( self.feature_info = _get_feature_info(model, out_indices) self.return_dict = return_dict self.output_fmt = Format(output_fmt) - self.grad_checkpointing = False if no_rewrite is None: no_rewrite = not flatten_sequential layers = OrderedDict() @@ -415,9 +462,6 @@ def __init__( self.update(layers) self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map) - def set_grad_checkpointing(self, enable: bool = True): - self.grad_checkpointing = enable - def forward(self, x): for i, (name, module) in enumerate(self.items()): if self.grad_checkpointing and not torch.jit.is_scripting(): @@ -432,7 +476,7 @@ def forward(self, x): return out if self.return_dict else list(out.values()) -class FeatureGetterNet(nn.ModuleDict): +class FeatureGetterNet(FeatureBase): """ FeatureGetterNet Wrap models with a feature getter method, like 'get_intermediate_layers' @@ -472,6 +516,10 @@ def __init__( self.output_fmt = Format(output_fmt) self.norm = norm + def set_grad_checkpointing(self, enable: bool = True): + self.grad_checkpointing = enable + self.model.set_grad_checkpointing(enable=enable) + def forward(self, x): features = self.model.forward_intermediates( x,