Skip to content

Commit a43a150

Browse files
committed
[pocl] RNG support
closes #641 Testing locally, I am running into #624
1 parent d052112 commit a43a150

File tree

9 files changed

+612
-15
lines changed

9 files changed

+612
-15
lines changed

Project.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KernelAbstractions"
22
uuid = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
3-
authors = ["Valentin Churavy <v.churavy@gmail.com> and contributors"]
43
version = "0.10.0-dev"
4+
authors = ["Valentin Churavy <v.churavy@gmail.com> and contributors"]
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -13,6 +13,9 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
1313
OpenCL_jll = "6cb37087-e8b6-5417-8430-1f242f1e46e4"
1414
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1515
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
16+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
17+
Random123 = "74087812-796a-5b5d-8853-05524746bad3"
18+
RandomNumbers = "e6cf234a-135c-5ec9-84dd-332b85af5143"
1619
SPIRVIntrinsics = "71d1d633-e7e8-4a92-83a1-de8814b09ba8"
1720
SPIRV_LLVM_Backend_jll = "4376b9bf-cff8-51b6-bb48-39421dff0d0c"
1821
SPIRV_Tools_jll = "6ac6d60f-d740-5983-97d7-a4482c0689f4"
@@ -40,6 +43,9 @@ LLVM = "9.4.1"
4043
LinearAlgebra = "1.6"
4144
MacroTools = "0.5"
4245
PrecompileTools = "1"
46+
Random = "1.11.0"
47+
Random123 = "1.7.1"
48+
RandomNumbers = "1.6.0"
4349
SPIRVIntrinsics = "0.5"
4450
SPIRV_LLVM_Backend_jll = "20"
4551
SPIRV_Tools_jll = "2024.4, 2025.1"

src/pocl/compiler/compilation.jl

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,87 @@ GPUCompiler.isintrinsic(job::OpenCLCompilerJob, fn::String) =
1919
in(fn, known_intrinsics) ||
2020
contains(fn, "__spirv_")
2121

22+
GPUCompiler.kernel_state_type(::OpenCLCompilerJob) = KernelState
23+
24+
function GPUCompiler.finish_module!(@nospecialize(job::OpenCLCompilerJob),
25+
mod::LLVM.Module, entry::LLVM.Function)
26+
entry = invoke(GPUCompiler.finish_module!,
27+
Tuple{CompilerJob{SPIRVCompilerTarget}, LLVM.Module, LLVM.Function},
28+
job, mod, entry)
29+
30+
# if this kernel uses our RNG, we should prime the shared state.
31+
# XXX: these transformations should really happen at the Julia IR level...
32+
if haskey(functions(mod), "julia.opencl.random_keys") && job.config.kernel
33+
# insert call to `initialize_rng_state`
34+
f = initialize_rng_state
35+
ft = typeof(f)
36+
tt = Tuple{}
37+
38+
# create a deferred compilation job for `initialize_rng_state`
39+
src = methodinstance(ft, tt, GPUCompiler.tls_world_age())
40+
cfg = CompilerConfig(job.config; kernel=false, name=nothing)
41+
job = CompilerJob(src, cfg, job.world)
42+
id = length(GPUCompiler.deferred_codegen_jobs) + 1
43+
GPUCompiler.deferred_codegen_jobs[id] = job
44+
45+
# generate IR for calls to `deferred_codegen` and the resulting function pointer
46+
top_bb = first(blocks(entry))
47+
bb = BasicBlock(top_bb, "initialize_rng")
48+
@dispose builder=IRBuilder() begin
49+
position!(builder, bb)
50+
subprogram = LLVM.subprogram(entry)
51+
if subprogram !== nothing
52+
loc = DILocation(0, 0, subprogram)
53+
debuglocation!(builder, loc)
54+
end
55+
debuglocation!(builder, first(instructions(top_bb)))
56+
57+
# call the `deferred_codegen` marker function
58+
T_ptr = if LLVM.version() >= v"17"
59+
LLVM.PointerType()
60+
elseif VERSION >= v"1.12.0-DEV.225"
61+
LLVM.PointerType(LLVM.Int8Type())
62+
else
63+
LLVM.Int64Type()
64+
end
65+
T_id = convert(LLVMType, Int)
66+
deferred_codegen_ft = LLVM.FunctionType(T_ptr, [T_id])
67+
deferred_codegen = if haskey(functions(mod), "deferred_codegen")
68+
functions(mod)["deferred_codegen"]
69+
else
70+
LLVM.Function(mod, "deferred_codegen", deferred_codegen_ft)
71+
end
72+
fptr = call!(builder, deferred_codegen_ft, deferred_codegen, [ConstantInt(id)])
73+
74+
# call the `initialize_rng_state` function
75+
rt = Core.Compiler.return_type(f, tt)
76+
llvm_rt = convert(LLVMType, rt)
77+
llvm_ft = LLVM.FunctionType(llvm_rt)
78+
fptr = inttoptr!(builder, fptr, LLVM.PointerType(llvm_ft))
79+
call!(builder, llvm_ft, fptr)
80+
br!(builder, top_bb)
81+
82+
# note the use of the device-side RNG in this kernel
83+
push!(function_attributes(entry), StringAttribute("julia.opencl.rng", ""))
84+
end
85+
86+
# XXX: put some of the above behind GPUCompiler abstractions
87+
# (e.g., a compile-time version of `deferred_codegen`)
88+
end
89+
return entry
90+
end
91+
92+
function GPUCompiler.finish_linked_module!(@nospecialize(job::OpenCLCompilerJob), mod::LLVM.Module)
93+
for f in GPUCompiler.kernels(mod)
94+
kernel_intrinsics = Dict(
95+
"julia.opencl.random_keys" => (; name = "random_keys", typ = LLVMPtr{UInt32, AS.Workgroup}),
96+
"julia.opencl.random_counters" => (; name = "random_counters", typ = LLVMPtr{UInt32, AS.Workgroup}),
97+
)
98+
GPUCompiler.add_input_arguments!(job, mod, f, kernel_intrinsics)
99+
end
100+
return
101+
end
102+
22103

