Skip to content

Commit 32ad9e6

Browse files
authored
change compiler to be stackless (#55575)
This change ensures the compiler uses very little stack, making it compatible with running on any arbitrary system stack size and depths much more reliably. It also could be further modified now to easily add various forms of pause-able/resumable inference, since there is no implicit state on the stack--everything is local and explicit now. Whereas before, less than 900 frames would crash in less than a second: ``` $ time ./julia -e 'f(::Val{N}) where {N} = N <= 0 ? 0 : f(Val(N - 1)); f(Val(1000))' Warning: detected a stack overflow; program state may be corrupted, so further execution might be unreliable. Internal error: during type inference of f(Base.Val{1000}) Encountered stack overflow. This might be caused by recursion over very long tuples or argument lists. [23763] signal 6: Abort trap: 6 in expression starting at none:1 __pthread_kill at /usr/lib/system/libsystem_kernel.dylib (unknown line) Allocations: 1 (Pool: 1; Big: 0); GC: 0 Abort trap: 6 real 0m0.233s user 0m0.165s sys 0m0.049s ```` Now: it is effectively unlimited, as long as you are willing to wait for it: ``` $ time ./julia -e 'f(::Val{N}) where {N} = N <= 0 ? 0 : f(Val(N - 1)); f(Val(50000))' info: inference of f(Base.Val{50000}) from f(Base.Val{N}) where {N} exceeding 2500 frames (may be slow). info: inference of f(Base.Val{50000}) from f(Base.Val{N}) where {N} exceeding 5000 frames (may be slow). info: inference of f(Base.Val{50000}) from f(Base.Val{N}) where {N} exceeding 10000 frames (may be slow). info: inference of f(Base.Val{50000}) from f(Base.Val{N}) where {N} exceeding 20000 frames (may be slow). info: inference of f(Base.Val{50000}) from f(Base.Val{N}) where {N} exceeding 40000 frames (may be slow). real 7m4.988s $ time ./julia -e 'f(::Val{N}) where {N} = N <= 0 ? 0 : f(Val(N - 1)); f(Val(1000))' real 0m0.214s user 0m0.164s sys 0m0.044s $ time ./julia -e '@noinline f(::Val{N}) where {N} = N <= 0 ? GC.safepoint() : f(Val(N - 1)); f(Val(5000))' info: inference of f(Base.Val{5000}) from f(Base.Val{N}) where {N} exceeding 2500 frames (may be slow). info: inference of f(Base.Val{5000}) from f(Base.Val{N}) where {N} exceeding 5000 frames (may be slow). real 0m8.609s user 0m8.358s sys 0m0.240s ```
1 parent a7c5056 commit 32ad9e6

11 files changed

Lines changed: 1048 additions & 965 deletions

File tree

base/compiler/abstractinterpretation.jl

Lines changed: 782 additions & 620 deletions
Large diffs are not rendered by default.

base/compiler/inferencestate.jl

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ mutable struct InferenceState
251251
stmt_info::Vector{CallInfo}
252252

253253
#= intermediate states for interprocedural abstract interpretation =#
254+
tasks::Vector{WorkThunk}
254255
pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
255256
limitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on return
256257
cycle_backedges::Vector{Tuple{InferenceState, Int}} # call-graph backedges connecting from callee to caller
@@ -328,6 +329,7 @@ mutable struct InferenceState
328329
limitations = IdSet{InferenceState}()
329330
cycle_backedges = Vector{Tuple{InferenceState,Int}}()
330331
callstack = AbsIntState[]
332+
tasks = WorkThunk[]
331333

332334
valid_worlds = WorldRange(1, get_world_counter())
333335
bestguess = Bottom
@@ -351,7 +353,7 @@ mutable struct InferenceState
351353
this = new(
352354
mi, world, mod, sptypes, slottypes, src, cfg, method_info,
353355
currbb, currpc, ip, handler_info, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
354-
pclimitations, limitations, cycle_backedges, callstack, 0, 0, 0,
356+
tasks, pclimitations, limitations, cycle_backedges, callstack, 0, 0, 0,
355357
result, unreachable, valid_worlds, bestguess, exc_bestguess, ipo_effects,
356358
restrict_abstract_call_sites, cache_mode, insert_coverage,
357359
interp)
@@ -800,6 +802,7 @@ mutable struct IRInterpretationState
800802
const ssa_refined::BitSet
801803
const lazyreachability::LazyCFGReachability
802804
valid_worlds::WorldRange
805+
const tasks::Vector{WorkThunk}
803806
const edges::Vector{Any}
804807
callstack #::Vector{AbsIntState}
805808
frameid::Int
@@ -825,10 +828,11 @@ mutable struct IRInterpretationState
825828
ssa_refined = BitSet()
826829
lazyreachability = LazyCFGReachability(ir)
827830
valid_worlds = WorldRange(min_world, max_world == typemax(UInt) ? get_world_counter() : max_world)
831+
tasks = WorkThunk[]
828832
edges = Any[]
829833
callstack = AbsIntState[]
830834
return new(method_info, ir, mi, world, curridx, argtypes_refined, ir.sptypes, tpdum,
831-
ssa_refined, lazyreachability, valid_worlds, edges, callstack, 0, 0)
835+
ssa_refined, lazyreachability, valid_worlds, tasks, edges, callstack, 0, 0)
832836
end
833837
end
834838

@@ -870,6 +874,7 @@ function print_callstack(frame::AbsIntState)
870874
print(frame_instance(sv))
871875
is_cached(sv) || print(" [uncached]")
872876
sv.parentid == idx - 1 || print(" [parent=", sv.parentid, "]")
877+
isempty(callers_in_cycle(sv)) || print(" [cycle=", sv.cycleid, "]")
873878
println()
874879
@assert sv.frameid == idx
875880
end
@@ -994,7 +999,10 @@ of the same cycle, only if it is part of a cycle with multiple frames.
994999
function callers_in_cycle(sv::InferenceState)
9951000
callstack = sv.callstack::Vector{AbsIntState}
9961001
cycletop = cycleid = sv.cycleid
997-
while cycletop < length(callstack) && (callstack[cycletop + 1]::InferenceState).cycleid == cycleid
1002+
while cycletop < length(callstack)
1003+
frame = callstack[cycletop + 1]
1004+
frame isa InferenceState || break
1005+
frame.cycleid == cycleid || break
9981006
cycletop += 1
9991007
end
10001008
return AbsIntCycle(callstack, cycletop == cycleid ? 0 : cycleid, cycletop)
@@ -1054,6 +1062,7 @@ function merge_effects!(::AbstractInterpreter, caller::InferenceState, effects::
10541062
effects = Effects(effects; effect_free=ALWAYS_TRUE)
10551063
end
10561064
caller.ipo_effects = merge_effects(caller.ipo_effects, effects)
1065+
nothing
10571066
end
10581067
merge_effects!(::AbstractInterpreter, ::IRInterpretationState, ::Effects) = return
10591068

@@ -1116,3 +1125,90 @@ function get_max_methods_for_module(mod::Module)
11161125
max_methods < 0 && return nothing
11171126
return max_methods
11181127
end
1128+
1129+
"""
1130+
Future{T}
1131+
1132+
Delayed return value for a value of type `T`, similar to RefValue{T}, but
1133+
explicitly represents completed as a `Bool` rather than as `isdefined`.
1134+
Set once with `f[] = v` and accessed with `f[]` afterwards.
1135+
1136+
Can also be constructed with the `completed` flag value and a closure to
1137+
produce `x`, as well as the additional arguments to avoid always capturing the
1138+
same couple of values.
1139+
"""
1140+
struct Future{T}
1141+
later::Union{Nothing,RefValue{T}}
1142+
now::Union{Nothing,T}
1143+
Future{T}() where {T} = new{T}(RefValue{T}(), nothing)
1144+
Future{T}(x) where {T} = new{T}(nothing, x)
1145+
Future(x::T) where {T} = new{T}(nothing, x)
1146+
end
1147+
isready(f::Future) = f.later === nothing
1148+
getindex(f::Future{T}) where {T} = (later = f.later; later === nothing ? f.now::T : later[])
1149+
setindex!(f::Future, v) = something(f.later)[] = v
1150+
convert(::Type{Future{T}}, x) where {T} = Future{T}(x) # support return type conversion
1151+
convert(::Type{Future{T}}, x::Future) where {T} = x::Future{T}
1152+
function Future{T}(f, immediate::Bool, interp::AbstractInterpreter, sv::AbsIntState) where {T}
1153+
if immediate
1154+
return Future{T}(f(interp, sv))
1155+
else
1156+
@assert applicable(f, interp, sv)
1157+
result = Future{T}()
1158+
push!(sv.tasks, function (interp, sv)
1159+
result[] = f(interp, sv)
1160+
return true
1161+
end)
1162+
return result
1163+
end
1164+
end
1165+
function Future{T}(f, prev::Future{S}, interp::AbstractInterpreter, sv::AbsIntState) where {T, S}
1166+
later = prev.later
1167+
if later === nothing
1168+
return Future{T}(f(prev[], interp, sv))
1169+
else
1170+
@assert Core._hasmethod(Tuple{Core.Typeof(f), S, typeof(interp), typeof(sv)})
1171+
result = Future{T}()
1172+
push!(sv.tasks, function (interp, sv)
1173+
result[] = f(later[], interp, sv) # capture just later, instead of all of prev
1174+
return true
1175+
end)
1176+
return result
1177+
end
1178+
end
1179+
1180+
1181+
"""
1182+
doworkloop(args...)
1183+
1184+
Run a tasks inside the abstract interpreter, returning false if there are none.
1185+
Tasks will be run in DFS post-order tree order, such that all child tasks will
1186+
be run in the order scheduled, prior to running any subsequent tasks. This
1187+
allows tasks to generate more child tasks, which will be run before anything else.
1188+
Each task will be run repeatedly when returning `false`, until it returns `true`.
1189+
"""
1190+
function doworkloop(interp::AbstractInterpreter, sv::AbsIntState)
1191+
tasks = sv.tasks
1192+
prev = length(tasks)
1193+
prev == 0 && return false
1194+
task = pop!(tasks)
1195+
completed = task(interp, sv)
1196+
tasks = sv.tasks # allow dropping gc root over the previous call
1197+
completed isa Bool || throw(TypeError(:return, "", Bool, task)) # print the task on failure as part of the error message, instead of just "@ workloop:line"
1198+
completed || push!(tasks, task)
1199+
# efficient post-order visitor: items pushed are executed in reverse post order such
1200+
# that later items are executed before earlier ones, but are fully executed
1201+
# (including any dependencies scheduled by them) before going on to the next item
1202+
reverse!(tasks, #=start=#prev)
1203+
return true
1204+
end
1205+
1206+
1207+
#macro workthunk(name::Symbol, body)
1208+
# name = esc(name)
1209+
# body = esc(body)
1210+
# return replace_linenums!(
1211+
# :(function $name($(esc(interp)), $(esc(sv)))
1212+
# $body
1213+
# end), __source__)
1214+
#end

base/compiler/ssair/ir.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1432,6 +1432,7 @@ function process_node!(compact::IncrementalCompact, result_idx::Int, inst::Instr
14321432
elseif isa(stmt, OldSSAValue)
14331433
ssa_rename[idx] = ssa_rename[stmt.id]
14341434
elseif isa(stmt, GotoNode) && cfg_transforms_enabled
1435+
stmt.label < 0 && (println(stmt); println(compact))
14351436
label = bb_rename_succ[stmt.label]
14361437
@assert label > 0
14371438
ssa_rename[idx] = SSAValue(result_idx)

base/compiler/ssair/irinterp.jl

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,11 @@ end
5151

5252
function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, irsv::IRInterpretationState)
5353
si = StmtInfo(true) # TODO better job here?
54-
call = abstract_call(interp, arginfo, si, irsv)
55-
irsv.ir.stmts[irsv.curridx][:info] = call.info
54+
call = abstract_call(interp, arginfo, si, irsv)::Future
55+
Future{Nothing}(call, interp, irsv) do call, interp, irsv
56+
irsv.ir.stmts[irsv.curridx][:info] = call.info
57+
nothing
58+
end
5659
return call
5760
end
5861

@@ -143,7 +146,19 @@ function reprocess_instruction!(interp::AbstractInterpreter, inst::Instruction,
143146
head = stmt.head
144147
if (head === :call || head === :foreigncall || head === :new || head === :splatnew ||
145148
head === :static_parameter || head === :isdefined || head === :boundscheck)
146-
(; rt, effects) = abstract_eval_statement_expr(interp, stmt, nothing, irsv)
149+
@assert isempty(irsv.tasks) # TODO: this whole function needs to be converted to a stackless design to be a valid AbsIntState, but this should work here for now
150+
result = abstract_eval_statement_expr(interp, stmt, nothing, irsv)
151+
reverse!(irsv.tasks)
152+
while true
153+
if length(irsv.callstack) > irsv.frameid
154+
typeinf(interp, irsv.callstack[irsv.frameid + 1])
155+
elseif !doworkloop(interp, irsv)
156+
break
157+
end
158+
end
159+
@assert length(irsv.callstack) == irsv.frameid && isempty(irsv.tasks)
160+
result isa Future && (result = result[])
161+
(; rt, effects) = result
147162
add_flag!(inst, flags_for_effects(effects))
148163
elseif head === :invoke
149164
rt, (nothrow, noub) = abstract_eval_invoke_inst(interp, inst, irsv)
@@ -293,7 +308,7 @@ function is_all_const_call(@nospecialize(stmt), interp::AbstractInterpreter, irs
293308
return true
294309
end
295310

296-
function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IRInterpretationState;
311+
function ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IRInterpretationState;
297312
externally_refined::Union{Nothing,BitSet} = nothing)
298313
(; ir, tpdum, ssa_refined) = irsv
299314

@@ -449,18 +464,3 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR
449464

450465
return Pair{Any,Tuple{Bool,Bool}}(maybe_singleton_const(ultimate_rt), (nothrow, noub))
451466
end
452-
453-
function ir_abstract_constant_propagation(interp::NativeInterpreter, irsv::IRInterpretationState)
454-
if __measure_typeinf__[]
455-
inf_frame = Timings.InferenceFrameInfo(irsv.mi, irsv.world, VarState[], Any[], length(irsv.ir.argtypes))
456-
Timings.enter_new_timer(inf_frame)
457-
ret = _ir_abstract_constant_propagation(interp, irsv)
458-
append!(inf_frame.slottypes, irsv.ir.argtypes)
459-
Timings.exit_current_timer(inf_frame)
460-
return ret
461-
else
462-
return _ir_abstract_constant_propagation(interp, irsv)
463-
end
464-
end
465-
ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IRInterpretationState) =
466-
_ir_abstract_constant_propagation(interp, irsv)

base/compiler/ssair/verify.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

33
function maybe_show_ir(ir::IRCode)
4-
if isdefined(Core, :Main)
4+
if isdefined(Core, :Main) && isdefined(Core.Main, :Base)
55
# ensure we use I/O that does not yield, as this gets called during compilation
66
invokelatest(Core.Main.Base.show, Core.stdout, "text/plain", ir)
7+
else
8+
Core.show(ir)
79
end
810
end
911

@@ -25,6 +27,7 @@ is_toplevel_expr_head(head::Symbol) = head === :global || head === :method || he
2527
is_value_pos_expr_head(head::Symbol) = head === :static_parameter
2628
function check_op(ir::IRCode, domtree::DomTree, @nospecialize(op), use_bb::Int, use_idx::Int, printed_use_idx::Int, print::Bool, isforeigncall::Bool, arg_idx::Int, allow_frontend_forms::Bool)
2729
if isa(op, SSAValue)
30+
op.id > 0 || @verify_error "Def ($(op.id)) is invalid in final IR"
2831
if op.id > length(ir.stmts)
2932
def_bb = block_for_inst(ir.cfg, ir.new_nodes.info[op.id - length(ir.stmts)].pos)
3033
else

0 commit comments

Comments
 (0)