@@ -11,7 +11,7 @@ const ADJMAT_T = AbstractMatrix
1111const SPARSE_T = AbstractSparseMatrix # subset of ADJMAT_T
1212
1313"""
14- GNNGraph(data; [graph_type, nf, ef, gf, num_nodes, num_graphs, graph_indicator, dir])
14+ GNNGraph(data; [graph_type, nf, ef, gf, num_nodes, graph_indicator, dir])
1515 GNNGraph(g::GNNGraph; [nf, ef, gf])
1616
1717A type representing a graph structure and storing also arrays
@@ -50,7 +50,6 @@ from the LightGraphs' graph library can be used on it.
5050- `dir`. The assumed edge direction when given adjacency matrix or adjacency list input data `g`.
5151 Possible values are `:out` and `:in`. Default `:out`.
5252- `num_nodes`. The number of nodes. If not specified, inferred from `g`. Default `nothing`.
53- - `num_graphs`. The number of graphs. Larger than 1 in case of batched graphs. Default `1`.
5453- `graph_indicator`. For batched graphs, a vector containeing the graph assigment of each node. Default `nothing`.
5554- `nf`: Node features. Either nothing, or an array whose last dimension has size num_nodes. Default `nothing`.
5655- `ef`: Edge features. Either nothing, or an array whose last dimension has size num_edges. Default `nothing`.
@@ -123,17 +122,17 @@ function GNNGraph(data;
123122
124123 @assert graph_type ∈ [:coo , :dense , :sparse ] " Invalid graph_type $graph_type requested"
125124 @assert dir ∈ [:in , :out ]
125+
126126 if graph_type == :coo
127127 g, num_nodes, num_edges = to_coo (data; num_nodes, dir)
128128 elseif graph_type == :dense
129129 g, num_nodes, num_edges = to_dense (data; dir)
130130 elseif graph_type == :sparse
131131 g, num_nodes, num_edges = to_sparse (data; dir)
132132 end
133- if num_graphs > 1
134- @assert len (graph_indicator) = num_nodes " When batching multiple graphs `graph_indicator` should be filled with the nodes' memberships."
135- end
136-
133+
134+ num_graphs = ! isnothing (graph_indicator) ? maximum (graph_indicator) : 1
135+
137136 # # Possible future implementation of feature maps.
138137 # # Currently this doesn't play well with zygote due to
139138 # # https://github.com/FluxML/Zygote.jl/issues/717
@@ -154,8 +153,8 @@ GNNGraph((s, t)::NTuple{2}; kws...) = GNNGraph((s, t, nothing); kws...)
154153
155154function GNNGraph (g:: AbstractGraph ; kws... )
156155 s = LightGraphs. src .(LightGraphs. edges (g))
157- t = LightGraphs. dst .(LightGraphs. edges (g))
158- GNNGraph ((s, t); kws... )
156+ t = LightGraphs. dst .(LightGraphs. edges (g))
157+ GNNGraph ((s, t); num_nodes = nv (g), kws... )
159158end
160159
161160function GNNGraph (g:: GNNGraph ;
@@ -436,36 +435,77 @@ function _catgraphs(g1::GNNGraph{<:COO_T}, g2::GNNGraph{<:COO_T})
436435 )
437436end
438437
439- # Cat public interfaces
438+ # ## Cat public interfaces #############
440439
441- ```
440+ """
442441 blockdiag(xs::GNNGraph...)
443442
444443Batch togheter multiple `GNNGraph`s into a single one
445444containing the total number of nodes and edges of the original graphs.
446445
447446Equivalent to [`Flux.batch`](@ref).
448- ```
447+ """
449448function SparseArrays. blockdiag (g1:: GNNGraph , gothers:: GNNGraph... )
450- @assert length (gothers) >= 1
451449 g = g1
452450 for go in gothers
453451 g = _catgraphs (g, go)
454452 end
455453 return g
456454end
457455
458- ```
456+ """
459457 batch(xs::Vector{<:GNNGraph})
460458
461459Batch togheter multiple `GNNGraph`s into a single one
462460containing the total number of nodes and edges of the original graphs.
463461
464462Equivalent to [`SparseArrays.blockdiag`](@ref).
465- ```
463+ """
466464Flux. batch (xs:: Vector{<:GNNGraph} ) = blockdiag (xs... )
467465# ########################
468466
467+ """
468+ subgraph(g::GNNGraph, i)
469+
470+ Return the subgraph of `g` induced by those nodes `v`
471+ for which `g.graph_indicator[v] ∈ i`. In other words, it
472+ extract the component graphs from a batched graph.
473+
474+ It also returns a vector `nodes` mapping the new nodes to the old ones.
475+ The node `i` in the subgraph corresponds to the node `nodes[i]` in `g`.
476+ """
477+ subgraph (g:: GNNGraph , i:: Int ) = subgraph (g:: GNNGraph{<:COO_T} , [i])
478+
479+ function subgraph (g:: GNNGraph{<:COO_T} , i:: AbstractVector )
480+ node_mask = g. graph_indicator .∈ Ref (i)
481+
482+ nodes = (1 : g. num_nodes)[node_mask]
483+ nodemap = Dict (v => vnew for (vnew, v) in enumerate (nodes))
484+
485+ graphmap = Dict (i => inew for (inew, i) in enumerate (i))
486+ graph_indicator = [graphmap[i] for i in g. graph_indicator[node_mask]]
487+
488+ s, t, w = g. graph
489+ edge_mask = s .∈ Ref (nodes)
490+ s = [nodemap[i] for i in s[edge_mask]]
491+ t = [nodemap[i] for i in t[edge_mask]]
492+ w = isnothing (w) ? nothing : w[edge_mask]
493+ @show size (g. nf) size (node_mask)
494+ nf = isnothing (g. nf) ? nothing : g. nf[:,node_mask]
495+ ef = isnothing (g. ef) ? nothing : g. ef[:,edge_mask]
496+ gf = isnothing (g. gf) ? nothing : g. gf[:,i]
497+
498+ num_nodes = length (graph_indicator)
499+ num_edges = length (s)
500+ num_graphs = length (i)
501+
502+ gnew = GNNGraph ((s,t,w),
503+ num_nodes, num_edges, num_graphs,
504+ graph_indicator,
505+ nf, ef, gf)
506+ return gnew, nodes
507+ end
508+
469509@non_differentiable normalized_laplacian (x... )
470510@non_differentiable normalized_adjacency (x... )
471511@non_differentiable scaled_laplacian (x... )
0 commit comments