Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions scripts/CondaPkg.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
channels = ["nvidia", "torch"]
channels = ["pytorch"]

[deps]
pytorch = ""
torchvision = ""
pytorch = ">=2,<3"
torchvision = ">=0.15"
2 changes: 2 additions & 0 deletions scripts/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
[deps]
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d"
Expand Down
49 changes: 24 additions & 25 deletions scripts/port_torchvision.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,31 @@ const tvmodels = pyimport("torchvision.models")

# name, weight, jlconstructor, pyconstructor
model_list = [
("vgg11", "IMAGENET1K_V1", () -> VGG(11), weights -> tvmodels.vgg11(weights=weights)),
("vgg13", "IMAGENET1K_V1", () -> VGG(13), weights -> tvmodels.vgg13(weights=weights)),
("vgg16", "IMAGENET1K_V1", () -> VGG(16), weights -> tvmodels.vgg16(weights=weights)),
("vgg19", "IMAGENET1K_V1", () -> VGG(19), weights -> tvmodels.vgg19(weights=weights)),
("resnet18", "IMAGENET1K_V1", () -> ResNet(18), weights -> tvmodels.resnet18(weights=weights)),
("resnet34", "IMAGENET1K_V1", () -> ResNet(34), weights -> tvmodels.resnet34(weights=weights)),
("resnet50", "IMAGENET1K_V1", () -> ResNet(50), weights -> tvmodels.resnet50(weights=weights)),
("resnet101", "IMAGENET1K_V1", () -> ResNet(101), weights -> tvmodels.resnet101(weights=weights)),
("resnet152", "IMAGENET1K_V1", () -> ResNet(152), weights -> tvmodels.resnet152(weights=weights)),
("resnet50", "IMAGENET1K_V2", () -> ResNet(50), weights -> tvmodels.resnet50(weights=weights)),
("resnet101", "IMAGENET1K_V2", () -> ResNet(101), weights -> tvmodels.resnet101(weights=weights)),
("resnet152", "IMAGENET1K_V2", () -> ResNet(152), weights -> tvmodels.resnet152(weights=weights)),
("resnext50_32x4d", "IMAGENET1K_V1", () -> ResNeXt(50; cardinality=32, base_width=4), weights -> tvmodels.resnext50_32x4d(weights=weights)),
("resnext50_32x4d", "IMAGENET1K_V2", () -> ResNeXt(50; cardinality=32, base_width=4), weights -> tvmodels.resnext50_32x4d(weights=weights)),
("resnext101_32x8d", "IMAGENET1K_V1", () -> ResNeXt(101; cardinality=32, base_width=8), weights -> tvmodels.resnext101_32x8d(weights=weights)),
("resnext101_64x4d", "IMAGENET1K_V1", () -> ResNeXt(101; cardinality=64, base_width=4), weights -> tvmodels.resnext101_64x4d(weights=weights)),
("resnext101_32x8d", "IMAGENET1K_V2", () -> ResNeXt(101; cardinality=32, base_width=8), weights -> tvmodels.resnext101_32x8d(weights=weights)),
("wide_resnet50_2", "IMAGENET1K_V1", () -> WideResNet(50), weights -> tvmodels.wide_resnet50_2(weights=weights)),
("wide_resnet50_2", "IMAGENET1K_V2", () -> WideResNet(50), weights -> tvmodels.wide_resnet50_2(weights=weights)),
("wide_resnet101_2", "IMAGENET1K_V1", () -> WideResNet(101), weights -> tvmodels.wide_resnet101_2(weights=weights)),
("wide_resnet101_2", "IMAGENET1K_V2", () -> WideResNet(101), weights -> tvmodels.wide_resnet101_2(weights=weights)),

("vgg11", "IMAGENET1K_V1", () -> VGG(11), weights -> tvmodels.vgg11(; weights)),
("vgg13", "IMAGENET1K_V1", () -> VGG(13), weights -> tvmodels.vgg13(; weights)),
("vgg16", "IMAGENET1K_V1", () -> VGG(16), weights -> tvmodels.vgg16(; weights)),
("vgg19", "IMAGENET1K_V1", () -> VGG(19), weights -> tvmodels.vgg19(; weights)),
("resnet18", "IMAGENET1K_V1", () -> ResNet(18), weights -> tvmodels.resnet18(; weights)),
("resnet34", "IMAGENET1K_V1", () -> ResNet(34), weights -> tvmodels.resnet34(; weights)),
("resnet50", "IMAGENET1K_V1", () -> ResNet(50), weights -> tvmodels.resnet50(; weights)),
("resnet101", "IMAGENET1K_V1", () -> ResNet(101), weights -> tvmodels.resnet101(; weights)),
("resnet152", "IMAGENET1K_V1", () -> ResNet(152), weights -> tvmodels.resnet152(; weights)),
("resnet50", "IMAGENET1K_V2", () -> ResNet(50), weights -> tvmodels.resnet50(; weights)),
("resnet101", "IMAGENET1K_V2", () -> ResNet(101), weights -> tvmodels.resnet101(; weights)),
("resnet152", "IMAGENET1K_V2", () -> ResNet(152), weights -> tvmodels.resnet152(; weights)),
("resnext50_32x4d", "IMAGENET1K_V1", () -> ResNeXt(50; cardinality=32, base_width=4), weights -> tvmodels.resnext50_32x4d(; weights)),
("resnext50_32x4d", "IMAGENET1K_V2", () -> ResNeXt(50; cardinality=32, base_width=4), weights -> tvmodels.resnext50_32x4d(; weights)),
("resnext101_32x8d", "IMAGENET1K_V1", () -> ResNeXt(101; cardinality=32, base_width=8), weights -> tvmodels.resnext101_32x8d(; weights)),
("resnext101_64x4d", "IMAGENET1K_V1", () -> ResNeXt(101; cardinality=64, base_width=4), weights -> tvmodels.resnext101_64x4d(; weights)),
("resnext101_32x8d", "IMAGENET1K_V2", () -> ResNeXt(101; cardinality=32, base_width=8), weights -> tvmodels.resnext101_32x8d(; weights)),
("wide_resnet50_2", "IMAGENET1K_V1", () -> WideResNet(50), weights -> tvmodels.wide_resnet50_2(; weights)),
("wide_resnet50_2", "IMAGENET1K_V2", () -> WideResNet(50), weights -> tvmodels.wide_resnet50_2(; weights)),
("wide_resnet101_2", "IMAGENET1K_V1", () -> WideResNet(101), weights -> tvmodels.wide_resnet101_2(; ; weights)),
("wide_resnet101_2", "IMAGENET1K_V2", () -> WideResNet(101), weights -> tvmodels.wide_resnet101_2(; weights)),
("vit_b_16", "IMAGENET1K_V1", () -> ViT(:base, imsize=(224,224), qkv_bias=true), weights -> tvmodels.vit_b_16(; weights)),
## NOT MATCHING BELOW
# ("squeezenet1_0", "IMAGENET1K_V1", () -> SqueezeNet(), weights -> tvmodels.squeezenet1_0(weights=weights)),
# ("densenet121", "IMAGENET1K_V1", () -> DenseNet(121), weights -> tvmodels.densenet121(weights=weights)),
# ("squeezenet1_0", "IMAGENET1K_V1", () -> SqueezeNet(), weights -> tvmodels.squeezenet1_0(; weights)),
# ("densenet121", "IMAGENET1K_V1", () -> DenseNet(121), weights -> tvmodels.densenet121(; weights)),
]


