Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/" # Location of package manifests
schedule:
interval: "monthly"
- package-ecosystem: "julia"
directories: # Location of Julia projects
- "/"
schedule:
interval: "weekly"
16 changes: 0 additions & 16 deletions .github/workflows/CompatHelper.yml

This file was deleted.

8 changes: 7 additions & 1 deletion src/builtins/ThresholdPredictors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,14 @@ function MMI.fitted_params(model::ThresholdUnion, fitresult)
)
end

# to strip off any "report" part of the atomic model's output of `predict`:
prediction_part(output, atomic_model) =
prediction_part(output, Val(:predict in MMI.reporting_operations(atomic_model)))
prediction_part(output, ::Val{true}) = first(output)
prediction_part(output, ::Val{false}) = output

function MMI.predict(model::ThresholdUnion, fitresult, X)
yhat = MMI.predict(model.model, fitresult[1], X)
yhat = prediction_part(MMI.predict(model.model, fitresult[1], X), model.model)
threshold = (1 - fitresult[2], fitresult[2])
return _predict_threshold(yhat, threshold)
end
Expand Down
15 changes: 15 additions & 0 deletions test/builtins/ThresholdPredictors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,21 @@ end
@test MLJBase.predict(mach2, (; x = rand(2))) == yhat
end

@testset "wrapping models with non-empty `reporting_observations`" begin
# to resolve https://github.com/JuliaAI/MLJModels.jl/issues/606

X = (x = rand(3),)
y = coerce([0, 1, 0], OrderedFactor)
mode_class = y[1] # `0`

clf = ConstantClassifier()
pipe = Standardizer() |> clf
point_predictor = BinaryThresholdPredictor(pipe)

mach = MLJBase.machine(point_predictor, X, y) |> MLJBase.fit!
@test MLJBase.predict(mach, X) == fill(mode_class, 3)
end

end # module

true
41 changes: 18 additions & 23 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,28 @@ import Pkg

using Test, MLJModels, MLJTransforms

@testset "metadata" begin
@testset "metadata.jl" begin
@test include("metadata.jl")
end
@testset "model search" begin
@test include("model_search.jl")
end
@testset "loading model code" begin
@test include("loading.jl")
end
end

@testset "built-in models" begin
@testset "Constant.jl" begin
@test include("builtins/Constant.jl")
end
@testset "ThresholdPredictors" begin
@test include("builtins/ThresholdPredictors.jl")
end
end
test_files = [
"metadata.jl",
"model_search.jl",
"loading.jl",
joinpath("builtins", "Constant.jl"),
joinpath("builtins", "ThresholdPredictors.jl"),
]

if parse(Bool, get(ENV, "MLJ_TEST_REGISTRY", "false"))
@testset "registry" begin
@test include("registry.jl")
end
push!(test_files, "registry.jl")
else
@info "Test of the MLJ Registry is being skipped. Set environment variable "*
"MLJ_TEST_REGISTRY = \"true\" to include them.\n"*
"The Registry test takes about ten minutes. "
end

files = isempty(ARGS) ? test_files : ARGS

for file in files
quote
@testset $file begin
include($file)
end
end |> eval
end
Loading