23104
## compiler implementation (cache, configure, compile, and link)
24105

@@ -60,10 +141,13 @@ end
60141
function compile(@nospecialize(job::CompilerJob))
61142
# TODO: this creates a context; cache those.
62143
obj, meta = JuliaContext() do ctx
63-
GPUCompiler.compile(:obj, job)
64-
end
144+
obj, meta = GPUCompiler.compile(:obj, job)
65145

66-
return (; obj, entry = LLVM.name(meta.entry))
146+
entry = LLVM.name(meta.entry)
147+
device_rng = StringAttribute("julia.opencl.rng", "") in collect(function_attributes(meta.entry))
148+
149+
(; obj, entry, device_rng)
150+
end
67151
end
68152

69153
# link into an executable kernel
@@ -74,5 +158,5 @@ function link(@nospecialize(job::CompilerJob), compiled)
74158
error("Your device does not support SPIR-V, which is currently required for native execution.")
75159
end
76160
cl.build!(prog)
77-
return cl.Kernel(prog, compiled.entry)
161+
(; kernel=cl.Kernel(prog, compiled.entry), compiled.device_rng)
78162
end

src/pocl/compiler/execution.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ end
146146

147147
abstract type AbstractKernel{F, TT} end
148148

149+
pass_arg(@nospecialize dt) = !(GPUCompiler.isghosttype(dt) || Core.Compiler.isconstType(dt))
150+
149151
@inline @generated function (kernel::AbstractKernel{F, TT})(
150152
args...;
151153
call_kwargs...
@@ -154,8 +156,7 @@ abstract type AbstractKernel{F, TT} end
154156
args = (:(kernel.f), (:(clconvert(args[$i], svm_pointers)) for i in 1:length(args))...)
155157

156158
# filter out ghost arguments that shouldn't be passed
157-
predicate = dt -> GPUCompiler.isghosttype(dt) || Core.Compiler.isconstType(dt)
158-
to_pass = map(!predicate, sig.parameters)
159+
to_pass = map(pass_arg, sig.parameters)
159160
call_t = Type[x[1] for x in zip(sig.parameters, to_pass) if x[2]]
160161
call_args = Union{Expr, Symbol}[x[1] for x in zip(args, to_pass) if x[2]]
161162

@@ -167,12 +168,15 @@ abstract type AbstractKernel{F, TT} end
167168
end
168169
end
169170

171+
pushfirst!(call_t, KernelState)
172+
pushfirst!(call_args, :(KernelState(kernel.rng_state ? Base.rand(UInt32) : UInt32(0))))
173+
170174
# finalize types
171175
call_tt = Base.to_tuple_type(call_t)
172176

173177
return quote
174178
svm_pointers = Ptr{Cvoid}[]
175-
$cl.clcall(kernel.fun, $call_tt, $(call_args...); svm_pointers, call_kwargs...)
179+
$cl.clcall(kernel.fun, $call_tt, $(call_args...); svm_pointers, kernel.rng_state, call_kwargs...)
176180
end
177181
end
178182

@@ -182,6 +186,7 @@ end
182186
struct HostKernel{F, TT} <: AbstractKernel{F, TT}
183187
f::F
184188
fun::cl.Kernel
189+
rng_state::Bool
185190
end
186191

187192

@@ -198,15 +203,15 @@ function clfunction(f::F, tt::TT = Tuple{}; kwargs...) where {F, TT}
198203
cache = compiler_cache(ctx)
199204
source = methodinstance(F, tt)
200205
config = compiler_config(dev; kwargs...)::OpenCLCompilerConfig
201-
fun = GPUCompiler.cached_compilation(cache, source, config, compile, link)
206+
linked = GPUCompiler.cached_compilation(cache, source, config, compile, link)
202207

203208
# create a callable object that captures the function instance. we don't need to think
204209
# about world age here, as GPUCompiler already does and will return a different object
205-
h = hash(fun, hash(f, hash(tt)))
210+
h = hash(linked.kernel, hash(f, hash(tt)))
206211
kernel = get(_kernel_instances, h, nothing)
207212
if kernel === nothing
208213
# create the kernel state object
209-
kernel = HostKernel{F, tt}(f, fun)
214+
kernel = HostKernel{F, tt}(f, linked.kernel, linked.device_rng)
210215
_kernel_instances[h] = kernel
211216
end
212217
return kernel::HostKernel{F, tt}

0 commit comments

Comments
 (0)