Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/" # Location of package manifests
schedule:
interval: "monthly"
- package-ecosystem: "julia"
directories: # Location of Julia projects
- "/"
schedule:
interval: "weekly"
16 changes: 0 additions & 16 deletions .github/workflows/CompatHelper.yml

This file was deleted.

8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ jobs:
arch:
- x64
steps:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@v1
- uses: actions/checkout@v6
- uses: julia-actions/setup-julia@v2
with:
version: ${{ matrix.version }}
arch: ${{ matrix.arch }}
- uses: julia-actions/cache@v1
- uses: julia-actions/cache@v3
env:
cache-name: cache-artifacts
with:
Expand All @@ -45,7 +45,7 @@ jobs:
# This environment variable enables the integration tests:
MLJ_TEST_REGISTRY: "false"
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v4
- uses: codecov/codecov-action@v6
with:
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: false
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "MLJModels"
uuid = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
authors = ["Anthony D. Blaom <anthony.blaom@gmail.com>"]
version = "0.18.4"
version = "0.18.5"

[deps]
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
Expand Down
8 changes: 7 additions & 1 deletion src/builtins/ThresholdPredictors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,14 @@ function MMI.fitted_params(model::ThresholdUnion, fitresult)
)
end

# to strip off any "report" part of the atomic model's output of `predict`:
prediction_part(output, atomic_model) =
prediction_part(output, Val(:predict in MMI.reporting_operations(atomic_model)))
prediction_part(output, ::Val{true}) = first(output)
prediction_part(output, ::Val{false}) = output

function MMI.predict(model::ThresholdUnion, fitresult, X)
yhat = MMI.predict(model.model, fitresult[1], X)
yhat = prediction_part(MMI.predict(model.model, fitresult[1], X), model.model)
threshold = (1 - fitresult[2], fitresult[2])
return _predict_threshold(yhat, threshold)
end
Expand Down
9,266 changes: 4,670 additions & 4,596 deletions src/registry/Metadata.toml

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion src/registry/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ MLJGLMInterface = "caf8df21-4939-456d-ac9c-5fefbfb04c0c"
MLJIteration = "614be32b-d00c-4edb-bd02-1eb411ab5e55"
MLJLIBSVMInterface = "61c7150f-6c77-4bb1-949c-13197eac2a52"
MLJLinearModels = "6ee0df7b-362f-4a72-a706-9e79364fb692"
MLJModelRegistryTools = "0a96183e-380b-4aa6-be10-c555140810f2"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
MLJMultivariateStatsInterface = "1b6a4a23-ba22-4f51-9698-8599985d3728"
MLJNaiveBayesInterface = "33e4bacb-b9e2-458e-9a13-5d9a90b235fa"
Expand Down
15 changes: 15 additions & 0 deletions test/builtins/ThresholdPredictors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,21 @@ end
@test MLJBase.predict(mach2, (; x = rand(2))) == yhat
end

@testset "wrapping models with non-empty `reporting_observations`" begin
# to resolve https://github.com/JuliaAI/MLJModels.jl/issues/606

X = (x = rand(3),)
y = coerce([0, 1, 0], OrderedFactor)
mode_class = y[1] # `0`

clf = ConstantClassifier()
pipe = Standardizer() |> clf
point_predictor = BinaryThresholdPredictor(pipe)

mach = MLJBase.machine(point_predictor, X, y) |> MLJBase.fit!
@test MLJBase.predict(mach, X) == fill(mode_class, 3)
end

end # module

true
41 changes: 18 additions & 23 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,28 @@ import Pkg

using Test, MLJModels, MLJTransforms

@testset "metadata" begin
@testset "metadata.jl" begin
@test include("metadata.jl")
end
@testset "model search" begin
@test include("model_search.jl")
end
@testset "loading model code" begin
@test include("loading.jl")
end
end

@testset "built-in models" begin
@testset "Constant.jl" begin
@test include("builtins/Constant.jl")
end
@testset "ThresholdPredictors" begin
@test include("builtins/ThresholdPredictors.jl")
end
end
test_files = [
"metadata.jl",
"model_search.jl",
"loading.jl",
joinpath("builtins", "Constant.jl"),
joinpath("builtins", "ThresholdPredictors.jl"),
]

if parse(Bool, get(ENV, "MLJ_TEST_REGISTRY", "false"))
@testset "registry" begin
@test include("registry.jl")
end
push!(test_files, "registry.jl")
else
@info "Test of the MLJ Registry is being skipped. Set environment variable "*
"MLJ_TEST_REGISTRY = \"true\" to include them.\n"*
"The Registry test takes about ten minutes. "
end

files = isempty(ARGS) ? test_files : ARGS

for file in files
quote
@testset $file begin
include($file)
end
end |> eval
end
Loading