-
-
Notifications
You must be signed in to change notification settings - Fork 67
fix ViT model output + rewrite attention layer + adapt torchvision script #230
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
256423e
d9a1334
ed97207
4ea2813
ff6c1be
35316ec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,5 @@ | ||
| channels = ["nvidia", "torch"] | ||
| channels = ["pytorch"] | ||
|
|
||
| [deps] | ||
| pytorch = "" | ||
| torchvision = "" | ||
| pytorch = ">=2,<3" | ||
| torchvision = ">=0.15" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,9 +13,10 @@ Transformer as used in the base ViT architecture. | |
| - `dropout_prob`: dropout probability | ||
| """ | ||
| function transformer_encoder(planes::Integer, depth::Integer, nheads::Integer; | ||
| mlp_ratio = 4.0, dropout_prob = 0.0) | ||
| mlp_ratio = 4.0, dropout_prob = 0.0, qkv_bias=false) | ||
| layers = [Chain(SkipConnection(prenorm(planes, | ||
| MHAttention(planes, nheads; | ||
| MultiHeadSelfAttention(planes, nheads; | ||
| qkv_bias, | ||
| attn_dropout_prob = dropout_prob, | ||
| proj_dropout_prob = dropout_prob)), | ||
| +), | ||
|
|
@@ -51,7 +52,8 @@ Creates a Vision Transformer (ViT) model. | |
| function vit(imsize::Dims{2} = (256, 256); inchannels::Integer = 3, | ||
| patch_size::Dims{2} = (16, 16), embedplanes::Integer = 768, | ||
| depth::Integer = 6, nheads::Integer = 16, mlp_ratio = 4.0, dropout_prob = 0.1, | ||
| emb_dropout_prob = 0.1, pool::Symbol = :class, nclasses::Integer = 1000) | ||
| emb_dropout_prob = 0.1, pool::Symbol = :class, nclasses::Integer = 1000, | ||
| qkv_bias = false) | ||
| @assert pool in [:class, :mean] | ||
| "Pool type must be either `:class` (class token) or `:mean` (mean pooling)" | ||
| npatches = prod(imsize .÷ patch_size) | ||
|
|
@@ -60,9 +62,9 @@ function vit(imsize::Dims{2} = (256, 256); inchannels::Integer = 3, | |
| ViPosEmbedding(embedplanes, npatches + 1), | ||
| Dropout(emb_dropout_prob), | ||
| transformer_encoder(embedplanes, depth, nheads; mlp_ratio, | ||
| dropout_prob), | ||
| dropout_prob, qkv_bias), | ||
| pool === :class ? x -> x[:, 1, :] : seconddimmean), | ||
| Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast))) | ||
| Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses))) | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this final |
||
| end | ||
|
|
||
| const VIT_CONFIGS = Dict(:tiny => (depth = 12, embedplanes = 192, nheads = 3), | ||
|
|
@@ -100,9 +102,10 @@ end | |
| @functor ViT | ||
|
|
||
| function ViT(config::Symbol; imsize::Dims{2} = (256, 256), patch_size::Dims{2} = (16, 16), | ||
| pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) | ||
| pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000, | ||
| qkv_bias=false) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unless it is typical to adjust this toggle, I think it should not get exposed going from
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had to add it since the default for torchvision is I think we should change the defaults here to match that before the tag of the breaking release, but this can be done in another PR
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, so change the default to
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, I'll do it in the next PR |
||
| _checkconfig(config, keys(VIT_CONFIGS)) | ||
| layers = vit(imsize; inchannels, patch_size, nclasses, VIT_CONFIGS[config]...) | ||
| layers = vit(imsize; inchannels, patch_size, nclasses, qkv_bias, VIT_CONFIGS[config]...) | ||
| if pretrain | ||
| loadpretrain!(layers, string("vit", config)) | ||
| end | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
made the name more informative