Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
6cac55b
The runtime rewrite (without debugging)
AnsonYeung Jan 13, 2026
315e6d2
Inline unwind
AnsonYeung Jan 14, 2026
d1f3ae6
Implement stack trace
AnsonYeung Jan 31, 2026
40cbcae
Remove unused debug function
AnsonYeung Feb 3, 2026
e4012a6
Change stack safety heuristic to include local count and runtime stac…
AnsonYeung Feb 9, 2026
8b7a03c
revert stack safety heuristic
AnsonYeung Feb 11, 2026
894fe8c
Merge remote-tracking branch 'upstream/hkmc2' into handler-overhaul2
AnsonYeung Feb 13, 2026
0f08646
revert some changes
AnsonYeung Feb 13, 2026
3a9f512
Don't transform unneeded code
AnsonYeung Feb 16, 2026
f6360d3
Merge branch 'hkmc2' into handler-overhaul2
LPTK Feb 16, 2026
2f50f3f
Lifter fix
AnsonYeung Feb 17, 2026
f0858a1
Merge remote-tracking branch 'upstream/hkmc2' into handler-overhaul2
AnsonYeung Feb 17, 2026
eb04e20
fix
CAG2Mark Feb 17, 2026
fd61bc7
fix
CAG2Mark Feb 17, 2026
7f0d66d
Rerun test
AnsonYeung Feb 17, 2026
a16aca8
Benchmark fixes
AnsonYeung Feb 17, 2026
d203d50
Fix benchmark
AnsonYeung Feb 18, 2026
1bc9b26
get dSym directly from Block
AnsonYeung Feb 23, 2026
4042942
Merge branch 'hkmc2' into handler-overhaul2
LPTK Feb 23, 2026
ad650d1
Fix path
AnsonYeung Feb 23, 2026
46be32c
Use sync so the file is always written
AnsonYeung Feb 23, 2026
6720606
Add new option to HandlerLowering
AnsonYeung Feb 23, 2026
03f8826
skip mod ctor option for stack safety
AnsonYeung Feb 23, 2026
2fafca6
important change
AnsonYeung Feb 23, 2026
d991897
change mod ctor code
AnsonYeung Feb 24, 2026
734eabb
Update comment
AnsonYeung Feb 24, 2026
94dad59
Fix short circuit with effect
AnsonYeung Feb 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions hkmc2/shared/src/main/scala/hkmc2/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,16 @@ object Config:
// Whether we check `Instantiate` nodes for effects. Currently, effects cannot be raised in constructors.
checkInstantiateEffect: Bool = false,
// A debug option that allows codegen to continue even if an unlifted definition is encountered.
softLifterError: Bool = false
softLifterError: Bool = false,
// Skips instrumenting module constructors, this can be used when the file is statically known to not
// raise any effect and cannot use the Runtime.mls module during module construction due to cyclic dependency.
doNotInstrumentTopLevelModCtor: Bool = false,
)

case class StackSafety(stackLimit: Int)
object StackSafety:
val default: StackSafety = StackSafety(
stackLimit = 500,
stackLimit = 1000,
)

case class LiftDefns() // there may be other settings in the future, having it as a case class now
Expand Down
30 changes: 18 additions & 12 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/HandlerLowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,18 @@ object HandlerLowering:
private enum HandlerCtx:
case FunctionLike(ctx: FunctionCtx)
case Ctor
case ModCtor
case ModCtor(trulyNested: Bool)
case TopLevel

def inCtor = this === Ctor || this === ModCtor
def inCtor = this === Ctor || this.isInstanceOf[ModCtor]
def inTopLevel = this === TopLevel
def allowDefn = inTopLevel || this === ModCtor
def allowDefn = inTopLevel || this.isInstanceOf[ModCtor]
def innerDefIsTrulyNested = this match
case FunctionLike(_) => true
case Ctor => true
case ModCtor(trulyNested) => trulyNested
case TopLevel => false


// currentFun: path to the current function for resumption
// thisPath: path to `this` binding if the function is a method, `this` will be rebinded on resumption
Expand Down Expand Up @@ -125,8 +131,6 @@ import HandlerLowering.*

