Skip to content

Commit 18e6da8

Browse files
committed
Fix type checker literal union upcasting issues
- Extend infer_arithmetic's base_ty to handle Ty::Union by folding over members with a promote helper, fixing InvalidBinaryOp for union-of-literal arithmetic (e.g. `1 + if(true){2} else {3}`) - Make widen_fresh() recursive into Union, List, Map, Optional so container literals at unannotated let-bindings widen correctly (e.g. `[1,2,3]` infers as int[] instead of (1|2|3)[]) - Fix propagate_copies in MIR cleanup to guard against inlining phi-like temps that have multiple definitions - Add subtype unit tests for union-of-literals scenarios - Un-ignore 6 tests across soundness, if_else, and match_basics
1 parent ccd9110 commit 18e6da8

38 files changed

Lines changed: 2944 additions & 1410 deletions

File tree

baml_language/crates/baml_compiler2_mir/src/cleanup.rs

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,51 @@ fn collect_place_index_locals(body: &MirFunctionBody) -> HashSet<Local> {
275275
set
276276
}
277277

278+
/// Count definition sites (assignments) for each local across all blocks.
279+
///
280+
/// A local is "single-definition" if it appears as an assignment destination
281+
/// in exactly one statement. Locals defined in multiple branches (e.g., a temp
282+
/// that is assigned in both arms of an if-else) have a count > 1 and must not
283+
/// be constant-propagated.
284+
fn count_local_defs(body: &MirFunctionBody) -> Vec<usize> {
285+
let mut defs = vec![0usize; body.locals.len()];
286+
287+
for block in &body.blocks {
288+
for stmt in &block.statements {
289+
if let crate::StatementKind::Assign { destination, .. } = &stmt.kind {
290+
// Walk the place to find the root local being defined.
291+
let mut place = destination;
292+
loop {
293+
match place {
294+
Place::Local(l) => {
295+
defs[l.0] += 1;
296+
break;
297+
}
298+
Place::Field { base, .. } => place = base,
299+
Place::Index { base, .. } => place = base,
300+
}
301+
}
302+
}
303+
}
304+
// Call destinations also count as definitions.
305+
if let Some(Terminator::Call { destination, .. }) = &block.terminator {
306+
let mut place = destination;
307+
loop {
308+
match place {
309+
Place::Local(l) => {
310+
defs[l.0] += 1;
311+
break;
312+
}
313+
Place::Field { base, .. } => place = base,
314+
Place::Index { base, .. } => place = base,
315+
}
316+
}
317+
}
318+
}
319+
320+
defs
321+
}
322+
278323
fn count_local_uses(body: &MirFunctionBody) -> Vec<usize> {
279324
let mut uses = vec![0usize; body.locals.len()];
280325

@@ -436,6 +481,7 @@ fn count_in_terminator(term: &Terminator, uses: &mut [usize]) {
436481
fn propagate_copies(body: &mut MirFunctionBody, arity: usize) {
437482
// Build substitution map: Local -> replacement Operand
438483
let uses = count_local_uses(body);
484+
let defs = count_local_defs(body);
439485
// Locals used as the `index` field of a `Place::Index` cannot be replaced
440486
// with constants — that field is typed `Local`, not `Operand`. Collect them
441487
// so we can exclude them from constant inlining below.
@@ -448,8 +494,13 @@ fn propagate_copies(body: &mut MirFunctionBody, arity: usize) {
448494
// SAFETY: Only propagate unnamed locals (compiler temporaries from
449495
// lower_to_operand / builder.temp()). Named locals (user variables from
450496
// AstStmt::Let) can be reassigned via AstStmt::Assign or AstStmt::AssignOp,
451-
// making propagation unsound. Unnamed temps are always fresh single-definition
452-
// locals, so this is safe.
497+
// making propagation unsound.
498+
//
499+
// We additionally require defs[dest] == 1 to guard against phi-like temps
500+
// that are assigned in multiple branches (e.g., the result temp of an
501+
// if-else used directly as an arithmetic operand). Such temps have a
502+
// single use-site but two definition-sites; propagating the last-seen
503+
// constant would silently use the wrong branch value.
453504
for block in &body.blocks {
454505
for stmt in &block.statements {
455506
if let crate::StatementKind::Assign {
@@ -462,6 +513,11 @@ fn propagate_copies(body: &mut MirFunctionBody, arity: usize) {
462513
continue;
463514
}
464515

516+
// Skip locals with multiple definition sites (phi-like).
517+
if defs[dest.0] != 1 {
518+
continue;
519+
}
520+
465521
match operand {
466522
Operand::Copy(Place::Local(src)) if src.0 >= 1 && src.0 <= arity => {
467523
// Copy of param — substitute
@@ -470,9 +526,9 @@ fn propagate_copies(body: &mut MirFunctionBody, arity: usize) {
470526
Operand::Constant(c)
471527
if uses[dest.0] == 1 && !used_as_place_index.contains(dest) =>
472528
{
473-
// Single-use constant — inline. Skip locals that appear
474-
// as a Place::Index index, since that position can only
475-
// hold a Local, not a Constant.
529+
// Single-use, single-definition constant — inline. Skip
530+
// locals that appear as a Place::Index index, since that
531+
// position can only hold a Local, not a Constant.
476532
subst.insert(*dest, Operand::Constant(c.clone()));
477533
}
478534
_ => {}

baml_language/crates/baml_compiler2_tir/src/builder.rs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3644,13 +3644,36 @@ impl<'db> TypeInferenceBuilder<'db> {
36443644
/// String concatenation is only valid for `Add`; other arithmetic ops on
36453645
/// strings are invalid and return `Unknown` (triggering an error upstream).
36463646
fn infer_arithmetic(op: baml_compiler2_ast::BinaryOp, lhs: &Ty, rhs: &Ty) -> Ty {
3647-
let base_ty = |ty: &Ty| -> Option<PrimitiveType> {
3647+
fn promote(a: PrimitiveType, b: &PrimitiveType) -> Option<PrimitiveType> {
3648+
if a == *b {
3649+
return Some(a);
3650+
}
3651+
match (&a, &b) {
3652+
(PrimitiveType::Int, PrimitiveType::Float)
3653+
| (PrimitiveType::Float, PrimitiveType::Int) => Some(PrimitiveType::Float),
3654+
_ => None,
3655+
}
3656+
}
3657+
3658+
fn base_ty(ty: &Ty) -> Option<PrimitiveType> {
36483659
match ty {
36493660
Ty::Primitive(p, _) => Some(p.clone()),
36503661
Ty::Literal(lit, _, _) => Some(PrimitiveType::from_literal(lit)),
3662+
Ty::Union(members, _) => {
3663+
let mut result: Option<PrimitiveType> = None;
3664+
for m in members {
3665+
let p = base_ty(m)?;
3666+
result = Some(match result {
3667+
None => p,
3668+
Some(existing) => promote(existing, &p)?,
3669+
});
3670+
}
3671+
result
3672+
}
36513673
_ => None,
36523674
}
3653-
};
3675+
}
3676+
36543677
match (base_ty(lhs), base_ty(rhs)) {
36553678
(Some(PrimitiveType::Float), _) | (_, Some(PrimitiveType::Float)) => {
36563679
Ty::Primitive(PrimitiveType::Float, TyAttr::default())

baml_language/crates/baml_compiler2_tir/src/normalize.rs

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1543,4 +1543,59 @@ mod tests {
15431543
"map<string, int> (evolving)"
15441544
);
15451545
}
1546+
1547+
// ── Literal-union subtype tests ─────────────────────────────────────────
1548+
1549+
#[test]
1550+
fn test_union_of_int_literals_subtype_of_int() {
1551+
let aliases = HashMap::new();
1552+
let sub = Ty::Union(
1553+
vec![
1554+
Ty::Literal(LiteralValue::Int(1), Freshness::Fresh, TyAttr::default()),
1555+
Ty::Literal(LiteralValue::Int(2), Freshness::Fresh, TyAttr::default()),
1556+
Ty::Literal(LiteralValue::Int(3), Freshness::Fresh, TyAttr::default()),
1557+
],
1558+
TyAttr::default(),
1559+
);
1560+
let sup = Ty::Primitive(PrimitiveType::Int, TyAttr::default());
1561+
assert!(is_subtype_of(&sub, &sup, &aliases));
1562+
}
1563+
1564+
#[test]
1565+
fn test_list_of_literal_union_subtype_of_list_int() {
1566+
let aliases = HashMap::new();
1567+
let sub = Ty::List(
1568+
Box::new(Ty::Union(
1569+
vec![
1570+
Ty::Literal(LiteralValue::Int(1), Freshness::Fresh, TyAttr::default()),
1571+
Ty::Literal(LiteralValue::Int(2), Freshness::Fresh, TyAttr::default()),
1572+
],
1573+
TyAttr::default(),
1574+
)),
1575+
TyAttr::default(),
1576+
);
1577+
let sup = Ty::List(
1578+
Box::new(Ty::Primitive(PrimitiveType::Int, TyAttr::default())),
1579+
TyAttr::default(),
1580+
);
1581+
assert!(is_subtype_of(&sub, &sup, &aliases));
1582+
}
1583+
1584+
#[test]
1585+
fn test_union_of_int_and_float_literals_subtype_of_float() {
1586+
let aliases = HashMap::new();
1587+
let sub = Ty::Union(
1588+
vec![
1589+
Ty::Literal(LiteralValue::Int(1), Freshness::Fresh, TyAttr::default()),
1590+
Ty::Literal(
1591+
LiteralValue::Float("2.0".to_string()),
1592+
Freshness::Fresh,
1593+
TyAttr::default(),
1594+
),
1595+
],
1596+
TyAttr::default(),
1597+
);
1598+
let sup = Ty::Primitive(PrimitiveType::Float, TyAttr::default());
1599+
assert!(is_subtype_of(&sub, &sup, &aliases));
1600+
}
15461601
}

baml_language/crates/baml_compiler2_tir/src/ty.rs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,41 @@ pub enum Freshness {
282282
/// Re-export `baml_base::Literal` as `LiteralValue` for backward compatibility.
283283
pub type LiteralValue = baml_base::Literal;
284284

285+
/// Flatten, deduplicate, and collapse a vec of widened types into a single `Ty`.
286+
///
287+
/// After `widen_fresh()` has run on each union member, multiple members may
288+
/// have widened to the same primitive (e.g. `[Literal(1,Fresh), Literal(2,Fresh)]`
289+
/// both become `Primitive(Int)`). This helper deduplicates and collapses:
290+
/// - Flattens nested unions one level
291+
/// - Deduplicates by `PartialEq`
292+
/// - Unwraps singletons
293+
fn dedup_and_collapse(types: Vec<Ty>) -> Ty {
294+
let mut members: Vec<Ty> = Vec::new();
295+
for ty in types {
296+
match ty {
297+
Ty::Union(inner, _) => {
298+
for m in inner {
299+
if !members.contains(&m) {
300+
members.push(m);
301+
}
302+
}
303+
}
304+
_ => {
305+
if !members.contains(&ty) {
306+
members.push(ty);
307+
}
308+
}
309+
}
310+
}
311+
match members.len() {
312+
0 => Ty::Never {
313+
attr: TyAttr::default(),
314+
},
315+
1 => members.into_iter().next().unwrap(),
316+
_ => Ty::Union(members, TyAttr::default()),
317+
}
318+
}
319+
285320
impl Ty {
286321
/// Access the `TyAttr` on this type.
287322
pub fn attr(&self) -> &TyAttr {
@@ -343,12 +378,26 @@ impl Ty {
343378
///
344379
/// Called at mutable binding sites (`let` without annotation).
345380
/// Regular (non-fresh) literals pass through unchanged.
381+
///
382+
/// Recurses into `Union`, `List`, `Map`, and `Optional` so that compound
383+
/// types like `(1 | 2 | 3)[]` widen to `int[]` at unannotated bindings.
346384
#[must_use]
347385
pub fn widen_fresh(self) -> Ty {
348386
match self {
349387
Ty::Literal(lit, Freshness::Fresh, attr) => {
350388
Ty::Primitive(PrimitiveType::from_literal(&lit), attr)
351389
}
390+
Ty::Union(members, _attr) => {
391+
let widened: Vec<Ty> = members.into_iter().map(Ty::widen_fresh).collect();
392+
dedup_and_collapse(widened)
393+
}
394+
Ty::List(inner, attr) => Ty::List(Box::new((*inner).widen_fresh()), attr),
395+
Ty::Map(k, v, attr) => Ty::Map(
396+
Box::new((*k).widen_fresh()),
397+
Box::new((*v).widen_fresh()),
398+
attr,
399+
),
400+
Ty::Optional(inner, attr) => Ty::Optional(Box::new((*inner).widen_fresh()), attr),
352401
other => other,
353402
}
354403
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
function ifElseAdd() -> int {
2+
1 + if (true) { 2 } else { 3 }
3+
}
4+
5+
function matchAdd(x: int) -> int {
6+
match (x) {
7+
1 => 10,
8+
2 => 20,
9+
_ => 0
10+
} + 1
11+
}
12+
13+
function bothSidesUnion() -> int {
14+
(if (true) { 1 } else { 2 }) + (if (false) { 3 } else { 4 })
15+
}
16+
17+
function subtractUnion() -> int {
18+
100 - if (true) { 7 } else { 3 }
19+
}
20+
21+
function loopMatchAccum() -> int {
22+
let sum = 0;
23+
let i = 0;
24+
while (i < 3) {
25+
sum = sum + match (i) {
26+
0 => 10,
27+
1 => 20,
28+
_ => 30
29+
};
30+
i = i + 1;
31+
}
32+
return sum;
33+
}
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
function arrayOfInts() -> int[] {
2+
let x = [1, 2, 3];
3+
return x;
4+
}
5+
6+
function mapOfInts() -> map<string, int> {
7+
let m = {"a": 1, "b": 2};
8+
return m;
9+
}
10+
11+
function nestedArray() -> int[][] {
12+
let x = [[1, 2], [3, 4]];
13+
return x;
14+
}
15+
16+
function annotatedLiteralUnion() -> int {
17+
let x: 1 | 2 = 1;
18+
return x;
19+
}

baml_language/crates/baml_tests/snapshots/catch_throw/baml_tests__catch_throw__04_5_mir.snap

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

baml_language/crates/baml_tests/snapshots/catch_throw/baml_tests__catch_throw__04_tir.snap

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)