Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ Note: <sup>*</sup> indicates multi-scale testing.

(`Note please report accuracy numbers and provide trained models in your new repository to facilitate others to get sense of correctness and model behavior`)

[04/06/2022] Swin Transformer for Audio Classification: [Hierarchical Token Semantic Audio Transformer](https://github.com/RetroCirce/HTS-Audio-Transformer).

[12/21/2021] Swin Transformer for StyleGAN: [StyleSwin](https://github.com/microsoft/StyleSwin)

[12/13/2021] Swin Transformer for Face Recognition: [FaceX-Zoo](https://github.com/JDAI-CV/FaceX-Zoo)
Expand Down
46 changes: 28 additions & 18 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def load_pretrained(config, model, logger):
checkpoint = torch.load(config.MODEL.PRETRAINED, map_location='cpu')
state_dict = checkpoint['model']


# for MoBY, preplace prefixes
if sorted(list(state_dict.keys()))[0].startswith('encoder'):
state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}

# delete relative_position_index since we always re-init it
relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
for k in relative_position_index_keys:
Expand Down Expand Up @@ -105,24 +110,29 @@ def load_pretrained(config, model, logger):
state_dict[k] = absolute_pos_embed_pretrained_resized

# check classifier, if not match, then re-init classifier to zero
head_bias_pretrained = state_dict['head.bias']
Nc1 = head_bias_pretrained.shape[0]
Nc2 = model.head.bias.shape[0]
if (Nc1 != Nc2):
if Nc1 == 21841 and Nc2 == 1000:
logger.info("loading ImageNet-22K weight to ImageNet-1K ......")
map22kto1k_path = f'data/map22kto1k.txt'
with open(map22kto1k_path) as f:
map22kto1k = f.readlines()
map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]
state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :]
state_dict['head.bias'] = state_dict['head.bias'][map22kto1k]
else:
torch.nn.init.constant_(model.head.bias, 0.)
torch.nn.init.constant_(model.head.weight, 0.)
del state_dict['head.weight']
del state_dict['head.bias']
logger.warning(f"Error in loading classifier head, re-init classifier head to 0")
if ('head.bias' in state_dict):
head_bias_pretrained = state_dict['head.bias']
Nc1 = head_bias_pretrained.shape[0]
Nc2 = model.head.bias.shape[0]
if (Nc1 != Nc2):
if Nc1 == 21841 and Nc2 == 1000:
logger.info("loading ImageNet-22K weight to ImageNet-1K ......")
map22kto1k_path = f'data/map22kto1k.txt'
with open(map22kto1k_path) as f:
map22kto1k = f.readlines()
map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]
state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :]
state_dict['head.bias'] = state_dict['head.bias'][map22kto1k]
else:
torch.nn.init.constant_(model.head.bias, 0.)
torch.nn.init.constant_(model.head.weight, 0.)
del state_dict['head.weight']
del state_dict['head.bias']
logger.warning(f"Error in loading classifier head, re-init classifier head to 0")
else:
torch.nn.init.constant_(model.head.bias, 0.)
torch.nn.init.constant_(model.head.weight, 0.)
logger.warning(f"Error in loading classifier head, re-init classifier head to 0")

msg = model.load_state_dict(state_dict, strict=False)
logger.warning(msg)
Expand Down