diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..9c3da475 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,12 @@ +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" # Location of package manifests + schedule: + interval: "monthly" + - package-ecosystem: "julia" + directories: # Location of Julia projects + - "/" + schedule: + interval: "weekly" diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml deleted file mode 100644 index d1217f23..00000000 --- a/.github/workflows/CompatHelper.yml +++ /dev/null @@ -1,16 +0,0 @@ -name: CompatHelper -on: - schedule: - - cron: '00 00 * * *' - workflow_dispatch: -jobs: - CompatHelper: - runs-on: ubuntu-latest - steps: - - name: Pkg.add("CompatHelper") - run: julia -e 'using Pkg; Pkg.add("CompatHelper")' - - name: CompatHelper.main() - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} - run: julia -e 'using CompatHelper; CompatHelper.main(; master_branch = "dev")' diff --git a/src/builtins/ThresholdPredictors.jl b/src/builtins/ThresholdPredictors.jl index 9b06f8a5..c0b82423 100644 --- a/src/builtins/ThresholdPredictors.jl +++ b/src/builtins/ThresholdPredictors.jl @@ -245,8 +245,14 @@ function MMI.fitted_params(model::ThresholdUnion, fitresult) ) end +# to strip off any "report" part of the atomic model's output of `predict`: +prediction_part(output, atomic_model) = + prediction_part(output, Val(:predict in MMI.reporting_operations(atomic_model))) +prediction_part(output, ::Val{true}) = first(output) +prediction_part(output, ::Val{false}) = output + function MMI.predict(model::ThresholdUnion, fitresult, X) - yhat = MMI.predict(model.model, fitresult[1], X) + yhat = prediction_part(MMI.predict(model.model, fitresult[1], X), model.model) threshold = (1 - fitresult[2], fitresult[2]) return _predict_threshold(yhat, threshold) end diff --git a/test/builtins/ThresholdPredictors.jl b/test/builtins/ThresholdPredictors.jl index f8525d69..e074325e 100644 --- a/test/builtins/ThresholdPredictors.jl +++ b/test/builtins/ThresholdPredictors.jl @@ -323,6 +323,21 @@ end @test MLJBase.predict(mach2, (; x = rand(2))) == yhat end +@testset "wrapping models with non-empty `reporting_observations`" begin + # to resolve https://github.com/JuliaAI/MLJModels.jl/issues/606 + + X = (x = rand(3),) + y = coerce([0, 1, 0], OrderedFactor) + mode_class = y[1] # `0` + + clf = ConstantClassifier() + pipe = Standardizer() |> clf + point_predictor = BinaryThresholdPredictor(pipe) + + mach = MLJBase.machine(point_predictor, X, y) |> MLJBase.fit! + @test MLJBase.predict(mach, X) == fill(mode_class, 3) +end + end # module true diff --git a/test/runtests.jl b/test/runtests.jl index da955718..7550b512 100755 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,33 +2,28 @@ import Pkg using Test, MLJModels, MLJTransforms -@testset "metadata" begin - @testset "metadata.jl" begin - @test include("metadata.jl") - end - @testset "model search" begin - @test include("model_search.jl") - end - @testset "loading model code" begin - @test include("loading.jl") - end -end - -@testset "built-in models" begin - @testset "Constant.jl" begin - @test include("builtins/Constant.jl") - end - @testset "ThresholdPredictors" begin - @test include("builtins/ThresholdPredictors.jl") - end -end +test_files = [ + "metadata.jl", + "model_search.jl", + "loading.jl", + joinpath("builtins", "Constant.jl"), + joinpath("builtins", "ThresholdPredictors.jl"), +] if parse(Bool, get(ENV, "MLJ_TEST_REGISTRY", "false")) - @testset "registry" begin - @test include("registry.jl") - end + push!(test_files, "registry.jl") else @info "Test of the MLJ Registry is being skipped. Set environment variable "* "MLJ_TEST_REGISTRY = \"true\" to include them.\n"* "The Registry test takes about ten minutes. " end + +files = isempty(ARGS) ? test_files : ARGS + +for file in files + quote + @testset $file begin + include($file) + end + end |> eval +end