diff --git a/scripts/CondaPkg.toml b/scripts/CondaPkg.toml index 5b167a017..47da7bc4a 100644 --- a/scripts/CondaPkg.toml +++ b/scripts/CondaPkg.toml @@ -1,5 +1,5 @@ -channels = ["nvidia", "torch"] +channels = ["pytorch"] [deps] -pytorch = "" -torchvision = "" +pytorch = ">=2,<3" +torchvision = ">=0.15" diff --git a/scripts/Project.toml b/scripts/Project.toml index b372d4497..c5803ce83 100644 --- a/scripts/Project.toml +++ b/scripts/Project.toml @@ -1,6 +1,8 @@ [deps] BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0" +DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" diff --git a/scripts/port_torchvision.jl b/scripts/port_torchvision.jl index 92e2f0b9c..044689fd8 100644 --- a/scripts/port_torchvision.jl +++ b/scripts/port_torchvision.jl @@ -7,31 +7,31 @@ const tvmodels = pyimport("torchvision.models") # name, weight, jlconstructor, pyconstructor model_list = [ - ("vgg11", "IMAGENET1K_V1", () -> VGG(11), weights -> tvmodels.vgg11(weights=weights)), - ("vgg13", "IMAGENET1K_V1", () -> VGG(13), weights -> tvmodels.vgg13(weights=weights)), - ("vgg16", "IMAGENET1K_V1", () -> VGG(16), weights -> tvmodels.vgg16(weights=weights)), - ("vgg19", "IMAGENET1K_V1", () -> VGG(19), weights -> tvmodels.vgg19(weights=weights)), - ("resnet18", "IMAGENET1K_V1", () -> ResNet(18), weights -> tvmodels.resnet18(weights=weights)), - ("resnet34", "IMAGENET1K_V1", () -> ResNet(34), weights -> tvmodels.resnet34(weights=weights)), - ("resnet50", "IMAGENET1K_V1", () -> ResNet(50), weights -> tvmodels.resnet50(weights=weights)), - ("resnet101", "IMAGENET1K_V1", () -> ResNet(101), weights -> tvmodels.resnet101(weights=weights)), - ("resnet152", "IMAGENET1K_V1", () -> ResNet(152), weights -> tvmodels.resnet152(weights=weights)), - ("resnet50", "IMAGENET1K_V2", () -> ResNet(50), weights -> tvmodels.resnet50(weights=weights)), - ("resnet101", "IMAGENET1K_V2", () -> ResNet(101), weights -> tvmodels.resnet101(weights=weights)), - ("resnet152", "IMAGENET1K_V2", () -> ResNet(152), weights -> tvmodels.resnet152(weights=weights)), - ("resnext50_32x4d", "IMAGENET1K_V1", () -> ResNeXt(50; cardinality=32, base_width=4), weights -> tvmodels.resnext50_32x4d(weights=weights)), - ("resnext50_32x4d", "IMAGENET1K_V2", () -> ResNeXt(50; cardinality=32, base_width=4), weights -> tvmodels.resnext50_32x4d(weights=weights)), - ("resnext101_32x8d", "IMAGENET1K_V1", () -> ResNeXt(101; cardinality=32, base_width=8), weights -> tvmodels.resnext101_32x8d(weights=weights)), - ("resnext101_64x4d", "IMAGENET1K_V1", () -> ResNeXt(101; cardinality=64, base_width=4), weights -> tvmodels.resnext101_64x4d(weights=weights)), - ("resnext101_32x8d", "IMAGENET1K_V2", () -> ResNeXt(101; cardinality=32, base_width=8), weights -> tvmodels.resnext101_32x8d(weights=weights)), - ("wide_resnet50_2", "IMAGENET1K_V1", () -> WideResNet(50), weights -> tvmodels.wide_resnet50_2(weights=weights)), - ("wide_resnet50_2", "IMAGENET1K_V2", () -> WideResNet(50), weights -> tvmodels.wide_resnet50_2(weights=weights)), - ("wide_resnet101_2", "IMAGENET1K_V1", () -> WideResNet(101), weights -> tvmodels.wide_resnet101_2(weights=weights)), - ("wide_resnet101_2", "IMAGENET1K_V2", () -> WideResNet(101), weights -> tvmodels.wide_resnet101_2(weights=weights)), - + ("vgg11", "IMAGENET1K_V1", () -> VGG(11), weights -> tvmodels.vgg11(; weights)), + ("vgg13", "IMAGENET1K_V1", () -> VGG(13), weights -> tvmodels.vgg13(; weights)), + ("vgg16", "IMAGENET1K_V1", () -> VGG(16), weights -> tvmodels.vgg16(; weights)), + ("vgg19", "IMAGENET1K_V1", () -> VGG(19), weights -> tvmodels.vgg19(; weights)), + ("resnet18", "IMAGENET1K_V1", () -> ResNet(18), weights -> tvmodels.resnet18(; weights)), + ("resnet34", "IMAGENET1K_V1", () -> ResNet(34), weights -> tvmodels.resnet34(; weights)), + ("resnet50", "IMAGENET1K_V1", () -> ResNet(50), weights -> tvmodels.resnet50(; weights)), + ("resnet101", "IMAGENET1K_V1", () -> ResNet(101), weights -> tvmodels.resnet101(; weights)), + ("resnet152", "IMAGENET1K_V1", () -> ResNet(152), weights -> tvmodels.resnet152(; weights)), + ("resnet50", "IMAGENET1K_V2", () -> ResNet(50), weights -> tvmodels.resnet50(; weights)), + ("resnet101", "IMAGENET1K_V2", () -> ResNet(101), weights -> tvmodels.resnet101(; weights)), + ("resnet152", "IMAGENET1K_V2", () -> ResNet(152), weights -> tvmodels.resnet152(; weights)), + ("resnext50_32x4d", "IMAGENET1K_V1", () -> ResNeXt(50; cardinality=32, base_width=4), weights -> tvmodels.resnext50_32x4d(; weights)), + ("resnext50_32x4d", "IMAGENET1K_V2", () -> ResNeXt(50; cardinality=32, base_width=4), weights -> tvmodels.resnext50_32x4d(; weights)), + ("resnext101_32x8d", "IMAGENET1K_V1", () -> ResNeXt(101; cardinality=32, base_width=8), weights -> tvmodels.resnext101_32x8d(; weights)), + ("resnext101_64x4d", "IMAGENET1K_V1", () -> ResNeXt(101; cardinality=64, base_width=4), weights -> tvmodels.resnext101_64x4d(; weights)), + ("resnext101_32x8d", "IMAGENET1K_V2", () -> ResNeXt(101; cardinality=32, base_width=8), weights -> tvmodels.resnext101_32x8d(; weights)), + ("wide_resnet50_2", "IMAGENET1K_V1", () -> WideResNet(50), weights -> tvmodels.wide_resnet50_2(; weights)), + ("wide_resnet50_2", "IMAGENET1K_V2", () -> WideResNet(50), weights -> tvmodels.wide_resnet50_2(; weights)), + ("wide_resnet101_2", "IMAGENET1K_V1", () -> WideResNet(101), weights -> tvmodels.wide_resnet101_2(; ; weights)), + ("wide_resnet101_2", "IMAGENET1K_V2", () -> WideResNet(101), weights -> tvmodels.wide_resnet101_2(; weights)), + ("vit_b_16", "IMAGENET1K_V1", () -> ViT(:base, imsize=(224,224), qkv_bias=true), weights -> tvmodels.vit_b_16(; weights)), ## NOT MATCHING BELOW - # ("squeezenet1_0", "IMAGENET1K_V1", () -> SqueezeNet(), weights -> tvmodels.squeezenet1_0(weights=weights)), - # ("densenet121", "IMAGENET1K_V1", () -> DenseNet(121), weights -> tvmodels.densenet121(weights=weights)), + # ("squeezenet1_0", "IMAGENET1K_V1", () -> SqueezeNet(), weights -> tvmodels.squeezenet1_0(; weights)), + # ("densenet121", "IMAGENET1K_V1", () -> DenseNet(121), weights -> tvmodels.densenet121(; weights)), ] @@ -44,4 +44,3 @@ for (name, weights, jlconstructor, pyconstructor) in model_list BSON.@save joinpath(@__DIR__, "$(name)_$weights.bson") model=jlmodel println("Saved $(name)_$weights.bson") end - diff --git a/scripts/pytorch2flux.jl b/scripts/pytorch2flux.jl index 44aa55336..cad47e6e6 100644 --- a/scripts/pytorch2flux.jl +++ b/scripts/pytorch2flux.jl @@ -75,12 +75,41 @@ function _list_state(node::Dense, channel, prefix) end end +function _list_state(node::Metalhead.Layers.ClassTokens, channel, prefix) + put!(channel, (prefix * ".classtoken", node.token)) +end + +function _list_state(node::Metalhead.Layers.ViPosEmbedding, channel, prefix) + put!(channel, (prefix * ".posembedding", node.vectors)) +end + +function _list_state(node::LayerNorm, channel, prefix) + put!(channel, (prefix * ".layernorm_scale", node.diag.scale)) + put!(channel, (prefix * ".layernorm_bias", node.diag.bias)) +end + +function _list_state(node::Metalhead.Layers.LayerNormV2, channel, prefix) + put!(channel, (prefix * ".layernorm_scale", node.diag.scale)) + put!(channel, (prefix * ".layernorm_bias", node.diag.bias)) +end + +function _list_state(node::Metalhead.Layers.MultiHeadSelfAttention, channel, prefix) + _list_state(node.qkv_layer, channel, prefix * ".qkv") + _list_state(node.projection, channel, prefix * ".proj") +end + function _list_state(node::Chain, channel, prefix) for (i, n) in enumerate(node.layers) _list_state(n, channel, prefix * ".layers[$i]") end end +function _list_state(node::SkipConnection, channel, prefix) + for (i, n) in enumerate(node.layers) + _list_state(n, channel, prefix * ".layers[$i]") + end +end + function _list_state(node::Parallel, channel, prefix) # reverse to match PyTorch order, see https://github.com/FluxML/Metalhead.jl/issues/228 for (i, n) in enumerate(reverse(node.layers)) @@ -102,6 +131,18 @@ function pytorch2flux!(jlmodel, pymodel; verb=false) state_dict = pymodel.state_dict() pystate = OrderedDict((py2jl(k), th2jl(v)) for (k, v) in state_dict.items() if !occursin("num_batches_tracked", py2jl(k))) + + jlkeys = collect(keys(jlstate)) + pykeys = collect(keys(pystate)) + + ## handle class_token since it is not in the same order + jl_k = findfirst(k -> occursin("classtoken", k), jlkeys) + py_k = findfirst(k -> occursin("class_token", k), pykeys) + if jl_k !== nothing && py_k !== nothing + jlstate[jlkeys[jl_k]] .= pystate[pykeys[py_k]] + delete!(pystate, pykeys[py_k]) + delete!(jlstate, jlkeys[jl_k]) + end for ((flux_key, flux_param), (pytorch_key, pytorch_param)) in zip(jlstate, pystate) # @show flux_key size(flux_param) pytorch_key size(pytorch_param) diff --git a/scripts/utils.jl b/scripts/utils.jl index 75b9ad08a..78d0ba1fc 100644 --- a/scripts/utils.jl +++ b/scripts/utils.jl @@ -25,9 +25,14 @@ function np2jl(x::Py) end function th2jl(x::Py) - x_jl = pyconvert(Array, x) + x_jl = pyconvert(Array, x.detach().numpy()) x_jl = permutedims(x_jl, ndims(x_jl):-1:1) return x_jl end py2jl(x::Py) = pyconvert(Any, x) + + +## SAVE STATE +using Functors +state_arrays(x) = fmapstructure(x -> x isa AbstractArray ? x : missing, x) diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index d395ccbc9..7fb19dab8 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -16,7 +16,7 @@ import Flux.testmode! include("../utilities.jl") include("attention.jl") -export MHAttention +export MultiHeadSelfAttention include("conv.jl") export conv_norm, basic_conv_bn, dwsep_conv_norm diff --git a/src/layers/attention.jl b/src/layers/attention.jl index b8560f12f..4c6787d93 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -1,5 +1,5 @@ """ - MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, + MultiHeadSelfAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, attn_dropout_prob = 0., proj_dropout_prob = 0.) Multi-head self-attention layer. @@ -12,39 +12,27 @@ Multi-head self-attention layer. - `attn_dropout_prob`: dropout probability after the self-attention layer - `proj_dropout_prob`: dropout probability after the projection layer """ -struct MHAttention{P, Q, R} +struct MultiHeadSelfAttention{P, Q, R} nheads::Int qkv_layer::P attn_drop::Q projection::R end -@functor MHAttention +@functor MultiHeadSelfAttention -function MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, +function MultiHeadSelfAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, attn_dropout_prob = 0.0, proj_dropout_prob = 0.0) @assert planes % nheads==0 "planes should be divisible by nheads" qkv_layer = Dense(planes, planes * 3; bias = qkv_bias) attn_drop = Dropout(attn_dropout_prob) proj = Chain(Dense(planes, planes), Dropout(proj_dropout_prob)) - return MHAttention(nheads, qkv_layer, attn_drop, proj) + return MultiHeadSelfAttention(nheads, qkv_layer, attn_drop, proj) end -function (m::MHAttention)(x::AbstractArray{T, 3}) where {T} - nfeatures, seq_len, batch_size = size(x) - x_reshaped = reshape(x, nfeatures, seq_len * batch_size) - qkv = m.qkv_layer(x_reshaped) - qkv_reshaped = reshape(qkv, nfeatures ÷ m.nheads, m.nheads, seq_len, 3 * batch_size) - query, key, value = chunk(qkv_reshaped, 3; dims = 4) - scale = convert(T, sqrt(size(query, 1) / m.nheads)) - key_reshaped = reshape(permutedims(key, (2, 1, 3, 4)), m.nheads, nfeatures ÷ m.nheads, - seq_len * batch_size) - query_reshaped = reshape(permutedims(query, (1, 2, 3, 4)), nfeatures ÷ m.nheads, - m.nheads, seq_len * batch_size) - attention = m.attn_drop(softmax(batched_mul(query_reshaped, key_reshaped) .* scale)) - value_reshaped = reshape(permutedims(value, (1, 2, 3, 4)), nfeatures ÷ m.nheads, - m.nheads, seq_len * batch_size) - pre_projection = reshape(batched_mul(attention, value_reshaped), - (nfeatures, seq_len, batch_size)) - y = m.projection(reshape(pre_projection, size(pre_projection, 1), :)) - return reshape(y, :, seq_len, batch_size) +function (m::MultiHeadSelfAttention)(x::AbstractArray{<:Number, 3}) + qkv = m.qkv_layer(x) + q, k, v = chunk(qkv, 3, dims = 1) + y, α = NNlib.dot_product_attention(q, k, v; m.nheads, fdrop = m.attn_drop) + y = m.projection(y) + return y end diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index bb83f042d..25b3af374 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -24,3 +24,45 @@ function ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1.0f-6) end (m::ChannelLayerNorm)(x) = m.diag(Flux.normalise(x; dims = ndims(x) - 1, ϵ = m.ϵ)) + + +""" + LayerNormV2(size..., λ=identity; affine=true, eps=1f-5) + +Same as Flux's LayerNorm but eps is added before taking the square root in the denominator. +Therefore, LayerNormV2 matches pytorch's LayerNorm. +""" +struct LayerNormV2{F,D,T,N} + λ::F + diag::D + ϵ::T + size::NTuple{N,Int} + affine::Bool +end + +function LayerNormV2(size::Tuple{Vararg{Int}}, λ=identity; affine::Bool=true, eps::Real=1f-5) + diag = affine ? Flux.Scale(size..., λ) : λ!=identity ? Base.Fix1(broadcast, λ) : identity + return LayerNormV2(λ, diag, eps, size, affine) +end +LayerNormV2(size::Integer...; kw...) = LayerNormV2(Int.(size); kw...) +LayerNormV2(size_act...; kw...) = LayerNormV2(Int.(size_act[1:end-1]), size_act[end]; kw...) + +@functor LayerNormV2 + +function (a::LayerNormV2)(x::AbstractArray) + eps = convert(float(eltype(x)), a.ϵ) # avoids promotion for Float16 data, but should ε chage too? + a.diag(_normalise(x; dims=1:length(a.size), eps)) +end + +function Base.show(io::IO, l::LayerNormV2) + print(io, "LayerNormV2(", join(l.size, ", ")) + l.λ === identity || print(io, ", ", l.λ) + Flux.hasaffine(l) || print(io, ", affine=false") + print(io, ")") +end + +@inline function _normalise(x::AbstractArray; dims=ndims(x), eps=Flux.ofeltype(x, 1e-5)) + μ = mean(x, dims=dims) + σ² = var(x, dims=dims, mean=μ, corrected=false) + return @. (x - μ) / sqrt(σ² + eps) + end diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index 02c57941f..8d9342120 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -13,9 +13,10 @@ Transformer as used in the base ViT architecture. - `dropout_prob`: dropout probability """ function transformer_encoder(planes::Integer, depth::Integer, nheads::Integer; - mlp_ratio = 4.0, dropout_prob = 0.0) + mlp_ratio = 4.0, dropout_prob = 0.0, qkv_bias=false) layers = [Chain(SkipConnection(prenorm(planes, - MHAttention(planes, nheads; + MultiHeadSelfAttention(planes, nheads; + qkv_bias, attn_dropout_prob = dropout_prob, proj_dropout_prob = dropout_prob)), +), @@ -51,7 +52,8 @@ Creates a Vision Transformer (ViT) model. function vit(imsize::Dims{2} = (256, 256); inchannels::Integer = 3, patch_size::Dims{2} = (16, 16), embedplanes::Integer = 768, depth::Integer = 6, nheads::Integer = 16, mlp_ratio = 4.0, dropout_prob = 0.1, - emb_dropout_prob = 0.1, pool::Symbol = :class, nclasses::Integer = 1000) + emb_dropout_prob = 0.1, pool::Symbol = :class, nclasses::Integer = 1000, + qkv_bias = false) @assert pool in [:class, :mean] "Pool type must be either `:class` (class token) or `:mean` (mean pooling)" npatches = prod(imsize .÷ patch_size) @@ -60,9 +62,9 @@ function vit(imsize::Dims{2} = (256, 256); inchannels::Integer = 3, ViPosEmbedding(embedplanes, npatches + 1), Dropout(emb_dropout_prob), transformer_encoder(embedplanes, depth, nheads; mlp_ratio, - dropout_prob), + dropout_prob, qkv_bias), pool === :class ? x -> x[:, 1, :] : seconddimmean), - Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast))) + Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses))) end const VIT_CONFIGS = Dict(:tiny => (depth = 12, embedplanes = 192, nheads = 3), @@ -100,9 +102,10 @@ end @functor ViT function ViT(config::Symbol; imsize::Dims{2} = (256, 256), patch_size::Dims{2} = (16, 16), - pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000) + pretrain::Bool = false, inchannels::Integer = 3, nclasses::Integer = 1000, + qkv_bias=false) _checkconfig(config, keys(VIT_CONFIGS)) - layers = vit(imsize; inchannels, patch_size, nclasses, VIT_CONFIGS[config]...) + layers = vit(imsize; inchannels, patch_size, nclasses, qkv_bias, VIT_CONFIGS[config]...) if pretrain loadpretrain!(layers, string("vit", config)) end