You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
GraphNeuralNetworks.jl provides common graph convolutional layers by which you can assemble arbitrarily deep or complex models. GNN layers are compatible with
4
+
Flux.jl ones, therefore expert Flux's users should be immediately able to define and train
5
+
their models.
6
+
7
+
In what follows, we discuss two different styles for model creation:
8
+
the *explicit modeling* style, more verbose but more flexible,
9
+
and the *implicit modeling* style based on [`GNNChain`](@ref), more concise but less flexible.
10
+
11
+
## Explicit modeling
12
+
13
+
In the explicit modeling style, the model is created according to the following steps:
14
+
15
+
1. Define a new type for your model (`GNN` in the example below). Layers and submodels are fields.
16
+
2. Apply `Flux.@functor` to the new type to make it Flux's compatible (parameters' collection, gpu movement, etc...)
17
+
3. Optionally define a convenience constructor for your model.
18
+
4. Define the forward pass by implementing the function call method for your type
19
+
5. Instantiate the model.
20
+
21
+
Here is an example of this construction:
22
+
```julia
23
+
using Flux, LightGraphs, GraphNeuralNetworks
24
+
using Flux:@functor
25
+
26
+
struct GNN # step 1
27
+
conv1
28
+
bn
29
+
conv2
30
+
dropout
31
+
dense
32
+
end
33
+
34
+
@functor GNN # step 2
35
+
36
+
functionGNN(din::Int, d::Int, dout::Int) # step 3
37
+
GNN(GCNConv(din => d),
38
+
BatchNorm(d),
39
+
GraphConv(d => d, relu),
40
+
Dropout(0.5),
41
+
Dense(d, dout))
42
+
end
43
+
44
+
function (model::GNN)(g::GNNGraph, x) # step 4
45
+
x = model.conv1(g, x)
46
+
x =relu.(model.bn(x))
47
+
x = model.conv2(g, x)
48
+
x = model.dropout(x)
49
+
x = model.dense(x)
50
+
return x
51
+
end
52
+
53
+
din, d, dout =3, 4, 2
54
+
g =GNNGraph(random_regular_graph(10, 4))
55
+
X =randn(Float32, din, 10)
56
+
model =GNN(din, d, dout) # step 5
57
+
y =model(g, X)
58
+
```
59
+
60
+
## Implicit modeling with GNNChains
61
+
62
+
While very flexible, the way in which we defined `GNN` model definition in last section is a bit verbose.
63
+
In order to simplify things, we provide the [`GNNChain`](@ref) type. It is very similar
64
+
to Flux's well known `Chain`. It allows to compose layers in a sequential fashion as Chain
65
+
does, propagating the output of each layer to the next one. In addition, `GNNChain`
66
+
handles propagates the input graph as well, providing it as a first argument
67
+
to layers subtyping the [`GNNLayer`](@ref) abstract type.
68
+
69
+
Using `GNNChain`, the previous example becomes
70
+
71
+
```julia
72
+
using Flux, LightGraphs, GraphNeuralNetworks
73
+
74
+
din, d, dout =3, 4, 2
75
+
g =GNNGraph(random_regular_graph(10, 4))
76
+
X =randn(Float32, din, 10)
77
+
78
+
model =GNNChain(GCNConv(din => d),
79
+
BatchNorm(d),
80
+
x ->relu.(x),
81
+
GraphConv(d => d, relu),
82
+
Dropout(0.5),
83
+
Dense(d, dout))
84
+
85
+
y =model(g, X)
86
+
```
87
+
88
+
The `GNNChain` only propagates the graph and the node features. More complex scenarios, e.g. when also edge features are updated, have to be handled using the explicit definition of the forward pass.
0 commit comments