Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Update TNT-(S/B) model weights and add feature extraction support #2480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

brianhou0208
Copy link
Contributor

@brianhou0208 brianhou0208 commented May 3, 2025

The main updates are as follows:

  1. Update the weight TNT-(S/B) implemented with the official PyTorch
  2. The original implementation is different from the official one (TNT Block & PixelEmbed), and the legacy parameter is used to maintain backward compatibility.
  3. Support features_only parameters and forward_intermediates function

Example

model = timm.create_model(f'tnt_s_patch16_224', pretrained=True)
model = timm.create_model(f'tnt_b_patch16_224', pretrained=True)

model = timm.create_model(f'tnt_s_patch16_224', pretrained=False, legacy=True)
ckpt = torch.load('/path/to/original_weight.pth', map_location='cpu')
model.load_state_dict(ckpt)

Example(forward_intermediates)

model = timm.create_model(f'tnt_s_patch16_224', pretrained=True).eval()
output, intermediates = model.forward_intermediates(torch.randn(2,3,224,224))
for i, o in enumerate(intermediates):
    print(f'Feat index: {i}, shape: {o.shape}')
Feat index: 0, shape: torch.Size([2, 384, 14, 14])
Feat index: 1, shape: torch.Size([2, 384, 14, 14])
Feat index: 2, shape: torch.Size([2, 384, 14, 14])
Feat index: 3, shape: torch.Size([2, 384, 14, 14])
Feat index: 4, shape: torch.Size([2, 384, 14, 14])
Feat index: 5, shape: torch.Size([2, 384, 14, 14])
Feat index: 6, shape: torch.Size([2, 384, 14, 14])
Feat index: 7, shape: torch.Size([2, 384, 14, 14])
Feat index: 8, shape: torch.Size([2, 384, 14, 14])
Feat index: 9, shape: torch.Size([2, 384, 14, 14])
Feat index: 10, shape: torch.Size([2, 384, 14, 14])
Feat index: 11, shape: torch.Size([2, 384, 14, 14])

Example(features_only)

model = timm.create_model('tnt_s_patch16_224', features_only=True, pretrained=True)
print(f'Feature channels: {model.feature_info.channels()}')
print(f'Feature reduction: {model.feature_info.reduction()}')
output = model(torch.randn(2, 3, 224, 224))
for x in output:
    print(x.shape)
Feature channels: [384, 384, 384]
Feature reduction: [16, 16, 16]
torch.Size([2, 384, 14, 14])

Result

Model FLOPs MACs Params ACC@1 ACC@5 ckpt
tnt_b_patch16_224 28.11G 14.06G 65.43M 82.872 96.224 link
##original 65.41M no weight link
tnt_s_patch16_224 10.44G 5.22G 23.77M 81.526 95.760 link
##original 23.76M 81.528 95.734 link
test code
from typing import Any, Dict, Union, List
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import timm
from timm.utils.metrics import AverageMeter, accuracy

device = torch.device('mps')
torch.mps.empty_cache()

def auto_unit(x: float, unit: str = '') -> str:
    if x >= 1e9:
        return f"{x / 1e9:.2f}G {unit}"
    elif x >= 1e6:
        return f"{x / 1e6:.2f}M {unit}"
    elif x >= 1e3:
        return f"{x / 1e3:.2f}K {unit}"
    else:
        return f"{x:.2f} {unit}"
 
 
 ef get_model_info(model: torch.nn.Module, imgsz: Union[int, List[int]] = 224) -> Dict[str, str]:
    """
    Compute model FLOPs, MACs, and Params using torch profiler.

    Args:
        model (nn.Module): The model to calculate for.
        imgsz (int | List[int], optional): Input image size. Defaults to 224.

    Returns:
        dict: Dictionary containing FLOPs, MACs, and Params with auto units.
    """
    p = next(model.parameters())
    if not isinstance(imgsz, list):
        imgsz = [imgsz, imgsz]

    im = torch.empty((1, 3, *imgsz), device=p.device)

    with torch.profiler.profile(with_flops=True) as prof:
        model(im)

    flops = sum(e.flops for e in prof.key_averages())
    macs = flops / 2
    params = sum(p.numel() for p in model.parameters())

    return {
        "FLOPs": auto_unit(flops, ""),
        "MACs": auto_unit(macs, ""),
        "Params": auto_unit(params, ""),
    }


