@@ -100,11 +100,13 @@ function broadcast_edges(g::GNNGraph, x)
100100 return gather (x, gi)
101101end
102102
103+ # return a permuted matrix according to the sorting of the sortby column
103104function _sort_col (matrix:: AbstractArray ; rev:: Bool = true , sortby:: Int = 1 )
104- index = sortperm (view (matrix, sortby, : ); rev)
105- return matrix[ :, index]
105+ index = sortperm (view (matrix, sortby, :); rev)
106+ return matrix[:, index]
106107end
107108
109+ # sort and reshape matrix
108110function _sort_matrix (matrix:: AbstractArray , k:: Int ; rev:: Bool = true , sortby = nothing )
109111 if sortby === nothing
110112 return sort (matrix, dims = 2 ; rev)[:, 1 : k]
@@ -113,32 +115,45 @@ function _sort_matrix(matrix::AbstractArray, k::Int; rev::Bool = true, sortby =
113115 end
114116end
115117
116- function _sort_batch (matrices:: AbstractArray , k:: Int ; rev:: Bool = true , sortby = nothing )
118+ # sort the iterator of batch matrices
119+ function _sort_batch (matrices, k:: Int ; rev:: Bool = true , sortby = nothing )
117120 return map (x -> _sort_matrix (x, k; rev, sortby), matrices)
118121end
119122
123+ # sort and reshape batch matrix
120124function _topk_batch (matrix:: AbstractArray , number_graphs:: Int , k:: Int ; rev:: Bool = true ,
121125 sortby = nothing )
122126 tensor_matrix = reshape (matrix, size (matrix, 1 ), size (matrix, 2 ) ÷ number_graphs,
123127 number_graphs)
124- sorted_matrix = _sort_batch (collect ( eachslice (tensor_matrix, dims = 3 ) ), k; rev, sortby)
128+ sorted_matrix = _sort_batch (eachslice (tensor_matrix, dims = 3 ), k; rev, sortby)
125129 return reduce (hcat, sorted_matrix)
126130end
127131
132+ # topk for a feature matrix
128133function _topk (matrix:: AbstractArray , number_graphs:: Int , k:: Int ; rev:: Bool = true ,
129134 sortby = nothing )
130- if number_graphs== 1
135+ if number_graphs == 1
131136 return _sort_matrix (matrix, k; rev, sortby)
132137 else
133138 return _topk_batch (matrix, number_graphs, k; rev, sortby)
134139 end
135140end
136141
142+ """
143+ topk_nodes(g, feat, k; rev = true, sortby = nothing)
144+
145+ Graph-wise top-k on node features `feat` according to the `sortby` feature index.
146+ """
137147function topk_nodes (g:: GNNGraph , feat:: Symbol , k:: Int ; rev = true , sortby = nothing )
138148 matrix = getproperty (g. ndata, feat)
139149 return _topk (matrix, g. num_graphs, k; rev, sortby)
140150end
141151
152+ """
153+ topk_edges(g, feat, k; rev = true, sortby = nothing)
154+
155+ Graph-wise top-k on edge features `feat` according to the `sortby` feature index.
156+ """
142157function topk_edges (g:: GNNGraph , feat:: Symbol , k:: Int ; rev = true , sortby = nothing )
143158 matrix = getproperty (g. edata, feat)
144159 return _topk (matrix, g. num_graphs, k; rev, sortby)
0 commit comments