-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
How do I use swin-transformer to predict the category of images in my own folder and show size mismatch when I load the model?
File "D:\face\Swin-Transformer-main\test.py", line 67, in
model = load_model(CHECKPOINT_PATH)
File "D:\face\Swin-Transformer-main\test.py", line 61, in load_model
model.load_state_dict(state_dict, strict=False)
File "C:\ProgramData\Anaconda3\envs\pytorch\lib\site-packages\torch\nn\modules\module.py", line 1671, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for SwinTransformer:
size mismatch for layers.1.downsample.norm.weight: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
size mismatch for layers.1.downsample.norm.bias: copying a param with shape torch.Size([768]) from checkpoint, the shape in current model is torch.Size([384]).
size mismatch for layers.1.downsample.reduction.weight: copying a param with shape torch.Size([384, 768]) from checkpoint, the shape in current model is torch.Size([192, 384]).
size mismatch for layers.2.downsample.norm.weight: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
size mismatch for layers.2.downsample.norm.bias: copying a param with shape torch.Size([1536]) from checkpoint, the shape in current model is torch.Size([768]).
size mismatch for layers.2.downsample.reduction.weight: copying a param with shape torch.Size([768, 1536]) from checkpoint, the shape in current model is torch.Size([384, 768]).
Here is my code.
def load_model(checkpoint_path):
model = timm.create_model(
'swin_tiny_patch4_window7_224',
pretrained=False,
num_classes=7,
img_size=224,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
window_size=7,
drop_path_rate=0.2,
mlp_ratio=4.0,
qkv_bias=True,
patch_norm=True
)
checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict = checkpoint.get('model', checkpoint)