class HandlerPaths(using Elaborator.State):
val runtimePath: Path = State.runtimeSymbol.asPath
val effectSigPath: Path = runtimePath.selSN("EffectSig").selSN("class")
val effectSigSym: ClassSymbol = State.effectSigSymbol
val contClsPath: Path = runtimePath.selSN("FunctionContFrame").selSN("class")
val mkEffectPath: Path = runtimePath.selSN("mkEffect")
val handleBlockImplPath: Path = runtimePath.selSN("handleBlockImpl")
Expand All @@ -138,14 +142,14 @@ class HandlerPaths(using Elaborator.State):
val stackDepthPath: Path = runtimePath.selN(stackDepthIdent)
val fnLocalsPath: Path = runtimePath.selSN("FnLocalsInfo").selSN("class")
val localVarInfoPath: Path = runtimePath.selSN("LocalVarInfo").selSN("class")
val unwindPath: Path = runtimePath.selSN("unwind")
val curEffect: Path = runtimePath.selSN("curEffect")
val unwindPath: Path = runtimePath.selSN("unwind")
val resumePc: Path = runtimePath.selSN("resumePc")
val resumeIdx: Path = runtimePath.selSN("resumeIdx")
val resumeValueIdent = new Tree.Ident("resumeValue")
val resumeValue: Path = runtimePath.selN(resumeValueIdent)

type StackSafetyMap = collection.Map[FnOrCls, Block]
type StackSafetyMap = collection.Map[FnOrCls, (Int, Block)]

class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise, Elaborator.State, Elaborator.Ctx):

Expand Down Expand Up @@ -474,7 +478,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,
.toList
.map(locals(_))

val stackSafetyMap: mutable.Map[FnOrCls, Block] = mutable.HashMap.empty
val stackSafetyMap: mutable.Map[FnOrCls, (Int, Block)] = mutable.HashMap.empty

private def lifterReport(using Line, FileName)(msgs: Ls[Message -> Opt[Loc]])(using Name) =
if opt.softLifterError then
Expand Down Expand Up @@ -524,13 +528,13 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,
raise(lifterReport(msg"Unexpected nested class: lambdas may not function correctly." -> isym.toLoc :: Nil))
val debugInfos = mutable.ArrayBuffer.empty[(Local, List[Arg])]
val newMtds = methods.map: f =>
val (debugInfoSym, debugInfo, fun2) = translateFunLike(f, Value.Ref(isym).sel(new Tree.Ident(f.sym.nme), f.sym.asTrm.get),
val (debugInfoSym, debugInfo, fun2) = translateFunLike(f, Value.Ref(isym).sel(new Tree.Ident(f.sym.nme), f.dSym),
S(Value.Ref(isym)), s"${sym.nme}#${f.sym.nme}")
debugInfos += debugInfoSym -> debugInfo
fun2
val companion2 = companion.map: bod =>
val newMtds = bod.methods.map: f =>
val (debugInfoSym, debugInfo, fun2) = translateFunLike(f, Value.Ref(bod.isym).sel(new Tree.Ident(f.sym.nme), f.sym.asTrm.get),
val (debugInfoSym, debugInfo, fun2) = translateFunLike(f, Value.Ref(bod.isym).sel(new Tree.Ident(f.sym.nme), f.dSym),
S(Value.Ref(bod.isym)), s"${sym.nme}.${f.sym.nme}")
debugInfos += debugInfoSym -> debugInfo
fun2
Expand All @@ -539,7 +543,8 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,
// TODO: Companion's ctor is more well behaved so it is possible to handle it
// However, JSBuilder inserts extra statements between preCtor and ctor and it's not possible to replicate the exact behavior
// without many special handling.
val newCtor = translateCtorLike(bod.ctor, bod.isym.asPath, true)
val newCtor = if opt.doNotInstrumentTopLevelModCtor && !h.innerDefIsTrulyNested then bod.ctor else
translateCtorLike(bod.ctor, bod.isym.asPath, true)
tl.log(s"companion name: ${bod.isym.nme}")
ClsLikeBody(bod.isym, newMtds, bod.privateFields, bod.publicFields, newCtor)
val c2 = ClsLikeDefn(owner, isym, sym, ctorSym, kind, paramsOpt, auxParams, parentPath, newMtds, privateFields, publicFields,
Expand All @@ -561,6 +566,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,
val parts = partitionBlock(b)
stackSafetyMap += ctx.resumeInfo.currentStackSafetySym ->
(
1,
ctx.doUnwind(ctx.resumeInfo.currentStackSafetySym.fold(_.toLoc, _.toLoc).fold(unit)(locToStr(_)), -1, Nil)(using paths)
)
if parts.states.size <= 1 then
Expand Down Expand Up @@ -616,7 +622,7 @@ class HandlerLowering(paths: HandlerPaths, opt: EffectHandlers)(using TL, Raise,
mainLoop))

private def translateCtorLike(b: Block, thisPath: Path, isModCtor: Bool)(using h: HandlerCtx): Block =
translateBlock(b, if isModCtor then HandlerCtx.ModCtor else HandlerCtx.Ctor, Set.empty)
translateBlock(b, if isModCtor then HandlerCtx.ModCtor(h.innerDefIsTrulyNested) else HandlerCtx.Ctor, Set.empty)

private def translateIllegalEffectCtx(b: Block, onEffect: Call)(using HandlerCtx): Block =
def effectCheck(l: Local, r: Result, rst: Block): Block =
Expand Down
12 changes: 8 additions & 4 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Lifter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,9 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config):
def join2: Block =
resolveDefnRef(l, d, r) match
case Some(value) => k(c.copy(fun = value, args = newArgs)(c.isMlsFun, c.mayRaiseEffects, c.explicitTailCall).withLoc(c.toLoc))
case None => super.applyResult(c)(k)
case None => super.applyPath(c.fun): fun2 =>
if (fun2 is c.fun) && (args is newArgs) then k(c)
else k(c.copy(fun = fun2, args = newArgs)(c.isMlsFun, c.mayRaiseEffects, c.explicitTailCall).withLoc(c.toLoc))
r match
// function call
case f: LiftedFunc => k(f.rewriteCall(c, newArgs))
Expand Down Expand Up @@ -978,7 +980,9 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config):
private val aux = Lazy[Defn](mkAuxDefn)

