@@ -12,29 +12,31 @@ abstract type GNNLayer end
1212 GNNChain(name = layer, ...)
1313
1414Collects multiple layers / functions to be called in sequence
15- on a given input. Supports indexing and slicing, `m[2]` or `m[1:end-1]`,
16- and if names are given, `m[:name] == m[1]` etc.
17-
18- ## Examples
15+ on given input graph and input node features.
1916
20- ```
21- julia> m = GNNChain(x -> x^2, x -> x+1);
17+ It allows to compose layers in a sequential fashion as `Flux.Chain`
18+ does, propagating the output of each layer to the next one.
19+ In addition, `GNNChain` handles the input graph as well, providing it
20+ as a first argument only to layers subtyping the [`GNNLayer`](@ref) abstract type.
2221
23- julia> m(5) == 26
24- true
22+ `GNNChain` supports indexing and slicing, `m[2]` or `m[1:end-1]`,
23+ and if names are given, `m[:name] == m[1]` etc.
2524
26- julia> m = GNNChain(Dense(10, 5, tanh), Dense(5, 2));
25+ # Examples
2726
28- julia> x = rand(10, 32);
27+ ```juliarepl
28+ julia> m = GNNChain(GCNConv(2=>5), BatchNorm(5), x -> relu.(x), Dense(5, 4));
2929
30- julia> m(x) == m[2](m[1](x))
31- true
30+ julia> x = randn(Float32, 2, 3);
3231
33- julia> m2 = GNNChain(enc = GNNChain(Flux.flatten, Dense(10, 5, tanh)),
34- dec = Dense(5, 2));
32+ julia> g = GNNGraph([1,1,2,3], [2,3,1,1]);
3533
36- julia> m2(x) == (m2[:dec] ∘ m2[:enc])(x)
37- true
34+ julia> m(g, x)
35+ 4×3 Matrix{Float32}:
36+ 0.157941 0.15443 0.193471
37+ 0.0819516 0.0503105 0.122523
38+ 0.225933 0.267901 0.241878
39+ -0.0134364 -0.0120716 -0.0172505
3840```
3941"""
4042struct GNNChain{T}
0 commit comments