Skip to content

Commit 5476bc3

Browse files
[Backport maintenance/4.0.x] Wrong inference with default argument values (#2924)
Fix overzealous filtering of `IfExp` inference (#2914) (cherry picked from commit 178a796) Co-authored-by: jkmnt <git@firewood.fastmail.com>
1 parent ae761dc commit 5476bc3

3 files changed

Lines changed: 120 additions & 18 deletions

File tree

ChangeLog

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,16 @@ What's New in astroid 4.0.3?
1313
============================
1414
Release date: TBA
1515

16+
* Fix inference of ``IfExp`` (ternary expression) nodes to avoid prematurely narrowing
17+
results in the face of inference ambiguity.
18+
19+
Closes #2899
20+
1621
* Fix base class inference for dataclasses using the PEP 695 typing syntax.
1722

1823
Refs pylint-dev/pylint#10788
1924

2025

21-
2226
What's New in astroid 4.0.2?
2327
============================
2428
Release date: 2025-11-09

astroid/nodes/node_classes.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3108,31 +3108,37 @@ def _infer(
31083108
to inferring both branches. Otherwise, we infer either branch
31093109
depending on the condition.
31103110
"""
3111-
both_branches = False
3111+
31123112
# We use two separate contexts for evaluating lhs and rhs because
31133113
# evaluating lhs may leave some undesired entries in context.path
31143114
# which may not let us infer right value of rhs.
3115-
31163115
context = context or InferenceContext()
31173116
lhs_context = copy_context(context)
31183117
rhs_context = copy_context(context)
3118+
3119+
# Infer bool condition. Stop inferring if in doubt and fallback to
3120+
# evaluating both branches.
3121+
condition: bool | None = None
31193122
try:
3120-
test = next(self.test.infer(context=context.clone()))
3121-
except (InferenceError, StopIteration):
3122-
both_branches = True
3123-
else:
3124-
test_bool_value = test.bool_value()
3125-
if not isinstance(test, util.UninferableBase) and not isinstance(
3126-
test_bool_value, util.UninferableBase
3127-
):
3128-
if test_bool_value:
3129-
yield from self.body.infer(context=lhs_context)
3130-
else:
3131-
yield from self.orelse.infer(context=rhs_context)
3132-
else:
3133-
both_branches = True
3134-
if both_branches:
3123+
for test in self.test.infer(context=context.clone()):
3124+
if isinstance(test, util.UninferableBase):
3125+
condition = None
3126+
break
3127+
test_bool_value = test.bool_value()
3128+
if isinstance(test_bool_value, util.UninferableBase):
3129+
condition = None
3130+
break
3131+
if condition is None:
3132+
condition = test_bool_value
3133+
elif test_bool_value != condition:
3134+
condition = None
3135+
break
3136+
except InferenceError:
3137+
condition = None
3138+
3139+
if condition is True or condition is None:
31353140
yield from self.body.infer(context=lhs_context)
3141+
if condition is False or condition is None:
31363142
yield from self.orelse.infer(context=rhs_context)
31373143

31383144

tests/test_inference.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6442,6 +6442,98 @@ def both_branches():
64426442
assert [third[0].value, third[1].value] == [1, 2]
64436443

64446444

6445+
def test_ifexp_with_default_arguments() -> None:
6446+
code = """
6447+
def with_default(foo: str | None = None):
6448+
a = 1 if foo else "bar" #@
6449+
6450+
def without_default(foo: str):
6451+
a = 1 if foo else "bar" #@
6452+
6453+
def some_ifexps(foo: str | None = None):
6454+
a = 1 if foo else 2
6455+
b = 3 if a else 4 #@
6456+
c = 4 if b else 5 #@
6457+
d = 5 if not foo else foo #@
6458+
e = d if not foo else foo #@
6459+
"""
6460+
6461+
ast_nodes = extract_node(code)
6462+
6463+
first = ast_nodes[0].value.inferred()
6464+
second = ast_nodes[1].value.inferred()
6465+
third = ast_nodes[2].value.inferred()
6466+
fourth = ast_nodes[3].value.inferred()
6467+
fifth = ast_nodes[4].value.inferred()
6468+
sixth = ast_nodes[5].value.inferred()
6469+
6470+
assert len(first) == 2
6471+
assert [first[0].value, first[1].value] == [1, "bar"]
6472+
6473+
assert len(second) == 2
6474+
assert [second[0].value, second[1].value] == [1, "bar"]
6475+
6476+
assert len(third) == 1
6477+
assert third[0].value == 3
6478+
6479+
assert len(fourth) == 1
6480+
assert fourth[0].value == 4
6481+
6482+
assert len(fifth) == 2
6483+
assert [fifth[0].value, fifth[1].value] == [5, Uninferable]
6484+
6485+
assert len(sixth) == 3
6486+
assert [sixth[0].value, sixth[1].value, sixth[2].value] == [
6487+
5,
6488+
Uninferable,
6489+
Uninferable,
6490+
]
6491+
6492+
6493+
def test_ifexp_with_uninferables() -> None:
6494+
code = """
6495+
def truthy_and_falsy():
6496+
return False if unknown() else True
6497+
6498+
def truthy_and_uninferable():
6499+
return False if unknown() else unknown()
6500+
6501+
def calls_truthy_and_falsy():
6502+
return 1 if truthy_and_falsy() else 2
6503+
6504+
def calls_truthy_and_uninferable():
6505+
return 1 if range(10) else truthy_and_uninferable()
6506+
6507+
truthy_and_falsy() #@
6508+
truthy_and_uninferable() #@
6509+
calls_truthy_and_falsy() #@
6510+
calls_truthy_and_uninferable() #@
6511+
"""
6512+
6513+
ast_nodes = extract_node(code)
6514+
6515+
first = ast_nodes[0].inferred()
6516+
second = ast_nodes[1].inferred()
6517+
third = ast_nodes[2].inferred()
6518+
fourth = ast_nodes[3].inferred()
6519+
6520+
assert len(first) == 2
6521+
assert [first[0].value, first[1].value] == [False, True]
6522+
6523+
assert len(second) == 2
6524+
assert [second[0].value, second[1].value] == [False, Uninferable]
6525+
6526+
assert len(third) == 2
6527+
assert [third[0].value, third[1].value] == [1, 2]
6528+
6529+
assert len(fourth) == 3
6530+
assert [fourth[0].value, fourth[1].value, fourth[2].value] == [
6531+
1,
6532+
False,
6533+
Uninferable,
6534+
]
6535+
6536+
64456537
def test_assert_last_function_returns_none_on_inference() -> None:
64466538
code = """
64476539
def check_equal(a, b):

0 commit comments

Comments
 (0)