Expand All @@ -44,4 +44,3 @@ for (name, weights, jlconstructor, pyconstructor) in model_list
BSON.@save joinpath(@__DIR__, "$(name)_$weights.bson") model=jlmodel
println("Saved $(name)_$weights.bson")
end

41 changes: 41 additions & 0 deletions scripts/pytorch2flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,41 @@ function _list_state(node::Dense, channel, prefix)
end
end

function _list_state(node::Metalhead.Layers.ClassTokens, channel, prefix)
put!(channel, (prefix * ".classtoken", node.token))
end

function _list_state(node::Metalhead.Layers.ViPosEmbedding, channel, prefix)
put!(channel, (prefix * ".posembedding", node.vectors))
end

function _list_state(node::LayerNorm, channel, prefix)
put!(channel, (prefix * ".layernorm_scale", node.diag.scale))
put!(channel, (prefix * ".layernorm_bias", node.diag.bias))
end

function _list_state(node::Metalhead.Layers.LayerNormV2, channel, prefix)
put!(channel, (prefix * ".layernorm_scale", node.diag.scale))
put!(channel, (prefix * ".layernorm_bias", node.diag.bias))
end

function _list_state(node::Metalhead.Layers.MultiHeadSelfAttention, channel, prefix)
_list_state(node.qkv_layer, channel, prefix * ".qkv")
_list_state(node.projection, channel, prefix * ".proj")
end

