Skip to content
Draft

yax #157

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions projects/RbQ10/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,8 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0"
EasyHybrid = "61bb816a-e6af-4913-ab9e-91bff2e122e3"
GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
YAXArrays = "c21b50f5-aa40-41ea-b809-c0f5e47bfa5c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[sources]
YAXArrays = {rev = "main", url = "https://github.com/JuliaDataCubes/YAXArrays.jl"}
24 changes: 22 additions & 2 deletions projects/RbQ10/Q10_dd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ grads = backtrace(l)[1]
# TODO: test DimArray inputs
using DimensionalData, ChainRulesCore
# mat = Matrix(df)'
mat = Array(Matrix(df)')
mat = Float32.(Array(Matrix(df)'))
da = DimArray(mat, (Dim{:col}(Symbol.(names(df))), Dim{:row}(1:size(df,1))))

##! new dispatch
Expand Down Expand Up @@ -91,7 +91,7 @@ ar = rand(3,3)
A = DimArray(ar, (Y([:a,:b,:c]), X(1:3)));
grad = Zygote.gradient(x -> sum(x[Y=At(:b)]), A)

xy = EasyHybrid.split_data((ds_p_f, ds_t), 0.8, shuffle=true, rng=Random.default_rng())
# xy = EasyHybrid.split_data((ds_p_f, ds_t), 0.8, shuffle=true, rng=Random.default_rng())

EasyHybrid.get_prediction_target_names(RbQ10)

Expand All @@ -100,3 +100,23 @@ xy1 = EasyHybrid.prepare_data(RbQ10, da)
(x_train, y_train), (x_val, y_val) = EasyHybrid.split_data(da, RbQ10) # ; shuffleobs=false, split_data_at=0.8

out = train(RbQ10, da, (:Q10, ); nepochs=200, batchsize=512, opt=Adam(0.01));

using YAXArrays
axDims = dims(da)

ds_yax = YAXArray(axDims, da.data)

ds_p_f = ds_yax[col=At(forcing_names ∪ predictor_names)]
ds_t = ds_yax[col=At(target_names)]
ds_t_nan = .!isnan.(ds_t) # produces 1×35064 YAXArray{Float32, 2}, not a Bool
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has to be done with map otherwise it's not a Bool

ds_t_nan = map(x -> !isnan(x), ds_t) # 1×35064 YAXArray{Bool, 2}
length(ds_t_nan)
# is_no_nan = .!isnan.(y)


ls = EasyHybrid.lossfn(RbQ10, ds_p_f, (ds_t, ds_t_nan), ps, st, LoggingLoss())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This runs now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, yes indeed it seems to be the nan mask. Let's the most generic version so that all input cases work.

ls_logs = EasyHybrid.lossfn(RbQ10, ds_p_f, (ds_t, ds_t_nan), ps, st, LoggingLoss(train_mode=false))
acc_ = EasyHybrid.evaluate_acc(RbQ10, ds_p_f, ds_t, ds_t_nan, ps, st, [:mse, :r2], :mse, sum)

# TODO how to proceed - would it work already for multiple targets?
out_yax = train(RbQ10, ds_yax, (:Q10, ); nepochs=200, batchsize=512, opt=Adam(0.01));
10 changes: 7 additions & 3 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,10 @@ function train(hybridModel, data, save_ps;
opt_state = Optimisers.setup(opt, ps)

# ? initial losses
is_no_nan_t = .!isnan.(y_train)
is_no_nan_v = .!isnan.(y_val)
# is_no_nan_t = .!isnan.(y_train)
is_no_nan_t = map(x -> !isnan(x), y_train)
# is_no_nan_v = .!isnan.(y_val)
is_no_nan_v = map(x -> !isnan(x), y_val)

l_init_train, _, init_ŷ_train = evaluate_acc(hybridModel, x_train, y_train, is_no_nan_t, ps, st, loss_types, training_loss, agg)
l_init_val, _, init_ŷ_val = evaluate_acc(hybridModel, x_val, y_val, is_no_nan_v, ps, st, loss_types, training_loss, agg)
Expand Down Expand Up @@ -194,7 +196,9 @@ function train(hybridModel, data, save_ps;
for epoch in 1:nepochs
for (x, y) in train_loader
# ? check NaN indices before going forward, and pass filtered `x, y`.
is_no_nan = .!isnan.(y)
# is_no_nan = .!isnan.(y)
is_no_nan = map(x -> !isnan(x), y) # doing this due to YAXArray Bool issue

if length(is_no_nan)>0 # ! be careful here, multivariate needs fine tuning
l, backtrace = Zygote.pullback((ps) -> lossfn(hybridModel, x, (y, is_no_nan), ps, st,
LoggingLoss(training_loss=training_loss, agg=agg)), ps)
Expand Down
12 changes: 11 additions & 1 deletion src/utils/loss_fn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,17 @@ function loss_fn(ŷ, y, y_nan, ::Val{:rmse})
return sqrt(mean(abs2, (ŷ[y_nan] .- y[y_nan])))
end
function loss_fn(ŷ, y, y_nan, ::Val{:mse})
return mean(abs2, (ŷ[y_nan] .- y[y_nan]))
# Option 1: Convert to Array and compute MSE
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Option 1 would be converting to an array but what I am not sure is where it would use a view, when it copies and when comes into memory and so on

#yh = Array(ŷ[y_nan])
#yt = Array(y[y_nan])
#return mean(abs2, yh .- yt)

# Option 2: Use YAXArray directly but map has to be used
return mean(x -> x, map((a,b)->(a-b)^2, ŷ[y_nan], y[y_nan]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here broadcast does not work - map works or we it an array. Not quite sure how we proceed from here. Should our model give a DimArray back? Not at all a YaxArrays / DimArrays expert ;-)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in principle all types should work, so YAXArrays should be fine.


# Option 3 gives an error
#return mean(abs2, (ŷ[y_nan] .- y[y_nan])) # errors with ERROR: MethodError: no method matching to_yax(::Vector{Float32}) The function `to_yax` exists, but no method is defined for this combination of argument types.
# I guess our model output would need to yax and not Vector{Float32}
end
function loss_fn(ŷ, y, y_nan, ::Val{:mae})
return mean(abs, (ŷ[y_nan] .- y[y_nan]))
Expand Down
Loading