def get_model_acc(model: torch.nn.Module):
    cfg: Dict[str, Any]= model.default_cfg
    _, height, width = cfg['input_size'] if 'test_input_size' not in cfg else cfg['test_input_size']
    imgsz = height if height == width else (height, width)

    interp_mode = {
        "nearest": 0,
        "bilinear": 2,
        "bicubic": 3,
    }

    val_dataset = datasets.ImageFolder(
        '/Users/ryanhou/Downloads/imagenet/val',
        transforms.Compose([
            transforms.Resize(int(imgsz / cfg['crop_pct']), interpolation=interp_mode[cfg['interpolation']]),
            transforms.CenterCrop(imgsz),
            transforms.ToTensor(),
            transforms.Normalize(cfg['mean'], cfg['std'])])
    )
    val_loader = DataLoader(
        val_dataset, batch_size=64, shuffle=False, pin_memory=False, prefetch_factor=4, num_workers=4,
        persistent_workers=True#, pin_memory_device='mps'
    )

    top1 = AverageMeter()
    top5 = AverageMeter()

    model.eval()
    model.to(device)
    torch.mps.synchronize()
    with torch.no_grad():
        for images, target in tqdm(val_loader):
            images = images.to(device)
            target = target.to(device)
            output = model(images)
            acc1, acc5 = accuracy(output, target, topk=(1, 5))
            top1.update(acc1, images.size(0))
            top5.update(acc5, images.size(0))
    torch.mps.synchronize()
    return {"ACC@1": round(top1.avg.item(), 4), "ACC@5": round(top5.avg.item(), 4)}
 
 
 if __name__ == "__main__":
    model = timm.create_model(f'tnt_s_patch16_224', pretrained=True).eval()
    result = get_model_acc(model)
    print(result)
    model = timm.create_model(f'tnt_b_patch16_224', pretrained=True).eval()
    result = get_model_acc(model)
    print(result)
>>{'ACC@1': 81.526, 'ACC@5': 95.76}
>>{'ACC@1': 82.872, 'ACC@5': 96.224}

Reference

official PyTorch implement: https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/tnt_pytorch

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

if intermediates_only:
return intermediates

patch_embed = self.norm(patch_embed)
Copy link
Contributor Author

@brianhou0208 brianhou0208 May 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @rwightman,

I'm not sure if my understanding is correct — but I think we may need a last_idx check in forward_intermediates() to align with the behavior in forward_features().

In particular, if the final output should go through the norm layer, perhaps it should be handled like this:

  • FastViT

if feat_idx == last_idx:
x = self.final_conv(x)

  • EfficientFormer

last_idx = self.num_stages - 1
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
stages = self.stages
else:
stages = self.stages[:max_index + 1]
feat_idx = 0
for feat_idx, stage in enumerate(stages):
x = stage(x)
if feat_idx < last_idx:
B, C, H, W = x.shape
if feat_idx in take_indices:
if feat_idx == last_idx:
x_inter = self.norm(x) if norm else x
intermediates.append(x_inter.reshape(B, H // 2, W // 2, -1).permute(0, 3, 1, 2))
else:
intermediates.append(x)
if intermediates_only:
return intermediates
if feat_idx == last_idx:
x = self.norm(x)

Otherwise, in cases where stop_early=True or feat_idx != last_idx, we would skip the norm layer which I assume is the intended behavior for intermediate outputs.

last_idx = len(self.blocks) - 1

for i, blk in enumerate(blocks):
      pixel_embed, patch_embed = blk(pixel_embed, patch_embed)
      if i in take_indices:
          # normalize intermediates with final norm layer if enabled
          intermediates.append(self.norm(patch_embed) if norm and i == last_idx else patch_embed)
 
if intermediates_only:
      return intermediates

if i == last_idx:
      patch_embed = self.norm(patch_embed)

return patch_embed, intermediates

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@brianhou0208 yeah, there are some compromises for models that aren't 'uniform' like ViT where it only makes sense for a final norm/conv layer to apply to the very last output and would break if you tried to force it on intermediate output. That's why some of the modles have those checks on last_idx to ensure if the stop early is used or the model was truncated that those last layers aren't applied.

For vit models, applying a final norm or projection to a block output that isn't the last actually works reasonably and gives you useable results (accuracy falls off obviously as you prune further back), so if the sizing matches that's why the final norm is applied when possible.

Other sanity check, with
feat = model.forward_features(inpt)
feat2, intermediates = model.forward_intermediates(inpt, intermediates_only=False)

torch.allclose(feat, feat2) should hold

and either one can be passed through classifier_out = model.forward_head(feat)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I should probably add that check to the forward_intermediates unit test...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants