@@ -206,17 +206,33 @@ function eval_grad_tree_array(
206
206
variable:: Union{Bool,Val} = Val (false ),
207
207
turbo:: Union{Bool,Val} = Val (false ),
208
208
) where {T<: Number }
209
- n_gradients = if isa (variable, Val{true }) || (isa (variable, Bool) && variable)
209
+ variable_mode = isa (variable, Val{true }) || (isa (variable, Bool) && variable)
210
+ constant_mode = isa (variable, Val{false }) || (isa (variable, Bool) && ! variable)
211
+ both_mode = isa (variable, Val{:both })
212
+
213
+ n_gradients = if variable_mode
210
214
size (cX, 1 ):: Int
211
- else
215
+ elseif constant_mode
212
216
count_constants (tree):: Int
217
+ elseif both_mode
218
+ size (cX, 1 ) + count_constants (tree)
213
219
end
214
- result = if isa (variable, Val{true }) || (variable isa Bool && variable)
220
+
221
+ result = if variable_mode
215
222
eval_grad_tree_array (tree, n_gradients, nothing , cX, operators, Val (true ))
216
- else
223
+ elseif constant_mode
217
224
index_tree = index_constants (tree)
218
- eval_grad_tree_array (tree, n_gradients, index_tree, cX, operators, Val (false ))
219
- end
225
+ eval_grad_tree_array (
226
+ tree, n_gradients, index_tree, cX, operators, Val (false )
227
+ )
228
+ elseif both_mode
229
+ # features come first because we can use size(cX, 1) to skip them
230
+ index_tree = index_constants (tree)
231
+ eval_grad_tree_array (
232
+ tree, n_gradients, index_tree, cX, operators, Val (:both )
233
+ )
234
+ end :: ResultOk2
235
+
220
236
return (result. x, result. dx, result. ok)
221
237
end
222
238
@@ -226,11 +242,9 @@ function eval_grad_tree_array(
226
242
index_tree:: Union{NodeIndex,Nothing} ,
227
243
cX:: AbstractMatrix{T} ,
228
244
operators:: OperatorEnum ,
229
- :: Val{variable} ,
230
- ):: ResultOk2 where {T<: Number ,variable}
231
- result = _eval_grad_tree_array (
232
- tree, n_gradients, index_tree, cX, operators, Val (variable)
233
- )
245
+ :: Val{mode} ,
246
+ ):: ResultOk2 where {T<: Number ,mode}
247
+ result = _eval_grad_tree_array (tree, n_gradients, index_tree, cX, operators, Val (mode))
234
248
! result. ok && return result
235
249
return ResultOk2 (
236
250
result. x, result. dx, ! (is_bad_array (result. x) || is_bad_array (result. dx))
@@ -260,30 +274,18 @@ end
260
274
index_tree:: Union{NodeIndex,Nothing} ,
261
275
cX:: AbstractMatrix{T} ,
262
276
operators:: OperatorEnum ,
263
- :: Val{variable } ,
264
- ):: ResultOk2 where {T<: Number ,variable }
277
+ :: Val{mode } ,
278
+ ):: ResultOk2 where {T<: Number ,mode }
265
279
nuna = get_nuna (operators)
266
280
nbin = get_nbin (operators)
267
281
deg1_branch_skeleton = quote
268
282
grad_deg1_eval (
269
- tree,
270
- n_gradients,
271
- index_tree,
272
- cX,
273
- operators. unaops[i],
274
- operators,
275
- Val (variable),
283
+ tree, n_gradients, index_tree, cX, operators. unaops[i], operators, Val (mode)
276
284
)
277
285
end
278
286
deg2_branch_skeleton = quote
279
287
grad_deg2_eval (
280
- tree,
281
- n_gradients,
282
- index_tree,
283
- cX,
284
- operators. binops[i],
285
- operators,
286
- Val (variable),
288
+ tree, n_gradients, index_tree, cX, operators. binops[i], operators, Val (mode)
287
289
)
288
290
end
289
291
deg1_branch = if nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN
310
312
end
311
313
quote
312
314
if tree. degree == 0
313
- grad_deg0_eval (tree, n_gradients, index_tree, cX, Val (variable ))
315
+ grad_deg0_eval (tree, n_gradients, index_tree, cX, Val (mode ))
314
316
elseif tree. degree == 1
315
317
$ deg1_branch
316
318
else
@@ -324,8 +326,8 @@ function grad_deg0_eval(
324
326
n_gradients,
325
327
index_tree:: Union{NodeIndex,Nothing} ,
326
328
cX:: AbstractMatrix{T} ,
327
- :: Val{variable } ,
328
- ):: ResultOk2 where {T<: Number ,variable }
329
+ :: Val{mode } ,
330
+ ):: ResultOk2 where {T<: Number ,mode }
329
331
const_part = deg0_eval (tree, cX). x
330
332
331
333
zero_mat = if isa (cX, Array)
@@ -334,17 +336,26 @@ function grad_deg0_eval(
334
336
hcat ([fill_similar (zero (T), cX, axes (cX, 2 )) for _ in 1 : n_gradients]. .. )'
335
337
end
336
338
337
- if variable == tree. constant
339
+ if (mode isa Bool && mode == tree. constant)
340
+ # No gradients at this leaf node
338
341
return ResultOk2 (const_part, zero_mat, true )
339
342
end
340
343
341
- index = if variable
342
- tree. feature
343
- else
344
+ index = if (mode isa Bool && mode)
345
+ tree. feature:: UInt16
346
+ elseif (mode isa Bool && ! mode)
344
347
(index_tree === nothing ? zero (UInt16) : index_tree. val:: UInt16 )
348
+ elseif mode == :both
349
+ index_tree:: NodeIndex
350
+ if tree. constant
351
+ index_tree. val:: UInt16 + UInt16 (size (cX, 1 ))
352
+ else
353
+ tree. feature:: UInt16
354
+ end
345
355
end
356
+
346
357
derivative_part = zero_mat
347
- derivative_part[index, :] . = one (T)
358
+ fill! ( @view ( derivative_part[index, :]), one (T) )
348
359
return ResultOk2 (const_part, derivative_part, true )
349
360
end
350
361
@@ -355,15 +366,15 @@ function grad_deg1_eval(
355
366
cX:: AbstractMatrix{T} ,
356
367
op:: F ,
357
368
operators:: OperatorEnum ,
358
- :: Val{variable } ,
359
- ):: ResultOk2 where {T<: Number ,F,variable }
369
+ :: Val{mode } ,
370
+ ):: ResultOk2 where {T<: Number ,F,mode }
360
371
result = eval_grad_tree_array (
361
372
tree. l,
362
373
n_gradients,
363
374
index_tree === nothing ? index_tree : index_tree. l,
364
375
cX,
365
376
operators,
366
- Val (variable ),
377
+ Val (mode ),
367
378
)
368
379
! result. ok && return result
369
380
@@ -389,15 +400,15 @@ function grad_deg2_eval(
389
400
cX:: AbstractMatrix{T} ,
390
401
op:: F ,
391
402
operators:: OperatorEnum ,
392
- :: Val{variable } ,
393
- ):: ResultOk2 where {T<: Number ,F,variable }
403
+ :: Val{mode } ,
404
+ ):: ResultOk2 where {T<: Number ,F,mode }
394
405
result_l = eval_grad_tree_array (
395
406
tree. l,
396
407
n_gradients,
397
408
index_tree === nothing ? index_tree : index_tree. l,
398
409
cX,
399
410
operators,
400
- Val (variable ),
411
+ Val (mode ),
401
412
)
402
413
! result_l. ok && return result_l
403
414
result_r = eval_grad_tree_array (
@@ -406,7 +417,7 @@ function grad_deg2_eval(
406
417
index_tree === nothing ? index_tree : index_tree. r,
407
418
cX,
408
419
operators,
409
- Val (variable ),
420
+ Val (mode ),
410
421
)
411
422
! result_r. ok && return result_r
412
423
0 commit comments