diff --git a/src/core.jl b/src/core.jl index 641898e..5d2911b 100644 --- a/src/core.jl +++ b/src/core.jl @@ -381,7 +381,7 @@ end """ - predict(I::Tuple{T,T}, ytest::T; GPU::Bool=false) where {T<:NamedMatrix} + predict(I::Tuple{T,T}, ytest::T; GPU::Bool=false, returnweights::Bool=false) where {T<:NamedMatrix} Predict interactions between query and target nodes using *de novo* network-based inference model proposed by Wu, et al (2016). @@ -390,6 +390,7 @@ model proposed by Wu, et al (2016). - `I::Tuple{NamedMatrix,NamedMatrix}`: Feature-source-target trilayered adjacency matrices - `ytest::NamedMatrix`: Query-target bipartite adjacency matrix - `GPU::Bool`: Use GPU acceleration for calculation (default = false) +- `returnweights::Bool`: Return the weighting matrix employed for prediction # References 1. Wu, et al (2016). SDTNBI: an integrated network and chemoinformatics tool for @@ -399,7 +400,7 @@ model proposed by Wu, et al (2016). Chemical Similarity-Guided Network-Based Inference. International Journal of Molecular Sciences, 23(17), 9666. https://doi.org/10.3390/ijms23179666 """ -function predict(I::Tuple{T,T}, ytest::T; GPU::Bool=false) where {T<:NamedMatrix} +function predict(I::Tuple{T,T}, ytest::T; GPU::Bool=false, returnweights::Bool=false) where {T<:NamedMatrix} # GPU calculations helper functions _useGPU(x::AbstractArray) = GPU ? CuArray{Float32}(x) : x @@ -419,13 +420,17 @@ function predict(I::Tuple{T,T}, ytest::T; GPU::Bool=false) where {T<:NamedMatrix end yhat = F[names(ytest, 1), names(ytest, 2)] - return yhat + if returnweights + return yhat, W + else + return yhat + end end -predict(A::T, B::T, ytest::T; GPU::Bool=false) where {T<:NamedMatrix} = - predict((A, B), ytest; GPU=GPU) +predict(A::T, B::T, ytest::T; kwargs...) where {T<:NamedMatrix} = + predict((A, B), ytest; kwargs...) """ - predict(A::T, ytrain::T; GPU::Bool=false) where {T<:NamedMatrix} + predict(A::T, ytrain::T; GPU::Bool=false, returnweights::Bool=false) where {T<:NamedMatrix} Predict interactions between query and target nodes using *de novo* network-based inference model proposed by Wu, et al (2016). @@ -434,6 +439,7 @@ model proposed by Wu, et al (2016). - `A::NamedMatrix`: Feature-source-target trilayered adjacency matrix - `ytrain::NamedMatrix`: Source-target bipartite adjacency matrix - `GPU::Bool`: Use GPU acceleration for calculation (default = false) +- `returnweights::Bool`: Return the weighting matrix employed for prediction # References 1. Wu, et al (2016). SDTNBI: an integrated network and chemoinformatics tool for @@ -443,7 +449,7 @@ model proposed by Wu, et al (2016). Chemical Similarity-Guided Network-Based Inference. International Journal of Molecular Sciences, 23(17), 9666. https://doi.org/10.3390/ijms23179666 """ -function predict(A::T, ytrain::T; GPU::Bool=false) where {T<:NamedMatrix} +function predict(A::T, ytrain::T; GPU::Bool=false, returnweights::Bool=false) where {T<:NamedMatrix} # GPU calculations helper functions _useGPU(x::AbstractArray) = GPU ? CuArray{Float32}(x) : x @@ -462,7 +468,11 @@ function predict(A::T, ytrain::T; GPU::Bool=false) where {T<:NamedMatrix} end yhat = F[names(ytrain, 1), names(ytrain, 2)] - return yhat + if returnweights + return yhat, W + else + return yhat + end end """ diff --git a/test/runtests.jl b/test/runtests.jl index f249d11..4c691d8 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -155,6 +155,8 @@ end @test SimSpread.predict(A, B, y) == yhat @test SimSpread.predict((A, B), y) == yhat + @test all(SimSpread.predict(A, B, y; returnweights=true) .== (yhat, spread(B))) + @test all(SimSpread.predict((A, B), y; returnweights=true) .== (yhat, spread(B))) end @testset "clean!" begin