def rewriteCall(c: Call, args: List[Arg])(using ctx: LifterCtxNew): Call =
if isTrivial then c
if isTrivial then
if args is c.args then c
else c.copy(args = args)(c.isMlsFun, c.mayRaiseEffects, c.explicitTailCall).withLocOf(c)
else
Call(
Value.Ref(mainSym, S(mainDsym)),
Expand Down Expand Up @@ -1112,7 +1116,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config):
if isTrivial then
val path = Value.Ref(cls.sym, S(cls.isym))
if (inst.cls === path) && (inst.args is args) then inst
else inst.copy(cls = path, args = args)
else inst.copy(cls = path, args = args).withLocOf(inst)
else
flat.force // force computation
Call(
Expand All @@ -1124,7 +1128,7 @@ class Lifter(topLevelBlk: Block)(using State, Raise, Config):
if obj.isObj then lastWords("tried to rewrite instantiate for an object")
if isTrivial then
if c.args is args then c
else c.copy(args = args)(c.isMlsFun, c.mayRaiseEffects, c.explicitTailCall)
else c.copy(args = args)(c.isMlsFun, c.mayRaiseEffects, c.explicitTailCall).withLocOf(c)
else
flat.force // force computation
Call(
Expand Down
4 changes: 2 additions & 2 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
k(Call(
Value.Ref(State.runtimeSymbol).selN(Tree.Ident(if isAnd then "short_and" else "short_or")),
Arg(N, ar1) :: Arg(N, lamDef.asPath) :: Nil
)(true, false, false)))
)(true, true, false)))
else
subTerm_nonTail(arg2): ar2 =>
val target = wasmIntrinsicPath(sym, unary = false)
Expand Down Expand Up @@ -1063,7 +1063,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):

val stackSafe = config.stackSafety match
case N => withHandlers2
case S(sts) => StackSafeTransform(sts.stackLimit, handlerPaths, stackSafetyInfo).transformTopLevel(withHandlers2)
case S(sts) => StackSafeTransform(sts.stackLimit, config.effectHandlers.exists(_.doNotInstrumentTopLevelModCtor), handlerPaths, stackSafetyInfo).transformTopLevel(withHandlers2)

val flattened = stackSafe.flattened

Expand Down
2 changes: 1 addition & 1 deletion hkmc2/shared/src/main/scala/hkmc2/codegen/ScopeData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ object ScopeData:
clsBody.privateFields.toSet + clsBody.isym
case _: ClassCtor => Set.empty
case Func(fun, _) => fun.params.flatMap: p =>
p.params.map(_.sym)
p.restParam.map(_.sym) ++ p.params.map(_.sym)
.toSet
case ScopedBlock(_, block) => block.syms.toSet
case _: Loop => Set.empty
Expand Down
18 changes: 10 additions & 8 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/StackSafeTransform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import hkmc2.semantics.*
import hkmc2.syntax.Tree
import hkmc2.codegen.HandlerLowering.FnOrCls

