Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions src/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,7 @@ end


"""
predict(I::Tuple{T,T}, ytest::T; GPU::Bool=false) where {T<:NamedMatrix}
predict(I::Tuple{T,T}, ytest::T; GPU::Bool=false, returnweights::Bool=false) where {T<:NamedMatrix}

Predict interactions between query and target nodes using *de novo* network-based inference
model proposed by Wu, et al (2016).
Expand All @@ -390,6 +390,7 @@ model proposed by Wu, et al (2016).
- `I::Tuple{NamedMatrix,NamedMatrix}`: Feature-source-target trilayered adjacency matrices
- `ytest::NamedMatrix`: Query-target bipartite adjacency matrix
- `GPU::Bool`: Use GPU acceleration for calculation (default = false)
- `returnweights::Bool`: Return the weighting matrix employed for prediction

# References
1. Wu, et al (2016). SDTNBI: an integrated network and chemoinformatics tool for
Expand All @@ -399,7 +400,7 @@ model proposed by Wu, et al (2016).
Chemical Similarity-Guided Network-Based Inference. International Journal of Molecular
Sciences, 23(17), 9666. https://doi.org/10.3390/ijms23179666
"""
function predict(I::Tuple{T,T}, ytest::T; GPU::Bool=false) where {T<:NamedMatrix}
function predict(I::Tuple{T,T}, ytest::T; GPU::Bool=false, returnweights::Bool=false) where {T<:NamedMatrix}
# GPU calculations helper functions
_useGPU(x::AbstractArray) = GPU ? CuArray{Float32}(x) : x

Expand All @@ -419,13 +420,17 @@ function predict(I::Tuple{T,T}, ytest::T; GPU::Bool=false) where {T<:NamedMatrix
end

yhat = F[names(ytest, 1), names(ytest, 2)]
return yhat
if returnweights
return yhat, W
else
return yhat
end
end
predict(A::T, B::T, ytest::T; GPU::Bool=false) where {T<:NamedMatrix} =
predict((A, B), ytest; GPU=GPU)
predict(A::T, B::T, ytest::T; kwargs...) where {T<:NamedMatrix} =
predict((A, B), ytest; kwargs...)

"""
predict(A::T, ytrain::T; GPU::Bool=false) where {T<:NamedMatrix}
predict(A::T, ytrain::T; GPU::Bool=false, returnweights::Bool=false) where {T<:NamedMatrix}

Predict interactions between query and target nodes using *de novo* network-based inference
model proposed by Wu, et al (2016).
Expand All @@ -434,6 +439,7 @@ model proposed by Wu, et al (2016).
- `A::NamedMatrix`: Feature-source-target trilayered adjacency matrix
- `ytrain::NamedMatrix`: Source-target bipartite adjacency matrix
- `GPU::Bool`: Use GPU acceleration for calculation (default = false)
- `returnweights::Bool`: Return the weighting matrix employed for prediction

# References
1. Wu, et al (2016). SDTNBI: an integrated network and chemoinformatics tool for
Expand All @@ -443,7 +449,7 @@ model proposed by Wu, et al (2016).
Chemical Similarity-Guided Network-Based Inference. International Journal of Molecular
Sciences, 23(17), 9666. https://doi.org/10.3390/ijms23179666
"""
function predict(A::T, ytrain::T; GPU::Bool=false) where {T<:NamedMatrix}
function predict(A::T, ytrain::T; GPU::Bool=false, returnweights::Bool=false) where {T<:NamedMatrix}
# GPU calculations helper functions
_useGPU(x::AbstractArray) = GPU ? CuArray{Float32}(x) : x

Expand All @@ -462,7 +468,11 @@ function predict(A::T, ytrain::T; GPU::Bool=false) where {T<:NamedMatrix}
end

yhat = F[names(ytrain, 1), names(ytrain, 2)]
return yhat
if returnweights
return yhat, W
else
return yhat
end
end

"""
Expand Down
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ end

@test SimSpread.predict(A, B, y) == yhat
@test SimSpread.predict((A, B), y) == yhat
@test all(SimSpread.predict(A, B, y; returnweights=true) .== (yhat, spread(B)))
@test all(SimSpread.predict((A, B), y; returnweights=true) .== (yhat, spread(B)))
end

@testset "clean!" begin
Expand Down
Loading