11# An example of semi-supervised node classification
22
33using Flux
4- using Flux: @functor , dropout, onecold, onehotbatch
4+ using Flux: @functor , dropout, onecold, onehotbatch, getindex
55using Flux. Losses: logitbinarycrossentropy
6+ using Flux. Data: DataLoader
67using GraphNeuralNetworks
78using MLDatasets: TUDataset
89using Statistics, Random
910using CUDA
1011CUDA. allowscalar (false )
1112
12- function eval_loss_accuracy (model, g, X, y)
13- ŷ = model (g, X) |> vec
14- l = logitbinarycrossentropy (ŷ, y)
15- acc = mean ((2 .* ŷ .- 1 ) .* (2 .* y .- 1 ) .> 0 )
16- return (loss = round (l, digits= 4 ), acc = round (acc* 100 , digits= 2 ))
13+ function eval_loss_accuracy (model, data_loader, device)
14+ loss = 0.
15+ acc = 0.
16+ ntot = 0
17+ for (g, X, y) in data_loader
18+ g, X, y = g |> device, X |> device, y |> device
19+ n = length (y)
20+ ŷ = model (g, X) |> vec
21+ loss += logitbinarycrossentropy (ŷ, y) * n
22+ acc += mean ((2 .* ŷ .- 1 ) .* (2 .* y .- 1 ) .> 0 ) * n
23+ ntot += n
24+ end
25+ return (loss = round (loss/ ntot, digits= 4 ), acc = round (acc* 100 / ntot, digits= 2 ))
1726end
1827
1928struct GNNData
@@ -22,6 +31,16 @@ struct GNNData
2231 y
2332end
2433
34+ Base. getindex (data:: GNNData , i:: Int ) = getindex (data, [i])
35+
36+ function Base. getindex (data:: GNNData , i:: AbstractVector )
37+ sg, nodemap = subgraph (data. g, i)
38+ return (sg, data. X[:,nodemap], data. y[i])
39+ end
40+
41+ # Flux's Dataloader compatibility.
42+ Flux. Data. _nobs (data:: GNNData ) = data. g. num_graphs
43+ Flux. Data. _getobs (data:: GNNData , i) = data[i]
2544
2645function getdataset (idxs)
2746 data = TUDataset (" MUTAG" )[idxs]
3756# arguments for the `train` function
3857Base. @kwdef mutable struct Args
3958 η = 1f-3 # learning rate
40- epochs = 1000 # number of epochs
59+ batchsize = 64 # batch size (number of graphs in each batch)
60+ epochs = 200 # number of epochs
4161 seed = 17 # set seed > 0 for reproducibility
42- use_cuda = false # if true use cuda (if available)
62+ usecuda = true # if true use cuda (if available)
4363 nhidden = 128 # dimension of hidden features
4464 infotime = 10 # report every `infotime` epochs
4565end
@@ -48,7 +68,7 @@ function train(; kws...)
4868 args = Args (; kws... )
4969 args. seed > 0 && Random. seed! (args. seed)
5070
51- if args. use_cuda && CUDA. functional ()
71+ if args. usecuda && CUDA. functional ()
5272 device = gpu
5373 args. seed > 0 && CUDA. seed! (args. seed)
5474 @info " Training on GPU"
@@ -61,12 +81,15 @@ function train(; kws...)
6181
6282 permindx = randperm (188 )
6383 ntrain = 150
64- gtrain, Xtrain, ytrain = getdataset (permindx[1 : ntrain])
65- gtest, Xtest, ytest = getdataset (permindx[ntrain+ 1 : end ])
84+ dtrain = getdataset (permindx[1 : ntrain])
85+ dtest = getdataset (permindx[ntrain+ 1 : end ])
86+
87+ train_loader = DataLoader (dtrain, batchsize= args. batchsize, shuffle= true )
88+ test_loader = DataLoader (dtest, batchsize= args. batchsize, shuffle= false )
6689
6790 # DEFINE MODEL
6891
69- nin = size (Xtrain, 1 )
92+ nin = size (dtrain . X, 1 )
7093 nhidden = args. nhidden
7194
7295 model = GNNChain (GCNConv (nin => nhidden, relu),
@@ -82,22 +105,23 @@ function train(; kws...)
82105 # LOGGING FUNCTION
83106
84107 function report (epoch)
85- train = eval_loss_accuracy (model, gtrain, Xtrain, ytrain )
86- test = eval_loss_accuracy (model, gtest, Xtest, ytest )
108+ train = eval_loss_accuracy (model, train_loader, device )
109+ test = eval_loss_accuracy (model, test_loader, device )
87110 println (" Epoch: $epoch Train: $(train) Test: $(test) " )
88111 end
89112
90113 # TRAIN
91114
92115 report (0 )
93116 for epoch in 1 : args. epochs
94- # for (g, X, y) in train_loader
117+ for (g, X, y) in train_loader
118+ g, X, y = g |> device, X |> device, y |> device
95119 gs = Flux. gradient (ps) do
96- ŷ = model (gtrain, Xtrain ) |> vec
97- logitbinarycrossentropy (ŷ, ytrain )
120+ ŷ = model (g, X ) |> vec
121+ logitbinarycrossentropy (ŷ, y )
98122 end
99123 Flux. Optimise. update! (opt, ps, gs)
100- # end
124+ end
101125
102126 epoch % args. infotime == 0 && report (epoch)
103127 end
0 commit comments