Skip to content

Commit 5f07cf3

Browse files
Support CUDA on Julia 1.9+ via a package extension.
1 parent dfb67d6 commit 5f07cf3

File tree

5 files changed

+233
-8
lines changed

5 files changed

+233
-8
lines changed

Project.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,24 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1616
[compat]
1717
ArgCheck = "2"
1818
CEnum = "0.4"
19+
CUDA = "4, 5"
1920
DataStructures = "0.18"
2021
DocStringExtensions = "0.8, 0.9"
2122
Requires = "1"
23+
cuDNN = "1.1"
2224
julia = "1.6"
2325

26+
[extensions]
27+
CUDAExt = ["CUDA", "cuDNN"]
28+
2429
[extras]
2530
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
31+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
32+
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
2633

2734
[targets]
2835
test = ["Test"]
36+
37+
[weakdeps]
38+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
39+
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

ext/CUDAExt.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
module CUDAExt
2+
3+
# These functions are only defined for diagnostic purposes. Otherwise
4+
# the CUDA extension only relies on the CUDA and cuDNN dependencies to
5+
# have loaded the libraries needed by ONNXRunTime's CUDA execution
6+
# provider.
7+
import CUDA
8+
cuda_functional() = CUDA.functional()
9+
cuda_runtime_version() = CUDA.runtime_version()
10+
11+
end

src/ONNXRunTime.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module ONNXRunTime
2-
using Requires:@require
2+
if !isdefined(Base, :get_extension)
3+
using Requires: @require
4+
end
35

46
function _perm(arr::AbstractArray{T,N}) where {T,N}
57
ntuple(i->N+1-i, N)
@@ -14,9 +16,11 @@ end
1416
include("capi.jl")
1517
include("highlevel.jl")
1618

17-
function __init__()
18-
@require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
19-
CUDA.functional() && include("cuda.jl")
19+
@static if !isdefined(Base, :get_extension)
20+
function __init__()
21+
@require CUDA="052768ef-5323-5732-b1bb-66c8b64840ba" begin
22+
CUDA.functional() && include("cuda.jl")
23+
end
2024
end
2125
end
2226

