@@ -11,7 +11,7 @@ const ADJMAT_T = AbstractMatrix
1111const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
1212
1313"""
14- GNNGraph(data; [graph_type, dir, num_nodes, nf, ef, gf ])
14+ GNNGraph(data; [graph_type, nf, ef, gf, num_nodes, num_graphs, graph_indicator, dir ])
1515 GNNGraph(g::GNNGraph; [nf, ef, gf])
1616
1717A type representing a graph structure and storing also arrays
@@ -43,11 +43,13 @@ from the LightGraphs' graph library can be used on it.
4343 - `:dense`. A dense adjacency matrix representation.
4444 Default `:coo`.
4545- `dir`. The assumed edge direction when given adjacency matrix or adjacency list input data `g`.
46- Possible values are `:out` and `:in`. Defaul `:out`.
47- - `num_nodes`. The number of nodes. If not specified, inferred from `g`. Default nothing.
48- - `nf`: Node features. Either nothing, or an array whose last dimension has size num_nodes. Default nothing.
49- - `ef`: Edge features. Either nothing, or an array whose last dimension has size num_edges. Default nothing.
50- - `gf`: Global features. Default nothing.
46+ Possible values are `:out` and `:in`. Default `:out`.
47+ - `num_nodes`. The number of nodes. If not specified, inferred from `g`. Default `nothing`.
48+ - `num_graphs`. The number of graphs. Larger than 1 in case of batched graphs. Default `1`.
49+ - `graph_indicator`. For batched graphs, a vector containeing the graph assigment of each node. Default `nothing`.
50+ - `nf`: Node features. Either nothing, or an array whose last dimension has size num_nodes. Default `nothing`.
51+ - `ef`: Edge features. Either nothing, or an array whose last dimension has size num_edges. Default `nothing`.
52+ - `gf`: Global features. Default `nothing`.
5153
5254# Usage.
5355
@@ -87,6 +89,8 @@ struct GNNGraph{T<:Union{COO_T,ADJMAT_T}}
8789 graph:: T
8890 num_nodes:: Int
8991 num_edges:: Int
92+ num_graphs:: Int
93+ graph_indicator
9094 nf
9195 ef
9296 gf
99103@functor GNNGraph
100104
101105function GNNGraph (data;
102- num_nodes = nothing ,
106+ num_nodes = nothing ,
107+ num_graphs = 1 ,
108+ graph_indicator = nothing ,
103109 graph_type = :coo ,
104110 dir = :out ,
105111 nf = nothing ,
@@ -119,6 +125,9 @@ function GNNGraph(data;
119125 elseif graph_type == :sparse
120126 g, num_nodes, num_edges = to_sparse (data; dir)
121127 end
128+ if num_graphs > 1
129+ @assert len (graph_indicator) = num_nodes " When batching multiple graphs `graph_indicator` should be filled with the nodes' memberships."
130+ end
122131
123132 # # Possible future implementation of feature maps.
124133 # # Currently this doesn't play well with zygote due to
@@ -127,8 +136,9 @@ function GNNGraph(data;
127136 # edata["e"] = ef
128137 # gdata["g"] = gf
129138
130-
131- GNNGraph (g, num_nodes, num_edges, nf, ef, gf)
139+ GNNGraph (g, num_nodes, num_edges,
140+ num_graphs, graph_indicator,
141+ nf, ef, gf)
132142end
133143
134144# COO convenience constructors
@@ -147,7 +157,7 @@ function GNNGraph(g::GNNGraph;
147157 nf= node_feature (g), ef= edge_feature (g), gf= global_feature (g))
148158 # ndata=copy(g.ndata), edata=copy(g.edata), gdata=copy(g.gdata), # copy keeps the refs to old data
149159
150- GNNGraph (g. graph, g. num_nodes, g. num_edges, nf, ef, gf) # ndata, edata, gdata,
160+ GNNGraph (g. graph, g. num_nodes, g. num_edges, g . num_graphs, g . graph_indicator, nf, ef, gf) # ndata, edata, gdata,
151161end
152162
153163
@@ -370,6 +380,7 @@ function add_self_loops(g::GNNGraph{<:COO_T})
370380 t = [t; nodes]
371381
372382 GNNGraph ((s, t, nothing ), g. num_nodes, length (s),
383+ g. num_graphs, g. graph_indicator,
373384 node_feature (g), edge_feature (g), global_feature (g))
374385end
375386
@@ -379,6 +390,7 @@ function add_self_loops(g::GNNGraph{<:ADJMAT_T}; add_to_existing=true)
379390 A += I
380391 num_edges = g. num_edges + g. num_nodes
381392 GNNGraph (A, g. num_nodes, num_edges,
393+ g. num_graphs, g. graph_indicator,
382394 node_feature (g), edge_feature (g), global_feature (g))
383395end
384396
@@ -392,10 +404,46 @@ function remove_self_loops(g::GNNGraph{<:COO_T})
392404 s = s[mask_old_loops]
393405 t = t[mask_old_loops]
394406
395- GNNGraph ((s, t, nothing ), g. num_nodes, length (s),
407+ GNNGraph ((s, t, nothing ), g. num_nodes, length (s),
408+ g. num_graphs, g. graph_indicator,
396409 node_feature (g), edge_feature (g), global_feature (g))
397410end
398411
412+ function _catgraphs (g1:: GNNGraph{<:COO_T} , g2:: GNNGraph{<:COO_T} )
413+ s1, t1 = edge_index (g1)
414+ s2, t2 = edge_index (g2)
415+ nv1, nv2 = g1. num_nodes, g2. num_nodes
416+ s = vcat (s1, nv1 .+ s2)
417+ t = vcat (t1, nv1 .+ t2)
418+ w = cat_features (edge_weight (g1), edge_weight (g2))
419+
420+ ind1 = isnothing (g1. graph_indicator) ? fill! (similar (s1, Int, nv1), 1 ) : g1. graph_indicator
421+ ind2 = isnothing (g2. graph_indicator) ? fill! (similar (s2, Int, nv2), 1 ) : g2. graph_indicator
422+ graph_indicator = vcat (ind1, g1. num_graphs .+ ind2)
423+
424+ GNNGraph (
425+ (s, t, w),
426+ nv1 + nv2, g1. num_edges + g2. num_edges,
427+ g1. num_graphs + g2. num_graphs, graph_indicator,
428+ cat_features (node_feature (g1), node_feature (g2)),
429+ cat_features (edge_feature (g1), edge_feature (g2)),
430+ cat_features (global_feature (g1), global_feature (g2)),
431+ )
432+ end
433+
434+ # Cat public interfaces
435+ function SparseArrays. blockdiag (g1:: GNNGraph , gothers:: GNNGraph... )
436+ @assert length (gothers) >= 1
437+ g = g1
438+ for go in gothers
439+ g = _catgraphs (g, go)
440+ end
441+ return g
442+ end
443+
444+ Flux. batch (xs:: Vector{<:GNNGraph} ) = blockdiag (xs... )
445+ # ########################
446+
399447@non_differentiable normalized_laplacian (x... )
400448@non_differentiable normalized_adjacency (x... )
401449@non_differentiable scaled_laplacian (x... )
0 commit comments