class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: StackSafetyMap)(using State):
class StackSafeTransform(depthLimit: Int, doNotInstrumentTopLevelModCtor: Bool, paths: HandlerPaths, stackSafetyMap: StackSafetyMap)(using State):
private val STACK_DEPTH_IDENT: Tree.Ident = Tree.Ident("stackDepth")

private val runtimePath: Path = State.runtimeSymbol.asPath
Expand Down Expand Up @@ -78,7 +78,7 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: S

case _ => super.applyBlock(b)

override def applyHandler(hdr: Handler): Handler = lastWords("HandleBlock in stack safe transformation")
override def applyHandler(hdr: Handler): Handler = lastWords("HandleBlock in stack safe transformation")

override def applyResult(r: Result)(k: Result => Block): Block =
if usesStack(r) then
Expand Down Expand Up @@ -111,8 +111,8 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: S
methods.map(rewriteFn),
privateFields,
publicFields,
rewriteBlk(preCtor, L(BlockMemberSymbol("TODO", Nil)), 1), // TODO: preCtor is not translated in handler lowering
if isTopLevel && (defn.k is syntax.Mod) then transformTopLevel(ctor) else rewriteBlk(ctor, R(isym), 1),
preCtor,
if isTopLevel && (defn.k is syntax.Mod) then transformTopLevel(ctor) else ctor,
mod.map(rewriteObjBody(_, isTopLevel)),
bufferable,
)
Expand All @@ -123,13 +123,15 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: S
defn.methods.map(rewriteFn),
defn.privateFields,
defn.publicFields,
if isTopLevel then transformTopLevel(defn.ctor) else rewriteBlk(defn.ctor, R(defn.isym), 1),
if isTopLevel then
if doNotInstrumentTopLevelModCtor then defn.ctor else transformTopLevel(defn.ctor)
else rewriteBlk(defn.ctor, R(defn.isym)),
)

// fnOrCls points us to the doUnwind function
def rewriteBlk(blk: Block, fnOrCls: FnOrCls, increment: Int) =
def rewriteBlk(blk: Block, fnOrCls: FnOrCls) =
(stackSafetyMap.get(fnOrCls), isTrivial(blk)) match
case (S(doUnwindBlk), false) =>
case (S((increment, doUnwindBlk)), false) =>
var usedDepth = false
lazy val curDepth =
usedDepth = true
Expand All @@ -153,6 +155,6 @@ class StackSafeTransform(depthLimit: Int, paths: HandlerPaths, stackSafetyMap: S


def rewriteFn(defn: FunDefn) =
FunDefn(defn.owner, defn.sym, defn.dSym, defn.params, rewriteBlk(defn.body, L(defn.sym), 1))(defn.forceTailRec)
FunDefn(defn.owner, defn.sym, defn.dSym, defn.params, rewriteBlk(defn.body, L(defn.sym)))(defn.forceTailRec)

def transformTopLevel(b: Block) = transform(b, TempSymbol(N), true)
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,6 @@ object Elaborator:
val blockSymbol = TempSymbol(N, "Block")
val shapeSymbol = TempSymbol(N, "Shape")
val wasmSymbol = TempSymbol(N, "wasm")
val effectSigSymbol = ClassSymbol(DummyTypeDef(syntax.Cls), Ident("EffectSig"))
val nonLocalRetHandlerTrm =
val id = new Ident("NonLocalReturn")
val sym = ClassSymbol(DummyTypeDef(syntax.Cls), id)
Expand Down
3 changes: 2 additions & 1 deletion hkmc2/shared/src/test/mlscript-compile/Iter.mls
Original file line number Diff line number Diff line change
Expand Up @@ -277,4 +277,5 @@ fun fromStack(stack) = Iterable of () =>

fun toStack(xs) = xs rightFolded of Nil, Cons

fun isArrayLike = _ is Array | IterableBase
// Important: This function is used by runtime directly. Do not call any function or make any lambda here.
fun isArrayLike(xs) = xs is Array | IterableBase
Loading
Loading