Skip to content

Commit 9812009

Browse files
committed
refactor: assume get_op_name gives string not vector
1 parent 730dcf0 commit 9812009

File tree

3 files changed

+16
-17
lines changed

3 files changed

+16
-17
lines changed

src/Strings.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ function dispatch_op_name(::Val{deg}, ::Nothing, idx)::Vector{Char} where {deg}
1313
end
1414
function dispatch_op_name(::Val{deg}, operators::AbstractOperatorEnum, idx) where {deg}
1515
if deg == 1
16-
return get_op_name(operators.unaops[idx])::Vector{Char}
16+
return collect(get_op_name(operators.unaops[idx])::String)
1717
else
18-
return get_op_name(operators.binops[idx])::Vector{Char}
18+
return collect(get_op_name(operators.binops[idx])::String)
1919
end
2020
end
2121

22-
const OP_NAME_CACHE = (; x=Dict{UInt64,Vector{Char}}(), lock=Threads.SpinLock())
22+
const OP_NAME_CACHE = (; x=Dict{UInt64,String}(), lock=Threads.SpinLock())
2323

2424
function get_op_name(op)
2525
h = hash(op)
@@ -29,18 +29,17 @@ function get_op_name(op)
2929
if haskey(cache, h)
3030
return cache[h]
3131
end
32-
op_s = sizehint!(Char[], 10)
33-
if op isa Broadcast.BroadcastFunction
34-
append!(op_s, string(op.f))
35-
if length(op_s) == 1 && first(op_s) in ('+', '-', '*', '/', '^')
32+
op_s = if op isa Broadcast.BroadcastFunction
33+
base_op_s = string(op.f)
34+
if length(base_op_s) == 1 && first(base_op_s) in ('+', '-', '*', '/', '^')
3635
# Like `.+`
37-
pushfirst!(op_s, '.')
36+
string('.', base_op_s)
3837
else
3938
# Like `cos.`
40-
push!(op_s, '.')
39+
string(base_op_s, '.')
4140
end
4241
else
43-
append!(op_s, string(op))
42+
string(op)
4443
end
4544
cache[h] = op_s
4645
return op_s

test/test_params.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ maximum_residual = 1e-2
2828

2929
custom_cos(x) = cos(x)^2
3030

31-
DE.get_op_name(::typeof(safe_log)) = ['l', 'o', 'g']
32-
DE.get_op_name(::typeof(safe_log2)) = ['l', 'o', 'g', '2']
33-
DE.get_op_name(::typeof(safe_log10)) = ['l', 'o', 'g', '1', '0']
34-
DE.get_op_name(::typeof(safe_log1p)) = ['l', 'o', 'g', '1', 'p']
35-
DE.get_op_name(::typeof(safe_acosh)) = ['a', 'c', 'o', 's', 'h']
36-
DE.get_op_name(::typeof(safe_sqrt)) = ['s', 'q', 'r', 't']
31+
DE.get_op_name(::typeof(safe_log)) = "log"
32+
DE.get_op_name(::typeof(safe_log2)) = "log2"
33+
DE.get_op_name(::typeof(safe_log10)) = "log10"
34+
DE.get_op_name(::typeof(safe_log1p)) = "log1p"
35+
DE.get_op_name(::typeof(safe_acosh)) = "acosh"
36+
DE.get_op_name(::typeof(safe_sqrt)) = "sqrt"
3737
end
3838

3939
HEADER_GUARD_TEST_PARAMS = true

test/test_print.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ end
3535

3636
@isdefined(safe_pow) || @eval begin
3737
safe_pow(x::T, y::T) where {T<:Number} = (x < 0 && y != round(y)) ? T(NaN) : x^y
38-
DE.get_op_name(::typeof(safe_pow)) = ['^']
38+
DE.get_op_name(::typeof(safe_pow)) = "^"
3939
end
4040
for binop in [safe_pow, ^]
4141
opts = OperatorEnum(;

0 commit comments

Comments
 (0)