Skip to content

Commit 7e9d2b2

Browse files
fix (#137) missing problem in forward
1 parent 85af0b3 commit 7e9d2b2

File tree

3 files changed

+26
-6
lines changed

3 files changed

+26
-6
lines changed

projects/Respiration_Fluxnet/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,6 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
66
EasyHybrid = "61bb816a-e6af-4913-ab9e-91bff2e122e3"
77
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
88
OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5"
9+
RetryManagers = "7a8cfeaa-437c-4519-b75e-06295593f538"
910
TidierPlots = "337ecbd1-5042-4e2a-ae6f-ca776f97570a"
1011
WGLMakie = "276b4fcb-3e11-5398-bf8b-a0c2d153d008"

projects/Respiration_Fluxnet/script.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@
55
using Pkg
66
project_path = "projects/Respiration_Fluxnet"
77
Pkg.activate(project_path)
8-
9-
#Pkg.develop(path=pwd())
8+
Pkg.develop(path=pwd())
109
#Pkg.instantiate()
1110

1211
# start using the package
@@ -25,6 +24,8 @@ include("Data/load_data.jl")
2524

2625
site = "US-SRG"
2726

27+
fluxnet_data = load_fluxnet_nc(joinpath(project_path, "Data", "data20240123", "$site.nc"), timevar="date")
28+
2829
# explore data structure
2930
println(names(fluxnet_data.timeseries))
3031
println(fluxnet_data.scalars)
@@ -77,11 +78,10 @@ parameters = (
7778
target_FluxPartModel = [:NEE]
7879
forcing_FluxPartModel = [:SW_IN, :TA]
7980

80-
predictors = (Rb = [:SWC_shallow, :P, :WS],
81+
predictors = (Rb = [:SWC_shallow, :P, :WS, :cos_dayofyear, :sin_dayofyear],
8182
RUE = [:TA, :P, :WS, :SWC_shallow, :VPD, :SW_IN_POT, :dSW_IN_POT, :dSW_IN_POT_DAY])
8283

8384
global_param_names = [:Q10]
84-
8585
hybrid_model = constructHybridModel(
8686
predictors,
8787
forcing_FluxPartModel,
@@ -91,7 +91,7 @@ hybrid_model = constructHybridModel(
9191
global_param_names,
9292
scale_nn_outputs=true,
9393
hidden_layers = [32, 32],
94-
activation = tanh,
94+
activation = sigmoid,
9595
input_batchnorm = true,
9696
start_from_default = false
9797
)
@@ -178,7 +178,7 @@ using TidierPlots
178178
using WGLMakie
179179
beautiful_makie_theme = Attributes(fonts=(;regular="CMU Serif"))
180180

181-
ggplot(forward_run, aes(x=:GPP_NT, y=:GPP_pred)) + geom_point() + beautiful_makie_theme
181+
ggplot(forward_run, @aes(x=GPP_NT, y=GPP_pred)) + geom_point() + beautiful_makie_theme
182182

183183
idx = .!isnan.(forward_run.GPP_NT) .& .!isnan.(forward_run.GPP_pred)
184184
EasyHybrid.poplot(forward_run.GPP_NT[idx], forward_run.GPP_pred[idx], "GPP", xlabel = "Nighttime GPP", ylabel = "Hybrid GPP")

src/models/GenericHybridModel.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,15 @@ end
382382
function (m::SingleNNHybridModel)(df::DataFrame, ps, st)
383383
@warn "Only makes sense in test mode, not training!"
384384

385+
386+
# Process numeric or missing-containing columns
387+
for col in names(df)
388+
what_type = eltype(df[!, col])
389+
if what_type <: Union{Missing, Real} || what_type <: Real
390+
df[!, col] = Float64.(coalesce.(df[!, col], NaN))
391+
end
392+
end
393+
385394
all_data = to_keyedArray(df)
386395
x, _ = prepare_data(m, all_data)
387396
out, _ = m(x, ps, LuxCore.testmode(st))
@@ -470,8 +479,18 @@ end
470479
function (m::MultiNNHybridModel)(df::DataFrame, ps, st)
471480
@warn "Only makes sense in test mode, not training!"
472481

482+
# Process numeric or missing-containing columns
483+
for col in names(df)
484+
what_type = eltype(df[!, col])
485+
if what_type <: Union{Missing, Real} || what_type <: Real
486+
df[!, col] = Float64.(coalesce.(df[!, col], NaN))
487+
end
488+
end
489+
473490
all_data = to_keyedArray(df)
491+
474492
x, _ = prepare_data(m, all_data)
493+
@show typeof(x)
475494
out, _ = m(x, ps, LuxCore.testmode(st))
476495
dfnew = copy(df)
477496
for k in keys(out)

0 commit comments

Comments
 (0)