Skip to content

Commit 0c48b00

Browse files
committed
.
1 parent 06a5f5d commit 0c48b00

File tree

5 files changed

+115
-36
lines changed

5 files changed

+115
-36
lines changed

mypy/checker.py

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@
3232
)
3333
from mypy.checkpattern import PatternChecker
3434
from mypy.constraints import SUPERTYPE_OF
35-
from mypy.erasetype import erase_type, erase_typevars, remove_instance_last_known_values
35+
from mypy.erasetype import (
36+
erase_type,
37+
erase_typevars,
38+
remove_instance_last_known_values,
39+
shallow_erase_type_for_equality,
40+
)
3641
from mypy.errorcodes import TYPE_VAR, UNUSED_AWAITABLE, UNUSED_COROUTINE, ErrorCode
3742
from mypy.errors import (
3843
ErrorInfo,
@@ -45,7 +50,7 @@
4550
from mypy.expandtype import expand_type
4651
from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash
4752
from mypy.maptype import map_instance_to_supertype
48-
from mypy.meet import is_overlapping_erased_types, is_overlapping_types, meet_types
53+
from mypy.meet import is_overlapping_types, meet_types
4954
from mypy.message_registry import ErrorMessage
5055
from mypy.messages import (
5156
SUGGESTED_TEST_FIXTURES,
@@ -6540,19 +6545,6 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
65406545
narrowable_indices={0},
65416546
)
65426547

6543-
# We only try and narrow away 'None' for now
6544-
if (
6545-
not is_unreachable_map(if_map)
6546-
and is_overlapping_none(item_type)
6547-
and not is_overlapping_none(collection_item_type)
6548-
and not (
6549-
isinstance(collection_item_type, Instance)
6550-
and collection_item_type.type.fullname == "builtins.object"
6551-
)
6552-
and is_overlapping_erased_types(item_type, collection_item_type)
6553-
):
6554-
if_map[operands[left_index]] = remove_optional(item_type)
6555-
65566548
if right_index in narrowable_operand_index_to_hash:
65576549
if_type, else_type = self.conditional_types_for_iterable(
65586550
item_type, iterable_type
@@ -6676,6 +6668,9 @@ def narrow_type_by_identity_equality(
66766668
target_type = operand_types[j]
66776669
if should_coerce_literals:
66786670
target_type = coerce_to_literal(target_type)
6671+
# Type A[T1] could compare equal to A[T2] even if T1 is disjoint from T2
6672+
# e.g. cast(list[int], []) == cast(list[str], [])
6673+
target_type = shallow_erase_type_for_equality(target_type)
66796674

66806675
if (
66816676
# See comments in ambiguous_enum_equality_keys
@@ -8609,13 +8604,7 @@ def reduce_and_conditional_type_maps(ms: list[TypeMap], *, use_meet: bool) -> Ty
86098604
return result
86108605

86118606

8612-
BUILTINS_CUSTOM_EQ_CHECKS: Final = {
8613-
"builtins.bytearray",
8614-
"builtins.memoryview",
8615-
"builtins.list",
8616-
"builtins.dict",
8617-
"builtins.set",
8618-
}
8607+
BUILTINS_CUSTOM_EQ_CHECKS: Final = {"builtins.bytearray", "builtins.memoryview"}
86198608

86208609

86218610
def has_custom_eq_checks(t: Type) -> bool:

mypy/erasetype.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,26 @@ def visit_union_type(self, t: UnionType) -> Type:
285285
merged.append(orig_item)
286286
return UnionType.make_union(merged)
287287
return new
288+
289+
290+
def shallow_erase_type_for_equality(typ: Type) -> ProperType:
291+
"""Erase type variables from Instance's inside a type."""
292+
p_typ = get_proper_type(typ)
293+
if isinstance(p_typ, Instance):
294+
args = erased_vars(p_typ.type.defn.type_vars, TypeOfAny.special_form)
295+
return Instance(p_typ.type, args, p_typ.line)
296+
if isinstance(p_typ, UnionType):
297+
items = [shallow_erase_type_for_equality(item) for item in p_typ.items]
298+
return UnionType.make_union(items)
299+
return p_typ
300+
301+
302+
class EraseTypeForEqualityVisitor(TypeTranslator):
303+
def visit_instance(self, t: Instance) -> ProperType:
304+
args = erased_vars(t.type.defn.type_vars, TypeOfAny.special_form)
305+
return Instance(t.type, args, t.line)
306+
307+
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
308+
if has_recursive_types(t):
309+
return t
310+
return get_proper_type(t).accept(self)

mypy/meet.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from collections.abc import Callable
44

55
from mypy import join
6-
from mypy.erasetype import erase_type
76
from mypy.maptype import map_instance_to_supertype
87
from mypy.state import state
98
from mypy.subtypes import (
@@ -657,18 +656,6 @@ def _type_object_overlap(left: Type, right: Type) -> bool:
657656
return False
658657

659658

660-
def is_overlapping_erased_types(
661-
left: Type, right: Type, *, ignore_promotions: bool = False
662-
) -> bool:
663-
"""The same as 'is_overlapping_erased_types', except the types are erased first."""
664-
return is_overlapping_types(
665-
erase_type(left),
666-
erase_type(right),
667-
ignore_promotions=ignore_promotions,
668-
prohibit_none_typevar_overlap=True,
669-
)
670-
671-
672659
def are_typed_dicts_overlapping(
673660
left: TypedDictType, right: TypedDictType, is_overlapping: Callable[[Type, Type], bool]
674661
) -> bool:

test-data/unit/check-narrowing.test

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,6 +1065,84 @@ def f(x: Custom, y: CustomSub):
10651065
reveal_type(y) # N: Revealed type is "__main__.CustomSub"
10661066
[builtins fixtures/tuple.pyi]
10671067

1068+
[case testNarrowingCustomEqualityGeneric]
1069+
# flags: --strict-equality --warn-unreachable
1070+
from __future__ import annotations
1071+
from typing import Union
1072+
1073+
class Custom:
1074+
def __eq__(self, other: object) -> bool:
1075+
raise
1076+
1077+
class Default: ...
1078+
1079+
def f(x: list[Custom] | Default, y: list[int]):
1080+
if x == y: # E: Non-overlapping equality check (left operand type: "list[Custom] | Default", right operand type: "list[int]")
1081+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom]"
1082+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int]"
1083+
else:
1084+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom] | __main__.Default"
1085+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int]"
1086+
1087+
f([], [])
1088+
1089+
def g(x: list[Custom] | Default, y: list[int] | list[Default]):
1090+
if x == y: # E: Non-overlapping equality check (left operand type: "list[Custom] | Default", right operand type: "list[int] | list[Default]")
1091+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom]"
1092+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]"
1093+
else:
1094+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom] | __main__.Default"
1095+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]"
1096+
1097+
listcustom_or_default = Union[list[Custom], Default]
1098+
listint_or_default = Union[list[int], list[Default]]
1099+
1100+
def h(x: listcustom_or_default, y: listint_or_default):
1101+
if x == y: # E: Non-overlapping equality check (left operand type: "list[Custom] | Default", right operand type: "list[int] | list[Default]")
1102+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom]"
1103+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]"
1104+
else:
1105+
reveal_type(x) # N: Revealed type is "builtins.list[__main__.Custom] | __main__.Default"
1106+
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int] | builtins.list[__main__.Default]"
1107+
[builtins fixtures/list.pyi]
1108+
1109+
[case testNarrowingRecursiveCallable]
1110+
# flags: --strict-equality --warn-unreachable
1111+
from __future__ import annotations
1112+
from typing import Callable
1113+
1114+
class A: ...
1115+
class B: ...
1116+
1117+
T = Callable[[A], "S"]
1118+
S = Callable[[B], "T"]
1119+
1120+
def f(x: S, y: T):
1121+
if x == y: # E: Unsupported left operand type for == ("Callable[[B], T]")
1122+
reveal_type(x) # N: Revealed type is "def (__main__.B) -> def (__main__.A) -> ..."
1123+
reveal_type(y) # N: Revealed type is "def (__main__.A) -> def (__main__.B) -> ..."
1124+
else:
1125+
reveal_type(x) # N: Revealed type is "def (__main__.B) -> def (__main__.A) -> ..."
1126+
reveal_type(y) # N: Revealed type is "def (__main__.A) -> def (__main__.B) -> ..."
1127+
[builtins fixtures/tuple.pyi]
1128+
1129+
[case testNarrowingRecursiveUnion]
1130+
# flags: --strict-equality --warn-unreachable
1131+
from __future__ import annotations
1132+
from typing import Union
1133+
1134+
class A: ...
1135+
class B: ...
1136+
1137+
T = Union[A, "S"]
1138+
S = Union[B, "T"] # E: Invalid recursive alias: a union item of itself
1139+
1140+
def f(x: S, y: T):
1141+
if x == y:
1142+
reveal_type(x) # N: Revealed type is "Any"
1143+
reveal_type(y) # N: Revealed type is "__main__.A | Any"
1144+
[builtins fixtures/tuple.pyi]
1145+
10681146
[case testNarrowingUnreachableCases]
10691147
# flags: --strict-equality --warn-unreachable
10701148
from typing import Literal, Union

test-data/unit/check-tuples.test

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1540,7 +1540,9 @@ class B: pass
15401540

15411541
def f1(possibles: Tuple[int, Tuple[A]], x: Optional[Tuple[B]]):
15421542
if x in possibles:
1543-
reveal_type(x) # N: Revealed type is "tuple[__main__.B]"
1543+
# TODO: this branch is actually unreachable
1544+
# This is an easy fix: https://github.com/python/mypy/pull/20660
1545+
reveal_type(x) # N: Revealed type is "tuple[__main__.B] | None"
15441546
else:
15451547
reveal_type(x) # N: Revealed type is "tuple[__main__.B] | None"
15461548

0 commit comments

Comments
 (0)