Skip to content

Commit 986f167

Browse files
committed
Use a function decorator macro to transition back into GC unsafe domain.
1 parent a82b963 commit 986f167

8 files changed

Lines changed: 58 additions & 64 deletions

File tree

lib/cublas/CUBLAS.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -210,12 +210,7 @@ function log_message(ptr)
210210
return
211211
end
212212

213-
function _log_message(blob)
214-
# see @gcsafe_ccall documentation
215-
@static if VERSION < v"1.9"
216-
GC.safepoint()
217-
end
218-
213+
@gcunsafe_callback function _log_message(blob)
219214
# the message format isn't documented, but it looks like a message starts with a capital
220215
# and the severity (e.g. `I!`), and subsequent lines start with a lowercase mark (`!i`)
221216
#

lib/cudadrv/occupancy.jl

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,7 @@ end
3636
# HACK: callback function for `launch_configuration` on platforms without support for
3737
# trampolines as used by `@cfunction` (JuliaLang/julia#27174, JuliaLang/julia#32154)
3838
_shmem_cb = nothing
39-
function _shmem_cint_cb(x::Cint)
40-
# see @gcsafe_ccall documentation
41-
@static if VERSION < v"1.9"
42-
GC.safepoint()
43-
end
39+
@gcunsafe_callback function _shmem_cint_cb(x::Cint)
4440
Cint(something(_shmem_cb)(x))
4541
end
4642
_shmem_cb_lock = Threads.ReentrantLock()
@@ -64,11 +60,7 @@ function launch_configuration(fun::CuFunction; shmem::Union{Integer,Base.Callabl
6460
if isa(shmem, Integer)
6561
cuOccupancyMaxPotentialBlockSize(blocks_ref, threads_ref, fun, C_NULL, shmem, max_threads)
6662
elseif Sys.ARCH == :x86 || Sys.ARCH == :x86_64
67-
function shmem_cint(threads)
68-
# see @gcsafe_ccall documentation
69-
@static if VERSION < v"1.9"
70-
GC.safepoint()
71-
end
63+
@gcunsafe_callback function shmem_cint(threads)
7264
Cint(shmem(threads))
7365
end
7466
cb = @cfunction($shmem_cint, Cint, (Cint,))

lib/cudnn/src/cuDNN.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,7 @@ function log_message(sev, udata, dbg_ptr, ptr)
137137
return
138138
end
139139

140-
function _log_message(sev, dbg, str)
141-
# see @gcsafe_ccall documentation
142-
@static if VERSION < v"1.9"
143-
GC.safepoint()
144-
end
145-
140+
@gcunsafe_callback function _log_message(sev, dbg, str)
146141
lines = split(str, '\0')
147142
msg = join(lines, '\n')
148143
if sev == CUDNN_SEV_INFO

lib/cupti/wrappers.jl

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,8 @@ end
1212
# multiple subscribers aren't supported, so make sure we only call CUPTI once
1313
const callback_lock = ReentrantLock()
1414

15-
function callback(userdata::Ptr{Cvoid}, domain::CUpti_CallbackDomain,
16-
id::CUpti_CallbackId, data_ptr::Ptr{Cvoid})
17-
# see @gcsafe_ccall documentation
18-
@static if VERSION < v"1.9"
19-
GC.safepoint()
20-
end
21-
15+
@gcunsafe_callback function callback(userdata::Ptr{Cvoid}, domain::CUpti_CallbackDomain,
16+
id::CUpti_CallbackId, data_ptr::Ptr{Cvoid})
2217
cfg = Base.unsafe_pointer_to_objref(userdata)::CallbackConfig
2318

2419
# decode the callback data
@@ -131,15 +126,10 @@ end
131126
const activity_lock = ReentrantLock()
132127
const activity_config = Ref{Union{Nothing,ActivityConfig}}(nothing)
133128

134-
function request_buffer(dest_ptr, sz_ptr, max_num_records_ptr)
129+
@gcunsafe_callback function request_buffer(dest_ptr, sz_ptr, max_num_records_ptr)
135130
# this function is called by CUPTI, but directly from the application, so it should be
136131
# fine to perform I/O or allocate memory here.
137132

138-
# see @gcsafe_ccall documentation
139-
@static if VERSION < v"1.9"
140-
GC.safepoint()
141-
end
142-
143133
dest = Base.unsafe_wrap(Array, dest_ptr, 1)
144134
sz = Base.unsafe_wrap(Array, sz_ptr, 1)
145135
max_num_records = Base.unsafe_wrap(Array, max_num_records_ptr, 1)
@@ -167,7 +157,7 @@ function request_buffer(dest_ptr, sz_ptr, max_num_records_ptr)
167157
return
168158
end
169159

170-
function complete_buffer(ctx_handle, stream_id, buf_ptr, sz, valid_sz)
160+
@gcunsafe_callback function complete_buffer(ctx_handle, stream_id, buf_ptr, sz, valid_sz)
171161
# this function is called by a CUPTI worker thread while our application may be waiting
172162
# for `cuptiActivityFlushAll` to complete. that means we cannot do I/O here, or we could
173163
# yield while the application cannot make any progress.
@@ -176,11 +166,6 @@ function complete_buffer(ctx_handle, stream_id, buf_ptr, sz, valid_sz)
176166
# to prevent this, we call `sizehint!` in `request_buffer`.
177167
# XXX: `sizehint!` isn't a guarantee; use `resize!` and a cursor?
178168

179-
# see @gcsafe_ccall documentation
180-
@static if VERSION < v"1.9"
181-
GC.safepoint()
182-
end
183-
184169
cfg = activity_config[]
185170
if cfg !== nothing
186171
push!(cfg.results, (ctx_handle, stream_id, buf_ptr, sz, valid_sz))

lib/custatevec/src/cuStateVec.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,7 @@ end
8787

8888
## logging
8989

90-
function log_message(log_level, function_name, message)
91-
# see @gcsafe_ccall documentation
92-
@static if VERSION < v"1.9"
93-
GC.safepoint()
94-
end
95-
90+
@gcunsafe_callback function log_message(log_level, function_name, message)
9691
function_name = unsafe_string(function_name)
9792
message = unsafe_string(message)
9893
output = if isempty(message)

lib/cutensor/src/cuTENSOR.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,7 @@ end
7070

7171
## logging
7272

73-
function log_message(log_level, function_name, message)
74-
# see @gcsafe_ccall documentation
75-
@static if VERSION < v"1.9"
76-
GC.safepoint()
77-
end
78-
73+
@gcunsafe_callback function log_message(log_level, function_name, message)
7974
function_name = unsafe_string(function_name)
8075
message = unsafe_string(message)
8176
output = if isempty(message)

lib/cutensornet/src/cuTensorNet.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,7 @@ end
8888

8989
## logging
9090

91-
function log_message(log_level, function_name, message)
92-
# see @gcsafe_ccall documentation
93-
@static if VERSION < v"1.9"
94-
GC.safepoint()
95-
end
96-
91+
@gcunsafe_callback function log_message(log_level, function_name, message)
9792
function_name = unsafe_string(function_name)
9893
message = unsafe_string(message)
9994
output = if isempty(message)

lib/utils/call.jl

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# utilities for calling foreign functionality more conveniently
22

3-
export @checked, with_workspace, with_workspaces, @debug_ccall, @gcsafe_ccall
3+
export @checked, with_workspace, with_workspaces,
4+
@debug_ccall, @gcsafe_ccall, @gcunsafe_callback
45

56

67
## function wrapper for checking the return value of a function
@@ -166,10 +167,8 @@ render_arg(io, arg::Union{<:Base.RefValue, AbstractArray}) = summary(io, arg)
166167
# TODO: replace with JuliaLang/julia#49933 once merged
167168

168169
# note that this is generally only safe with functions that do not call back into Julia.
169-
# when callbacks occur, the code should ensure the GC is not running:
170-
# - on 1.10 and later, everything is fine because of safepoint_on_entry
171-
# - on 1.9, @cfunction-based callbacks are fine because they transition to gc_unsafe
172-
# - on 1.8 and earlier, the code should explicitly call GC.safepoint()!
170+
# when callbacks occur, the code should ensure the GC is not running by wrapping the code
171+
# in the `@gcunsafe` macro
173172

174173
function ccall_macro_lower(func, rettype, types, args, nreq)
175174
# instead of re-using ccall or Expr(:foreigncall) to perform argument conversion,
@@ -213,7 +212,50 @@ function ccall_macro_lower(func, rettype, types, args, nreq)
213212
end
214213
end
215214

215+
"""
216+
@gcsafe_ccall ...
217+
218+
Call a foreign function just like `@ccall`, but marking it safe for the GC to run. This is
219+
useful for functions that may block, so that the GC isn't blocked from running, but may also
220+
be required to prevent deadlocks (see JuliaGPU/CUDA.jl#2261).
221+
222+
Note that this is generally only safe with non-Julia C functions that do not call back into
223+
Julia. When using callbacks, the code should make sure to transition back into GC-unsafe
224+
mode using the `@gcunsafe` macro.
225+
"""
216226
macro gcsafe_ccall(expr)
217227
ccall_macro_lower(Base.ccall_macro_parse(expr)...)
218228
end
219229

230+
"""
231+
@gcunsafe_callback function callback(...)
232+
...
233+
end
234+
235+
Mark a callback function as unsafe for the GC to run. This is normally the default for
236+
Julia code, and is meant to be used in combination with `@gcsafe_ccall`.
237+
"""
238+
macro gcunsafe_callback(ex)
239+
if VERSION >= v"1.9"
240+
# on 1.9+, `@cfunction` already transitions to GC-unsafe mode
241+
return esc(ex)
242+
end
243+
244+
# parse the function definition
245+
@assert Meta.isexpr(ex, :function)
246+
sig = ex.args[1]
247+
@assert Meta.isexpr(sig, :call)
248+
body = ex.args[2]
249+
@assert Meta.isexpr(body, :block)
250+
251+
gcunsafe_body = quote
252+
gc_state = @ccall(jl_gc_unsafe_enter()::Int8)
253+
try
254+
$(ex)
255+
finally
256+
@ccall(jl_gc_unsafe_leave(gc_state::Int8)::Cvoid)
257+
end
258+
end
259+
260+
return esc(Expr(:function, sig, gcunsafe_body))
261+
end

0 commit comments

Comments
 (0)