src/highlevel.jl

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,31 @@ function load_inference(path::AbstractString; execution_provider::Symbol=:cpu,
6565
if execution_provider === :cpu
6666
session_options = CreateSessionOptions(api)
6767
elseif execution_provider === :cuda
68-
if !(isdefined(@__MODULE__, :CUDA))
69-
@warn """
70-
The $(repr(execution_provider)) requires the CUDA.jl package to be available. Try adding `import CUDA` to your code.
71-
"""
68+
if isdefined(Base, :get_extension)
69+
CUDAExt = Base.get_extension(@__MODULE__, :CUDAExt)
70+
if isnothing(CUDAExt)
71+
@warn """
72+
The $(repr(execution_provider)) execution provider requires the CUDA.jl and cuDNN.jl packages to be available. Try adding `import CUDA, cuDNN` to your code.
73+
"""
74+
elseif !getfield(CUDAExt, :cuda_functional)()
75+
@warn """
76+
The $(repr(execution_provider)) execution provider requires CUDA to be functional. See `CUDA.functional`.
77+
"""
78+
elseif !(v"11.8" <= getfield(CUDAExt, :cuda_runtime_version)() < v"12")
79+
# Note: The supported version range is a property
80+
# inherited from the CUDA runtime library and needs to
81+
# be updated when the library is updated. It may be a
82+
# good idea to centralize this information somewhere.
83+
@warn """
84+
The $(repr(execution_provider)) execution provider requires a CUDA runtime version of at least 11.8 but less than 12. See `CUDA.set_runtime_version!`.
85+
"""
86+
end
87+
else
88+
if !isdefined(@__MODULE__, :CUDA)
89+
@warn """
90+
The $(repr(execution_provider)) execution provider requires the CUDA.jl package to be available. Try adding `import CUDA` to your code.
91+
"""
92+
end
7293
end
7394
session_options = CreateSessionOptions(api)
7495
cuda_options = OrtCUDAProviderOptions()

test/test_cuda_extension.jl

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# This file is not included from `runtests.jl` nor run in CI.
2+
#
3+
# Run it with `julia tests/test_cuda_extension.jl`. This requires that
4+
# Julia is installed with juliaup and will involve downloading of a
5+
# lot of big artifacts. The output will contain lots of error messages
6+
# from caught errors; what matters is that all testsets pass.
7+
8+
using Test
9+
10+
juliaup_found = false
11+
try run(pipeline(`juliaup --version`, stdout = devnull, stderr = devnull))
12+
global juliaup_found = true
13+
catch e
14+
end
15+
16+
if !juliaup_found
17+
error("`juliaup` needs to be installed for the CUDA extension tests")
18+
end
19+
20+
wait(run(`juliaup add 1.6`, wait = false))
21+
wait(run(`juliaup add 1.9`, wait = false))
22+
23+
package_path = dirname(@__DIR__)
24+
onnx_path = joinpath(@__DIR__, "data", "copy2d.onnx")
25+
26+
function with_environment(f::Function; cuda_runtime_version)
27+
mktempdir() do env
28+
write(joinpath(env, "LocalPreferences.toml"),
29+
"""
30+
[CUDA_Runtime_jll]
31+
version = "$(cuda_runtime_version)"
32+
""")
33+
write(joinpath(env, "Project.toml"),
34+
"""
35+
[extras]
36+
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"
37+
""")
38+
f(env)
39+
end
40+
end
41+
42+
@testset "Julia 1.6 CUDA 3" begin
43+
with_environment(cuda_runtime_version = "11.8") do env
44+
install_script = """
45+
using Pkg
46+
Pkg.develop(path = "$(package_path)")
47+
Pkg.add(name = "CUDA", version = "3")
48+
"""
49+
@test success(run(`julia +1.6 --project=$(env) -e "$(install_script)"`))
50+
# Correct dependency for :cuda.
51+
test_script = """
52+
using ONNXRunTime, CUDA
53+
load_inference("$(onnx_path)", execution_provider = :cuda)
54+
"""
55+
@test success(run(`julia +1.6 --project=$(env) -e "$(test_script)"`))
56+
# CUDA not loaded.
57+
test_script = """
58+
using ONNXRunTime
59+
load_inference("$(onnx_path)", execution_provider = :cuda)
60+
"""
61+
@test_throws ProcessFailedException run(`julia +1.6 --project=$(env) -e "$(test_script)"`)
62+
# CUDA not loaded but running on CPU, so it's fine.
63+
test_script = """
64+
using ONNXRunTime
65+
load_inference("$(onnx_path)", execution_provider = :cpu)
66+
"""
67+
@test success(run(`julia +1.6 --project=$(env) -e "$(test_script)"`))
68+
end
69+
end
70+
71+
@testset "Julia 1.9 CUDA 3" begin
72+
with_environment(cuda_runtime_version = "11.8") do env
73+
install_script = """
74+
using Pkg
75+
Pkg.develop(path = "$(package_path)")
76+
Pkg.add(name = "CUDA", version = "3")
77+
"""
78+
# CUDA 3 is not possible to install together with ONNXRunTime
79+
# on Julia 1.9 due to Compat requirements.
80+
@test_throws ProcessFailedException run(`julia +1.9 --project=$(env) -e "$(install_script)"`)
81+
end
82+
end
83+
84+
@testset "Julia 1.9 CUDA.jl $(cuda_version) CUDA runtime 11.8" for cuda_version in (4, 5)
85+
with_environment(cuda_runtime_version = "11.8") do env
86+
install_script = """
87+
using Pkg
88+
Pkg.develop(path = "$(package_path)")
89+
Pkg.add(name = "CUDA", version = "$(cuda_version)")
90+
Pkg.add(name = "cuDNN")
91+
"""
92+
@test success(run(`julia +1.9 --project=$(env) -e "$(install_script)"`))
93+
# Correct dependencies for :cuda.
94+
test_script = """
95+
using ONNXRunTime, CUDA, cuDNN
96+
load_inference("$(onnx_path)", execution_provider = :cuda)
97+
"""
98+
@test success(run(`julia +1.9 --project=$(env) -e "$(test_script)"`))
99+
# Neither CUDA nor cuDNN loaded.
100+
test_script = """
101+
using ONNXRunTime
102+
load_inference("$(onnx_path)", execution_provider = :cuda)
103+
"""
104+
@test_throws ProcessFailedException run(`julia +1.9 --project=$(env) -e "$(test_script)"`)
105+
# Neither CUDA nor cuDNN loaded but running on CPU, so it's fine.
106+
test_script = """
107+
using ONNXRunTime
108+
load_inference("$(onnx_path)", execution_provider = :cpu)
109+
"""
110+
# CUDA not loaded. Well, cuDNN pulls in CUDA so this passes anyway.
111+
test_script = """
112+
using ONNXRunTime
113+
using cuDNN
114+
load_inference("$(onnx_path)", execution_provider = :cuda)
115+
"""
116+
@test success(run(`julia +1.9 --project=$(env) -e "$(test_script)"`))
117+
# CUDA not loaded but running on CPU, so it's fine.
118+
test_script = """
119+
using ONNXRunTime
120+
using cuDNN
121+
load_inference("$(onnx_path)", execution_provider = :cpu)
122+
"""
123+
@test success(run(`julia +1.9 --project=$(env) -e "$(test_script)"`))
124+
# cuDNN not loaded.
125+
test_script = """
126+
using ONNXRunTime
127+
using CUDA
128+
load_inference("$(onnx_path)", execution_provider = :cuda)
129+
"""
130+
@test_throws ProcessFailedException run(`julia +1.9 --project=$(env) -e "$(test_script)"`)
131+
# cuDNN not loaded but running on CPU, so it's fine.
132+
test_script = """
133+
using ONNXRunTime
134+
using CUDA
135+
load_inference("$(onnx_path)", execution_provider = :cpu)
136+
"""
137+
@test success(run(`julia +1.9 --project=$(env) -e "$(test_script)"`))
138+
end
139+
end
140+
141+
@testset "Julia 1.9 CUDA.jl $(cuda_version) CUDA runtime 11.6" for cuda_version in (4, 5)
142+
with_environment(cuda_runtime_version = "11.6") do env
143+
install_script = """
144+
using Pkg
145+
Pkg.develop(path = "$(package_path)")
146+
Pkg.add(name = "CUDA", version = "$(cuda_version)")
147+
Pkg.add(name = "cuDNN")
148+
"""
149+
@test success(run(`julia +1.9 --project=$(env) -e "$(install_script)"`))
150+
# Correct dependencies for :cuda. CUDA runtime version is
151+
# lower than officially supported but close enough to at least
152+
# load so there will be a warning but no error.
153+
test_script = """
154+
using ONNXRunTime, CUDA, cuDNN
155+
load_inference("$(onnx_path)", execution_provider = :cuda)
156+
"""
157+
@test success(run(`julia +1.9 --project=$(env) -e "$(test_script)"`))
158+
end
159+
end
160+
161+
@testset "Julia 1.9 CUDA.jl $(cuda_version) CUDA runtime 12.1" for cuda_version in (4, 5)
162+
with_environment(cuda_runtime_version = "12.1") do env
163+
install_script = """
164+
using Pkg
165+
Pkg.develop(path = "$(package_path)")
166+
Pkg.add(name = "CUDA", version = "$(cuda_version)")
167+
Pkg.add(name = "cuDNN")
168+
"""
169+
@test success(run(`julia +1.9 --project=$(env) -e "$(install_script)"`))
170+
# Correct dependencies for :cuda but fails due to bad version
171+
# of CUDA runtime.
172+
test_script = """
173+
using ONNXRunTime, CUDA, cuDNN
174+
load_inference("$(onnx_path)", execution_provider = :cuda)
175+
"""
176+
@test_throws ProcessFailedException run(`julia +1.9 --project=$(env) -e "$(test_script)"`)
177+
end
178+
end

0 commit comments

Comments
 (0)