Skip to content

Commit c9a3dc9

Browse files
committed
Add dwee model def, and weights for dwee and dpwee
1 parent ab2da45 commit c9a3dc9

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

timm/models/vision_transformer.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2703,11 +2703,14 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
27032703
'vit_wee_patch16_reg1_gap_256.sbb_in1k': _cfg(
27042704
hf_hub_id='timm/',
27052705
input_size=(3, 256, 256), crop_pct=0.95),
2706+
'vit_dwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
2707+
hf_hub_id='timm/',
2708+
input_size=(3, 256, 256), crop_pct=0.95),
27062709
'vit_pwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
27072710
hf_hub_id='timm/',
27082711
input_size=(3, 256, 256), crop_pct=0.95),
27092712
'vit_dpwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
2710-
#hf_hub_id='timm/',
2713+
hf_hub_id='timm/',
27112714
input_size=(3, 256, 256), crop_pct=0.95),
27122715
'vit_little_patch16_reg1_gap_256.sbb_in12k_ft_in1k': _cfg(
27132716
hf_hub_id='timm/',
@@ -4208,6 +4211,17 @@ def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTr
42084211
return model
42094212

42104213

4214+
@register_model
4215+
def vit_dwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
4216+
model_args = dict(
4217+
patch_size=16, embed_dim=256, depth=14, num_heads=4, init_values=1e-5, mlp_ratio=5,
4218+
class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', attn_layer='diff',
4219+
)
4220+
model = _create_vision_transformer(
4221+
'vit_dwee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
4222+
return model
4223+
4224+
42114225
@register_model
42124226
def vit_pwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
42134227
model_args = dict(
@@ -4229,6 +4243,7 @@ def vit_dpwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> Vision
42294243
'vit_dpwee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
42304244
return model
42314245

4246+
42324247
@register_model
42334248
def vit_little_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
42344249
model_args = dict(

0 commit comments

Comments
 (0)