function _list_state(node::Chain, channel, prefix)
for (i, n) in enumerate(node.layers)
_list_state(n, channel, prefix * ".layers[$i]")
end
end

function _list_state(node::SkipConnection, channel, prefix)
for (i, n) in enumerate(node.layers)
_list_state(n, channel, prefix * ".layers[$i]")
end
end

function _list_state(node::Parallel, channel, prefix)
# reverse to match PyTorch order, see https://github.com/FluxML/Metalhead.jl/issues/228
for (i, n) in enumerate(reverse(node.layers))
Expand All @@ -102,6 +131,18 @@ function pytorch2flux!(jlmodel, pymodel; verb=false)
state_dict = pymodel.state_dict()
pystate = OrderedDict((py2jl(k), th2jl(v)) for (k, v) in state_dict.items() if
!occursin("num_batches_tracked", py2jl(k)))

jlkeys = collect(keys(jlstate))
pykeys = collect(keys(pystate))

## handle class_token since it is not in the same order
jl_k = findfirst(k -> occursin("classtoken", k), jlkeys)
py_k = findfirst(k -> occursin("class_token", k), pykeys)
if jl_k !== nothing && py_k !== nothing
jlstate[jlkeys[jl_k]] .= pystate[pykeys[py_k]]
delete!(pystate, pykeys[py_k])
delete!(jlstate, jlkeys[jl_k])
end

for ((flux_key, flux_param), (pytorch_key, pytorch_param)) in zip(jlstate, pystate)
# @show flux_key size(flux_param) pytorch_key size(pytorch_param)
Expand Down
7 changes: 6 additions & 1 deletion scripts/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,14 @@ function np2jl(x::Py)
end

function th2jl(x::Py)
x_jl = pyconvert(Array, x)
x_jl = pyconvert(Array, x.detach().numpy())
x_jl = permutedims(x_jl, ndims(x_jl):-1:1)
return x_jl
end

py2jl(x::Py) = pyconvert(Any, x)


## SAVE STATE
using Functors
state_arrays(x) = fmapstructure(x -> x isa AbstractArray ? x : missing, x)
2 changes: 1 addition & 1 deletion src/layers/Layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import Flux.testmode!
include("../utilities.jl")

include("attention.jl")
export MHAttention
export MultiHeadSelfAttention

include("conv.jl")
export conv_norm, basic_conv_bn, dwsep_conv_norm
Expand Down
34 changes: 11 additions & 23 deletions src/layers/attention.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false,
MultiHeadSelfAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

made the name more informative

attn_dropout_prob = 0., proj_dropout_prob = 0.)

