@@ -19,6 +19,87 @@ GPUCompiler.isintrinsic(job::OpenCLCompilerJob, fn::String) =
1919 in (fn, known_intrinsics) ||
2020 contains (fn, " __spirv_" )
2121
22+ GPUCompiler. kernel_state_type (:: OpenCLCompilerJob ) = KernelState
23+
24+ function GPUCompiler. finish_module! (@nospecialize (job:: OpenCLCompilerJob ),
25+ mod:: LLVM.Module , entry:: LLVM.Function )
26+ entry = invoke (GPUCompiler. finish_module!,
27+ Tuple{CompilerJob{SPIRVCompilerTarget}, LLVM. Module, LLVM. Function},
28+ job, mod, entry)
29+
30+ # if this kernel uses our RNG, we should prime the shared state.
31+ # XXX : these transformations should really happen at the Julia IR level...
32+ if haskey (functions (mod), " julia.opencl.random_keys" ) && job. config. kernel
33+ # insert call to `initialize_rng_state`
34+ f = initialize_rng_state
35+ ft = typeof (f)
36+ tt = Tuple{}
37+
38+ # create a deferred compilation job for `initialize_rng_state`
39+ src = methodinstance (ft, tt, GPUCompiler. tls_world_age ())
40+ cfg = CompilerConfig (job. config; kernel= false , name= nothing )
41+ job = CompilerJob (src, cfg, job. world)
42+ id = length (GPUCompiler. deferred_codegen_jobs) + 1
43+ GPUCompiler. deferred_codegen_jobs[id] = job
44+
45+ # generate IR for calls to `deferred_codegen` and the resulting function pointer
46+ top_bb = first (blocks (entry))
47+ bb = BasicBlock (top_bb, " initialize_rng" )
48+ @dispose builder= IRBuilder () begin
49+ position! (builder, bb)
50+ subprogram = LLVM. subprogram (entry)
51+ if subprogram != = nothing
52+ loc = DILocation (0 , 0 , subprogram)
53+ debuglocation! (builder, loc)
54+ end
55+ debuglocation! (builder, first (instructions (top_bb)))
56+
57+ # call the `deferred_codegen` marker function
58+ T_ptr = if LLVM. version () >= v " 17"
59+ LLVM. PointerType ()
60+ elseif VERSION >= v " 1.12.0-DEV.225"
61+ LLVM. PointerType (LLVM. Int8Type ())
62+ else
63+ LLVM. Int64Type ()
64+ end
65+ T_id = convert (LLVMType, Int)
66+ deferred_codegen_ft = LLVM. FunctionType (T_ptr, [T_id])
67+ deferred_codegen = if haskey (functions (mod), " deferred_codegen" )
68+ functions (mod)[" deferred_codegen" ]
69+ else
70+ LLVM. Function (mod, " deferred_codegen" , deferred_codegen_ft)
71+ end
72+ fptr = call! (builder, deferred_codegen_ft, deferred_codegen, [ConstantInt (id)])
73+
74+ # call the `initialize_rng_state` function
75+ rt = Core. Compiler. return_type (f, tt)
76+ llvm_rt = convert (LLVMType, rt)
77+ llvm_ft = LLVM. FunctionType (llvm_rt)
78+ fptr = inttoptr! (builder, fptr, LLVM. PointerType (llvm_ft))
79+ call! (builder, llvm_ft, fptr)
80+ br! (builder, top_bb)
81+
82+ # note the use of the device-side RNG in this kernel
83+ push! (function_attributes (entry), StringAttribute (" julia.opencl.rng" , " " ))
84+ end
85+
86+ # XXX : put some of the above behind GPUCompiler abstractions
87+ # (e.g., a compile-time version of `deferred_codegen`)
88+ end
89+ return entry
90+ end
91+
92+ function GPUCompiler. finish_linked_module! (@nospecialize (job:: OpenCLCompilerJob ), mod:: LLVM.Module )
93+ for f in GPUCompiler. kernels (mod)
94+ kernel_intrinsics = Dict (
95+ " julia.opencl.random_keys" => (; name = " random_keys" , typ = LLVMPtr{UInt32, AS. Workgroup}),
96+ " julia.opencl.random_counters" => (; name = " random_counters" , typ = LLVMPtr{UInt32, AS. Workgroup}),
97+ )
98+ GPUCompiler. add_input_arguments! (job, mod, f, kernel_intrinsics)
99+ end
100+ return
101+ end
102+
22103
23104# # compiler implementation (cache, configure, compile, and link)
24105
60141function compile (@nospecialize (job:: CompilerJob ))
61142 # TODO : this creates a context; cache those.
62143 obj, meta = JuliaContext () do ctx
63- GPUCompiler. compile (:obj , job)
64- end
144+ obj, meta = GPUCompiler. compile (:obj , job)
65145
66- return (; obj, entry = LLVM. name (meta. entry))
146+ entry = LLVM. name (meta. entry)
147+ device_rng = StringAttribute (" julia.opencl.rng" , " " ) in collect (function_attributes (meta. entry))
148+
149+ (; obj, entry, device_rng)
150+ end
67151end
68152
69153# link into an executable kernel
@@ -74,5 +158,5 @@ function link(@nospecialize(job::CompilerJob), compiled)
74158 error (" Your device does not support SPIR-V, which is currently required for native execution." )
75159 end
76160 cl. build! (prog)
77- return cl. Kernel (prog, compiled. entry)
161+ (; kernel = cl. Kernel (prog, compiled. entry), compiled . device_rng )
78162end
0 commit comments