diff --git a/docs/make.jl b/docs/make.jl index c7c874b1..1f3d8047 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -59,6 +59,7 @@ makedocs(; "Hyperparameter Tuning" => "tutorials/hyperparameter_tuning.md", "Slurm" => "tutorials/slurm.md", "Cross-validation" => "tutorials/folds.md", + "Loss Functions" => "tutorials/losses.md", ], "Research" => [ "Overview" => "research/overview.md" diff --git a/docs/src/tutorials/losses.md b/docs/src/tutorials/losses.md new file mode 100644 index 00000000..164b9747 --- /dev/null +++ b/docs/src/tutorials/losses.md @@ -0,0 +1,154 @@ +## Losses and LoggingLoss + +```@example loss +using EasyHybrid +using EasyHybrid: compute_loss +``` + +````@docs; canonical=false +EasyHybrid.compute_loss +```` + +::: warning + +- `y_nan` is a boolean mask (or function returning a mask per target) used to ignore missing values. +- For uncertainty-aware losses, pass target values as `(y_vals, y_sigma)` and write custom losses to accept that tuple. + +::: + +::: tip Tips and quick reference + +- Prefer `f(ŷ_masked, y_masked)` for custom losses; `y_masked` may be a vector or `(y, σ)`. +- Use `Val(:metric)` only for predefined `loss_fn` variants. +- Quick calls: + - `compute_loss(..., :mse, sum)`: predefined + - `compute_loss(..., custom_loss, sum)` : custom loss + - `compute_loss(..., (f, (arg1, arg2, )), sum)`: additional arguments + - `compute_loss(..., (f, (kw=val,)), sum)`: with keyword arguments + - `compute_loss(..., (f, (arg1, ), (kw=val,)), sum)`: with additional arguments and keyword arguments + - `compute_loss(..., (y, y_sigma), ..., custom_loss_uncertainty, sum)`: with uncertainties + +::: + + +### Simple usage + +Predefined metrics + +```@example loss +# synthetic data +ŷ = Dict(:t1 => [1.0, 2.0], :t2 => [0.5, 1.0]) +y(t) = t == :t1 ? [1.1, 1.9] : [0.4, 1.1] +y_nan(t) = trues(2) +targets = [:t1, :t2] +``` + +```@ansi loss +mse_total = compute_loss(ŷ, y, y_nan, targets, :mse, sum) # total MSE across targets +losses = compute_loss(ŷ, y, y_nan, targets, [:mse, :mae], sum) # multiple metrics in a NamedTuple +``` + +### Custom functions, args, kwargs + +Custom losses receive masked predictions and masked targets: + +```@example loss +custom_loss(ŷ, y) = mean(abs2, ŷ .- y) +weighted_loss(ŷ, y, w) = w * mean(abs2, ŷ .- y) +scaled_loss(ŷ, y; scale=1.0) = scale * mean(abs2, ŷ .- y) +complex_loss(ŷ, y, w; scale=1.0) = scale * w * mean(abs2, ŷ .- y); +nothing # hide +``` + +Use variants: + +```@ansi loss +compute_loss(ŷ, y, y_nan, targets, custom_loss, sum) +compute_loss(ŷ, y, y_nan, targets, (weighted_loss, (0.5,)), sum) +compute_loss(ŷ, y, y_nan, targets, (scaled_loss, (scale=2.0,)), sum) +compute_loss(ŷ, y, y_nan, targets, (complex_loss, (0.5,), (scale=2.0,)), sum) +``` + +### Uncertainty-aware losses + +Signal uncertainty by providing targets as `(y_vals, y_sigma)` and write the loss to accept that tuple: + +```@example loss +function custom_loss_uncertainty(ŷ, y_and_sigma) + y_vals, σ = y_and_sigma + return mean(((ŷ .- y_vals).^2) ./ (σ .^2 .+ 1e-6)) +end +``` + +Top-level usage (both `y` and `y_sigma` can be functions or containers): + +```@example loss +y_sigma(t) = t == :t1 ? [0.1, 0.2] : [0.2, 0.1] +loss = compute_loss(ŷ, (y, y_sigma), y_nan, targets, + custom_loss_uncertainty, sum) +``` + +::: info Behavior + +- `compute_loss` packs per-target `(y_vals_target, σ_target)` tuples and forwards them to `loss_fn`. +- Predefined metrics use only `y_vals` when a `(y, σ)` tuple is supplied. (TODO) + +::: + + +## LoggingLoss + +The `LoggingLoss` helper aggregates per-target loss specifications for training and evaluation. + +````@docs; canonical=false +LoggingLoss +```` + +Internally, in training we use `logging.training_loss` and in evaluation `logging.loss_types`. +Note that `LoggingLoss` can mix symbols and functions. + +## Loss → train + +So, how do you specified your loss? and the additional metrics given by `loss_types`? + +### default losses + +You could select a different training or and a different vector for additional metrics + +```julia +train(...; + training_loss = :mae, + loss_types = [:mse, :mae, :nse] + ) +``` + +### without additional arguments + +Define your own custom function `fn(ŷ, y)` as above and pass it to the corresponding keyword argument: + +```julia +train(...; + training_loss = fn, + loss_types = [fn, :mae, :nse] + ) +``` + +### with additional arguments + +now your function will have additional arguments, i.e. `fn_args(ŷ, y, args...)`: + +```julia +train(...; + training_loss = (fn_args, (args...,)), + loss_types = [(fn_args, (args...,)), :mae, :nse] + ) +``` + +and possible keyword arguments, i.e. `fn_args(ŷ, y, args...; kwargs...)`: + +```julia +train(...; + training_loss = (fn_args, (args...,), (kwargs...,)), + loss_types = [(fn_args, (args...,), (kwargs...,)), :mae, :nse] + ) +``` diff --git a/src/utils/logging_loss.jl b/src/utils/logging_loss.jl index 9190e5d9..04d2978d 100644 --- a/src/utils/logging_loss.jl +++ b/src/utils/logging_loss.jl @@ -1,6 +1,72 @@ -export LoggingLoss +export LoggingLoss, DataAndPhysicsLoss, Physics +import Base: + -const LossSpec = Union{Symbol, Function, Tuple} +""" + LPPP(loss) +Prior-penalty-physics loss + +Wrapper to indicate a physics-based loss that operates on the full prediction ŷ. +Physics losses are computed once per batch, not per target. +""" +struct LPPP{L} + loss::L +end + +""" + LDataPPP(data_loss, physics_loss) + +A container for an optional data-driven loss and one or more physics-based losses. +The `physics_loss` can be a single `Physics` instance or a tuple of them. +""" +struct LDataPPP{D, P<:Tuple} + ℒ_data::D + ℒ_phys::P + + function LDataPPP(ℒ_data, ℒ_phys::Tuple) + all(p -> p isa Physics, ℒ_phys) || throw(ArgumentError("All elements in physics_loss must be of type Physics.")) + new{typeof(ℒ_data), typeof(ℒ_phys)}(ℒ_data, ℒ_phys) + end +end + +LDataPPP(; ℒ_data = nothing, ℒ_phys = ()) = LDataPPP(ℒ_data, ℒ_phys) + +""" + +(loss1, loss2) + +Overloads the `+` operator to combine data and physics losses into a `DataAndPhysicsLoss` object. + +# Examples +```julia +:mse + Physics(smoothness_loss) +custom_loss + Physics(conservation_loss) +:mse + Physics(smoothness_loss) + Physics(conservation_loss) +``` +""" +const DataLossSpec = Union{Symbol, Function, Tuple} + +# Data loss + Physics loss ++(data_loss::DataLossSpec, physics_loss::LPPP) = LDataPPP(data_loss, (physics_loss,)) + +# Physics loss + Data loss ++(physics_loss::LPPP, data_loss::DataLossSpec) = LDataPPP(data_loss, (physics_loss,)) + +# Two physics losses ++(p1::LPPP, p2::LPPP) = LDataPPP(nothing, (p1, p2)) + +# DataAndPhysicsLoss + Physics loss ++(dp::LDataPPP, p::LPPP) = LDataPPP(dp.data_loss, (dp.physics_loss..., p)) ++(p::LPPP, dp::LDataPPP) = LDataPPP(dp.data_loss, (p, dp.physics_loss...)) + +""" + PerTarget(losses) + +A wrapper to indicate that a tuple of losses should be applied on a per-target basis. +""" +struct PerTarget{T<:Tuple} + losses::T +end + +const LossSpec = Union{Symbol, Function, Tuple, LPPP, PerTarget, LDataPPP} """ LoggingLoss @@ -94,19 +160,7 @@ Main loss function for hybrid models that handles both training and evaluation m - In evaluation mode (`logging.train_mode = false`): - `(loss_values, st, ŷ)`: NamedTuple of losses, state and predictions """ -function lossfn(HM::LuxCore.AbstractLuxContainerLayer, x, (y_t, y_nan), ps, st, logging::LoggingLoss) - targets = HM.targets - ŷ, y, y_nan, st = get_predictions_targets(HM, x, (y_t, y_nan), ps, st, targets) - if logging.train_mode - loss_value = compute_loss(ŷ, y, y_nan, targets, logging.training_loss, logging.agg) - return loss_value, st - else - loss_value = compute_loss(ŷ, y, y_nan, targets, logging.loss_types, logging.agg) - return loss_value, st, ŷ - end -end - -function lossfn(HM::Union{SingleNNHybridModel, MultiNNHybridModel, SingleNNModel, MultiNNModel}, x, (y_t, y_nan), ps, st, logging::LoggingLoss) +function lossfn(HM::Union{LuxCore.AbstractLuxContainerLayer, SingleNNHybridModel, MultiNNHybridModel, SingleNNModel, MultiNNModel}, x, (y_t, y_nan), ps, st, logging::LoggingLoss) targets = HM.targets ŷ, y, y_nan, st = get_predictions_targets(HM, x, (y_t, y_nan), ps, st, targets) if logging.train_mode @@ -139,54 +193,40 @@ Get predictions and targets from the hybrid model and return them along with the - `st`: Updated model state """ function get_predictions_targets(HM, x, (y_t, y_nan), ps, st, targets) - ŷ, st = HM(x, ps, st) #TODO the output st can contain more than st, e.g. Rb is that what we want? - y = y_t(HM.targets) - y_nan = y_nan(HM.targets) - return ŷ, y, y_nan, st #TODO has to be done otherwise e.g. Rb is passed as a st and messes up the training -end - -function get_predictions_targets( - HM, - x::AbstractDimArray, - ys::Tuple{<:AbstractDimArray,<:AbstractDimArray}, - ps, st, targets -) - y_t, y_nan = ys ŷ, st = HM(x, ps, st) - y = y_t[col=At(targets)] - y_nan = y_nan[col=At(targets)] + y = _get_target_y(y_t, targets) + y_nan = _get_target_nan(y_nan, targets) return ŷ, y, y_nan, st end -""" - compute_loss(ŷ, y, y_nan, targets, loss_spec, agg::Function) - compute_loss(ŷ, y, y_nan, targets, loss_types::Vector, agg::Function) +function _apply_loss(ŷ, y, y_nan, loss_spec::Symbol) + return loss_fn(ŷ, y, y_nan, Val(loss_spec)) +end -Compute loss values for predictions against targets using specified loss functions. +function _apply_loss(ŷ, y, y_nan, loss_spec::Function) + return loss_fn(ŷ, y, y_nan, loss_spec) +end -# Arguments -- `ŷ`: Model predictions -- `y`: Target values (function or AbstractDimArray) -- `y_nan`: NaN mask (function or AbstractDimArray) -- `targets`: Target variable names -- `loss_spec`: Single loss specification (Symbol, Function, or Tuple) -- `loss_types`: Vector of loss specifications -- `agg`: Function to aggregate losses across targets +function _apply_loss(ŷ, y, y_nan, loss_spec::Tuple) + return loss_fn(ŷ, y, y_nan, loss_spec) +end -# Returns -- Single loss value when using `loss_spec` -- NamedTuple of losses when using `loss_types` -""" -function compute_loss(ŷ, y, y_nan, targets, loss_spec, agg::Function) - losses = [_apply_loss(ŷ[k], y(k), y_nan(k), loss_spec) for k in targets] - return agg(losses) +function _apply_loss(ŷ, y, y_nan, loss_spec::LPPP) + return loss_fn(ŷ, y, y_nan, loss_spec) end -function compute_loss(ŷ, y::AbstractDimArray, y_nan::AbstractDimArray, targets, loss_spec, agg::Function) - losses = [_apply_loss(ŷ[k], y[col=At(k)], y_nan[col=At(k)], loss_spec) for k in targets] - return agg(losses) + +function _apply_loss(ŷ, y, y_nan, target, ℒ_mix::LDataPPP{D, P}) where {D, P} + data_loss = ℒ_mix.ℒ_data === nothing ? 0.0f0 : _apply_loss(ŷ[target], y, y_nan, ℒ.ℒ_data) + phys_loss = ℒ_mix.ℒ_phys === nothing ? 0.0f0 : sum(𝓁 -> _apply_loss(ŷ, y, y_nan, 𝓁), ℒ_mix.ℒ_phys; init=0.0f0) + return data_loss + phys_loss +end + +function _apply_loss(ŷ_all, y, y_nan, target, ℒ_single) + return _apply_loss(ŷ_all[target], y, y_nan, ℒ_single) end + """ _apply_loss(ŷ, y, y_nan, loss_spec) @@ -201,42 +241,112 @@ Helper function to apply the appropriate loss function based on the specificatio # Returns - Computed loss value """ -function _apply_loss(ŷ, y, y_nan, loss_spec::Symbol) - return loss_fn(ŷ, y, y_nan, Val(loss_spec)) -end - -function _apply_loss(ŷ, y, y_nan, loss_spec::Function) - return loss_fn(ŷ, y, y_nan, loss_spec) -end +function _apply_loss end -function _apply_loss(ŷ, y, y_nan, loss_spec::Tuple) - return loss_fn(ŷ, y, y_nan, loss_spec) +function compute_loss(ŷ, y, y_nan, targets, loss_spec, agg::Function) + losses = assemble_loss(ŷ, y, y_nan, targets, loss_spec) + return agg(losses) end -function compute_loss(ŷ, y, y_nan, targets, loss_types::Vector, agg::Function) +function compute_loss(ŷ, y, y_nan, targets, loss_types::Vector, agg::Function) out_loss_types = [ begin - losses = [_apply_loss(ŷ[k], y(k), y_nan(k), loss_type) for k in targets] + losses = assemble_loss(ŷ, y, y_nan, targets, loss_type) agg_loss = agg(losses) NamedTuple{(targets..., Symbol(agg))}([losses..., agg_loss]) end for loss_type in loss_types] - _names = [_loss_name(lt) for lt in loss_types] + _names = [_loss_name(lt) for lt in loss_types] return NamedTuple{Tuple(_names)}([out_loss_types...]) end -function compute_loss(ŷ, y::AbstractDimArray, y_nan::AbstractDimArray, targets, loss_types::Vector, agg::Function) - out_loss_types = [ - begin - losses = [_apply_loss(ŷ[k], y[col=At(k)], y_nan[col=At(k)], loss_type) for k in targets] - agg_loss = agg(losses) - NamedTuple{(targets..., Symbol(agg))}([losses..., agg_loss]) - end - for loss_type in loss_types] - _names = [_loss_name(lt) for lt in loss_types] - return NamedTuple{Tuple(_names)}([out_loss_types...]) + +""" + compute_loss(ŷ, y, y_nan, targets, training_loss, agg::Function) + compute_loss(ŷ, y, y_nan, targets, loss_types::Vector, agg::Function) + +Compute the loss for the given predictions and targets using the specified training loss (or vector of losses) type and aggregation function. + +# Arguments: +- `ŷ`: Predicted values. +- `y`: Target values (as a function, tuple `(y, y_sigma)`, or AbstractDimArray). +- `y_nan`: Mask for NaN values. +- `targets`: The targets for which the loss is computed. +- `training_loss`: The loss type to use during training, e.g., `:mse` or a custom function. +- `loss_types::Vector`: A vector of loss types to compute, e.g., `[:mse, :mae]`. +- `agg::Function`: The aggregation function to apply to the computed losses, e.g., `sum` or `mean`. + +Returns a single loss value if `training_loss` is provided, or a NamedTuple of losses for each type in `loss_types`. +""" +function compute_loss end + +_get_target_y(y, target) = y(target) +_get_target_y(y::AbstractDimArray, target) = y[col=At(target)] +_get_target_y(y::AbstractDimArray, targets::Vector) = y[col=At(targets)] + +function _get_target_y(y::Tuple, target) + y_obs, y_sigma = y + sigma = y_sigma isa Number ? y_sigma : y_sigma(target) + y_obs_val = _get_target_y(y_obs, target) + return (y_obs_val, sigma) end -# Helper to generate meaningful names for loss types +""" + _get_target_y(y, target) +Helper function to extract target-specific values from `y`, handling cases where `y` may be a tuple of `(y_obs, y_sigma)`. +""" +function _get_target_y end + +_get_target_nan(y_nan, target) = y_nan(target) +_get_target_nan(y_nan::AbstractDimArray, target) = y_nan[col=At(target)] +_get_target_nan(y_nan::AbstractDimArray, targets::Vector) = y_nan[col=At(targets)] + +""" + _get_target_nan(y_nan, target) + +Helper function to extract target-specific values from `y_nan`. +""" +function _get_target_nan end + +function assemble_loss(ŷ, y, y_nan, targets, ℒ_mix::LDataPPP{D, P}) where {D, P} + data_losses = if ℒ_mix.ℒ_data != nothing + [_apply_loss(ŷ[target], _get_target_y(y, target), _get_target_nan(y_nan, target), ℒ.ℒ_data) + for target in targets] + end + phys_losses = [_apply_loss(ŷ, nothing, nothing, 𝓁) for 𝓁 in ℒ_mix.ℒ_phys] + return vcat(data_losses..., phys_losses...) +end + +function assemble_loss(ŷ, y, y_nan, targets, loss_spec::PerTarget) + @assert length(targets) == length(loss_spec.losses) "Length of targets and PerTarget losses tuple must match" + losses = [ + _apply_loss( + ŷ, + _get_target_y(y, target), + _get_target_nan(y_nan, target), + target, + loss_t + ) for (target, loss_t) in zip(targets, loss_spec.losses) + ] + return losses +end + +""" + assemble_loss(ŷ, y, y_nan, targets, loss_spec) + +Helper function to assemble a vector of losses for each target based on the provided loss specification. + +# Arguments +- `ŷ`: Predictions for all targets. +- `y`: Target values (can be a function, tuple, or AbstractDimArray). +- `y_nan`: NaN mask (function or array). +- `targets`: List of target names. +- `loss_spec`: Loss specification (Symbol, Function, or Tuple). + +# Returns +- Vector of losses for each target. +""" +function assemble_loss end + function _loss_name(loss_spec::Symbol) return loss_spec end @@ -251,21 +361,18 @@ function _loss_name(loss_spec::Tuple) return _loss_name(loss_spec[1]) end -""" - compute_loss(ŷ, y, y_nan, targets, training_loss, agg::Function) - compute_loss(ŷ, y, y_nan, targets, loss_types::Vector, agg::Function) - -Compute the loss for the given predictions and targets using the specified training loss (or vector of losses) type and aggregation function. +function _loss_name(loss_spec::DataAndPhysicsLoss) + data_name = loss_spec.data_loss === nothing ? "" : _loss_name(loss_spec.data_loss) + if isempty(loss_spec.physics_loss) + return Symbol(data_name) + end + num_physics = length(loss_spec.physics_loss) + return Symbol(data_name, "_plus_", num_physics, "_physics") +end -# Arguments: -- `ŷ`: Predicted values. -- `y`: Target values. -- `y_nan`: Mask for NaN values. -- `targets`: The targets for which the loss is computed. -- `training_loss`: The loss type to use during training, e.g., `:mse`. -- `loss_types::Vector`: A vector of loss types to compute, e.g., `[:mse, :mae]`. -- `agg::Function`: The aggregation function to apply to the computed losses, e.g., `sum` or `mean`. +""" + _loss_name(loss_spec::Symbol|Function|Tuple) -Returns a single loss value if `training_loss` is provided, or a NamedTuple of losses for each type in `loss_types`. +Helper function to generate a meaningful name for a loss specification """ -function compute_loss end \ No newline at end of file +function _loss_name end \ No newline at end of file diff --git a/src/utils/loss_fn.jl b/src/utils/loss_fn.jl index a7bc152b..dd34de6d 100644 --- a/src/utils/loss_fn.jl +++ b/src/utils/loss_fn.jl @@ -41,6 +41,13 @@ loss = loss_fn(ŷ, y, y_nan, (scaled_loss, (scale=2.0,))) # With both args and kwargs complex_loss(ŷ, y, w; scale=1.0) = scale * w * mean(abs2, ŷ .- y) loss = loss_fn(ŷ, y, y_nan, (complex_loss, (0.5,), (scale=2.0,))) + +# Generic function: accepts y either as Array or as (y, σ) +function uncertainty_loss(ŷ, y_y_σ::Tuple) + y, y_σ = y_y_σ + return mean(abs2, (ŷ .- y) ./ (y_σ .^2 .+ 1e-6)) +end +loss = loss_fn(ŷ, y_y_σ::Tuple, y_nan, uncertainty_loss) ``` You can define additional predefined loss functions by adding more methods: @@ -53,6 +60,24 @@ end """ function loss_fn end +function _mask_y(y, y_nan) + return y[y_nan] +end +function _mask_y(y_tuple::Tuple, y_nan) + yvals, y_σ = y_tuple + y_σ_nan = y_σ isa AbstractArray ? y_σ[y_nan] : y_σ + return (yvals[y_nan], y_σ_nan) +end + +""" + _mask_y(y, y_nan) + _mask_y((y, y_σ), y_nan) + +Helper function to mask target values based on NaN mask. Handles both Array and Tuple (y, y_σ) formats. +""" +function _mask_y end + + function loss_fn(ŷ, y, y_nan, ::Val{:rmse}) return sqrt(mean(abs2, (ŷ[y_nan] .- y[y_nan]))) end @@ -76,19 +101,46 @@ function loss_fn(ŷ, y, y_nan, ::Val{:nse}) return sum((ŷ[y_nan] .- y[y_nan]).^2) / sum((y[y_nan] .- mean(y[y_nan])).^2) end +# Generic function: accepts y either as Array or as (y, σ) function loss_fn(ŷ, y, y_nan, training_loss::Function) - return training_loss(ŷ[y_nan], y[y_nan]) + masked_y = _mask_y(y, y_nan) + return training_loss(ŷ[y_nan], masked_y) end + function loss_fn(ŷ, y, y_nan, training_loss::Tuple{Function, Tuple}) f, args = training_loss - return f(ŷ[y_nan], y[y_nan], args...) + masked_y = _mask_y(y, y_nan) + return f(ŷ[y_nan], masked_y, args...) end function loss_fn(ŷ, y, y_nan, training_loss::Tuple{Function, NamedTuple}) f, kwargs = training_loss - return f(ŷ[y_nan], y[y_nan]; kwargs...) + masked_y = _mask_y(y, y_nan) + return f(ŷ[y_nan], masked_y; kwargs...) end + function loss_fn(ŷ, y, y_nan, training_loss::Tuple{Function, Tuple, NamedTuple}) f, args, kwargs = training_loss - return f(ŷ[y_nan], y[y_nan], args...; kwargs...) + masked_y = _mask_y(y, y_nan) + return f(ŷ[y_nan], masked_y, args...; kwargs...) +end + +function loss_fn(ŷ_all, y, y_nan, training_loss::LPPP{<:Function}) + f = training_loss.loss + return f(ŷ_all) +end + +function loss_fn(ŷ_all, y, y_nan, training_loss::LPPP{<:Tuple{Function, Tuple}}) + f, args = training_loss.loss + return f(ŷ_all, args...) +end + +function loss_fn(ŷ_all, y, y_nan, training_loss::LPPP{<:Tuple{Function, NamedTuple}}) + f, kwargs = training_loss.loss + return f(ŷ_all; kwargs...) +end + +function loss_fn(ŷ_all, y, y_nan, training_loss::LPPP{<:Tuple{Function, Tuple, NamedTuple}}) + f, args, kwargs = training_loss.loss + return f(ŷ_all, args...; kwargs...) end \ No newline at end of file diff --git a/test/test_logging_loss.jl b/test/test_logging_loss.jl index 5ecc9475..125a43fb 100644 --- a/test/test_logging_loss.jl +++ b/test/test_logging_loss.jl @@ -1,5 +1,5 @@ using Test -using EasyHybrid +using EasyHybrid: LoggingLoss, PerTarget, LossSum, Physics using Statistics using DimensionalData import EasyHybrid: compute_loss @@ -86,6 +86,7 @@ end # Test data setup ŷ = Dict(:var1 => [1.0, 2.0, 3.0], :var2 => [2.0, 3.0, 4.0]) y(target) = target == :var1 ? [1.1, 1.9, 3.2] : [1.8, 3.1, 3.9] + y_sigma(target) = target == :var1 ? [0.1, 0.2, 0.1] : [0.2, 0.1, 0.2] y_nan(target) = trues(3) targets = [:var1, :var2] @@ -121,6 +122,36 @@ end complex_loss(ŷ, y, w; scale=1.0) = scale * w * mean(abs2, ŷ .- y) loss = compute_loss(ŷ, y, y_nan, targets, (complex_loss, (0.5,), (scale=2.0,)), sum) @test loss isa Number + + # custom loss with uncertainty + function custom_loss_uncertainty(ŷ, y_and_sigma) + y_vals, y_σ = y_and_sigma + return mean(((ŷ .- y_vals).^2) ./ (y_σ .^2 .+ 1e-6)) + end + loss = compute_loss(ŷ, (y, y_sigma), y_nan, targets, custom_loss_uncertainty, sum) + @test loss isa Number + # a single sigma number + loss = compute_loss(ŷ, (y, 0.01), y_nan, targets, custom_loss_uncertainty, sum) + # @test loss isa Number + losses = compute_loss(ŷ, (y, y_sigma), y_nan, targets, [custom_loss_uncertainty,], sum) + @test losses isa NamedTuple + + @testset "Per-target losses" begin + # Mix of predefined and custom + loss_spec = PerTarget((:mse, custom_loss)) + loss = compute_loss(ŷ, y, y_nan, targets, loss_spec, sum) + expected_loss = EasyHybrid.loss_fn(ŷ[:var1], y(:var1), y_nan(:var1), Val(:mse)) + custom_loss(ŷ[:var2], y(:var2), y_nan(:var2)) + @test loss ≈ expected_loss + + # Mix of custom losses with arguments + loss_spec_args = PerTarget(((weighted_loss, (0.5,)), (scaled_loss, (scale=2.0,)))) + loss_args = compute_loss(ŷ, y, y_nan, targets, loss_spec_args, sum) + expected_loss_args = weighted_loss(ŷ[:var1], y(:var1), y_nan(:var1), 0.5) + scaled_loss(ŷ[:var2], y(:var2), y_nan(:var2); scale=2.0) + @test loss_args ≈ expected_loss_args + + # Mismatched number of losses and targets + @test_throws AssertionError compute_loss(ŷ, y, y_nan, targets, PerTarget((:mse,)), sum) + end end @testset "DimensionalData interface" begin diff --git a/test/test_loss_fn.jl b/test/test_loss_fn.jl index dfaf1083..64f620e4 100644 --- a/test/test_loss_fn.jl +++ b/test/test_loss_fn.jl @@ -7,12 +7,23 @@ using Statistics ŷ = [1.0, 2.0, 3.0, 4.0] y = [1.1, 1.9, 3.2, 3.8] y_nan = trues(4) # all values are valid + y_sigma = [0.1, 0.2, 0.1, 0.2] simple_loss(ŷ, y) = mean(abs2, ŷ .- y) weighted_loss(ŷ, y, w) = w * mean(abs2, ŷ .- y) scaled_loss(ŷ, y; scale=1.0) = scale * mean(abs2, ŷ .- y) complex_loss(ŷ, y, w; scale=1.0) = scale * w * mean(abs2, ŷ .- y) + function custom_loss_uncertainty(ŷ, y_and_sigma::Tuple) + y_vals, σ = y_and_sigma + return mean(((ŷ .- y_vals).^2) ./ (σ .^2 .+ 1e-6)) + end + + @testset "Uncertainty handling" begin + @test loss_fn(ŷ, (y, y_sigma), y_nan, custom_loss_uncertainty) ≈ mean(((ŷ .- y).^2) ./ (y_sigma .^2 .+ 1e-6)) + @test loss_fn(ŷ, (y, 0.2), y_nan, custom_loss_uncertainty) ≈ mean(((ŷ .- y).^2) ./ (0.2 .^2 .+ 1e-6)) + end + @testset "Predefined loss functions" begin # RMSE test @test loss_fn(ŷ, y, y_nan, Val(:rmse)) ≈ sqrt(mean(abs2, ŷ .- y))