Multi-head self-attention layer.
Expand All @@ -12,39 +12,27 @@ Multi-head self-attention layer.
- `attn_dropout_prob`: dropout probability after the self-attention layer
- `proj_dropout_prob`: dropout probability after the projection layer
"""
struct MHAttention{P, Q, R}
struct MultiHeadSelfAttention{P, Q, R}
nheads::Int
qkv_layer::P
attn_drop::Q
projection::R
end
@functor MHAttention
@functor MultiHeadSelfAttention

function MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false,
function MultiHeadSelfAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false,
attn_dropout_prob = 0.0, proj_dropout_prob = 0.0)
@assert planes % nheads==0 "planes should be divisible by nheads"
qkv_layer = Dense(planes, planes * 3; bias = qkv_bias)
attn_drop = Dropout(attn_dropout_prob)
proj = Chain(Dense(planes, planes), Dropout(proj_dropout_prob))
return MHAttention(nheads, qkv_layer, attn_drop, proj)
return MultiHeadSelfAttention(nheads, qkv_layer, attn_drop, proj)
end

function (m::MHAttention)(x::AbstractArray{T, 3}) where {T}
nfeatures, seq_len, batch_size = size(x)
x_reshaped = reshape(x, nfeatures, seq_len * batch_size)
qkv = m.qkv_layer(x_reshaped)
qkv_reshaped = reshape(qkv, nfeatures ÷ m.nheads, m.nheads, seq_len, 3 * batch_size)
query, key, value = chunk(qkv_reshaped, 3; dims = 4)
scale = convert(T, sqrt(size(query, 1) / m.nheads))
key_reshaped = reshape(permutedims(key, (2, 1, 3, 4)), m.nheads, nfeatures ÷ m.nheads,
seq_len * batch_size)
query_reshaped = reshape(permutedims(query, (1, 2, 3, 4)), nfeatures ÷ m.nheads,
m.nheads, seq_len * batch_size)
attention = m.attn_drop(softmax(batched_mul(query_reshaped, key_reshaped) .* scale))
value_reshaped = reshape(permutedims(value, (1, 2, 3, 4)), nfeatures ÷ m.nheads,
m.nheads, seq_len * batch_size)
pre_projection = reshape(batched_mul(attention, value_reshaped),
(nfeatures, seq_len, batch_size))
y = m.projection(reshape(pre_projection, size(pre_projection, 1), :))
return reshape(y, :, seq_len, batch_size)
function (m::MultiHeadSelfAttention)(x::AbstractArray{<:Number, 3})
qkv = m.qkv_layer(x)
q, k, v = chunk(qkv, 3, dims = 1)
y, α = NNlib.dot_product_attention(q, k, v; m.nheads, fdrop = m.attn_drop)
y = m.projection(y)
return y
end
42 changes: 42 additions & 0 deletions src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,45 @@ function ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1.0f-6)
end

(m::ChannelLayerNorm)(x) = m.diag(Flux.normalise(x; dims = ndims(x) - 1, ϵ = m.ϵ))


"""
LayerNormV2(size..., λ=identity; affine=true, eps=1f-5)

Same as Flux's LayerNorm but eps is added before taking the square root in the denominator.
Therefore, LayerNormV2 matches pytorch's LayerNorm.
"""
struct LayerNormV2{F,D,T,N}
λ::F
diag::D
ϵ::T
size::NTuple{N,Int}
affine::Bool
end

function LayerNormV2(size::Tuple{Vararg{Int}}, λ=identity; affine::Bool=true, eps::Real=1f-5)
diag = affine ? Flux.Scale(size..., λ) : λ!=identity ? Base.Fix1(broadcast, λ) : identity
return LayerNormV2(λ, diag, eps, size, affine)
end
LayerNormV2(size::Integer...; kw...) = LayerNormV2(Int.(size); kw...)
LayerNormV2(size_act...; kw...) = LayerNormV2(Int.(size_act[1:end-1]), size_act[end]; kw...)

@functor LayerNormV2

function (a::LayerNormV2)(x::AbstractArray)
eps = convert(float(eltype(x)), a.ϵ) # avoids promotion for Float16 data, but should ε chage too?
a.diag(_normalise(x; dims=1:length(a.size), eps))
end

function Base.show(io::IO, l::LayerNormV2)
print(io, "LayerNormV2(", join(l.size, ", "))
l.λ === identity || print(io, ", ", l.λ)
Flux.hasaffine(l) || print(io, ", affine=false")
print(io, ")")
end

@inline function _normalise(x::AbstractArray; dims=ndims(x), eps=Flux.ofeltype(x, 1e-5))
μ = mean(x, dims=dims)
σ² = var(x, dims=dims, mean=μ, corrected=false)
return @. (x - μ) / sqrt(σ² + eps)
end
17 changes: 10 additions & 7 deletions src/vit-based/vit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ Transformer as used in the base ViT architecture.
- `dropout_prob`: dropout probability
"""
function transformer_encoder(planes::Integer, depth::Integer, nheads::Integer;
mlp_ratio = 4.0, dropout_prob = 0.0)
mlp_ratio = 4.0, dropout_prob = 0.0, qkv_bias=false)
layers = [Chain(SkipConnection(prenorm(planes,
MHAttention(planes, nheads;
MultiHeadSelfAttention(planes, nheads;
qkv_bias,
attn_dropout_prob = dropout_prob,
proj_dropout_prob = dropout_prob)),
+),
Expand Down Expand Up @@ -51,7 +52,8 @@ Creates a Vision Transformer (ViT) model.
function vit(imsize::Dims{2} = (256, 256); inchannels::Integer = 3,
patch_size::Dims{2} = (16, 16), embedplanes::Integer = 768,
depth::Integer = 6, nheads::Integer = 16, mlp_ratio = 4.0, dropout_prob = 0.1,
emb_dropout_prob = 0.1, pool::Symbol = :class, nclasses::Integer = 1000)
emb_dropout_prob = 0.1, pool::Symbol = :class, nclasses::Integer = 1000,
qkv_bias = false)
@assert pool in [:class, :mean]
"Pool type must be either `:class` (class token) or `:mean` (mean pooling)"
npatches = prod(imsize .÷ patch_size)
Expand All @@ -60,9 +62,9 @@ function vit(imsize::Dims{2} = (256, 256); inchannels::Integer = 3,
ViPosEmbedding(embedplanes, npatches + 1),
Dropout(emb_dropout_prob),
transformer_encoder(embedplanes, depth, nheads; mlp_ratio,
dropout_prob),
dropout_prob, qkv_bias),
pool === :class ? x -> x[:, 1, :] : seconddimmean),
Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast)))
Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses)))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this final tanh had no reason to exist

end

const VIT_CONFIGS = Dict(:tiny => (depth = 12, embedplanes = 192, nheads = 3),
Expand Down Expand Up @@ -100,9 +102,10 @@ end
@functor ViT

function ViT(config::Symbol; imsize::Dims{2} = (256, 256), patch_size::Dims{2} = (16, 16),
pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000)
pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000,
qkv_bias=false)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unless it is typical to adjust this toggle, I think it should not get exposed going from vit to ViT. The logic with the codebase has been to make the uppercase exports as simple as possible.

Copy link
Member Author

@CarloLucibello CarloLucibello May 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to add it since the default for torchvision is true, here is false. The torchvision model is given by

ViT(:base, imsize=(224,224), qkv_bias=true)

I think we should change the defaults here to match that before the tag of the breaking release, but this can be done in another PR

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so change the default to true and remove the keyword? I assume you almost always want it as true.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I'll do it in the next PR

_checkconfig(config, keys(VIT_CONFIGS))
layers = vit(imsize; inchannels, patch_size, nclasses, VIT_CONFIGS[config]...)
layers = vit(imsize; inchannels, patch_size, nclasses, qkv_bias, VIT_CONFIGS[config]...)
if pretrain
loadpretrain!(layers, string("vit", config))
end
Expand Down