Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 77 additions & 22 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,7 +743,9 @@ def incompatible_argument_note(
context: Context,
code: ErrorCode | None,
) -> None:
if isinstance(original_caller_type, (Instance, TupleType, TypedDictType)):
if isinstance(
original_caller_type, (Instance, TupleType, TypedDictType, TypeType, CallableType)
):
if isinstance(callee_type, Instance) and callee_type.type.is_protocol:
self.report_protocol_problems(
original_caller_type, callee_type, context, code=code
Expand Down Expand Up @@ -1791,7 +1793,7 @@ def impossible_intersection(

def report_protocol_problems(
self,
subtype: Instance | TupleType | TypedDictType,
subtype: Instance | TupleType | TypedDictType | TypeType | CallableType,
supertype: Instance,
context: Context,
*,
Expand All @@ -1811,15 +1813,15 @@ def report_protocol_problems(
exclusions: dict[type, list[str]] = {
TypedDictType: ["typing.Mapping"],
TupleType: ["typing.Iterable", "typing.Sequence"],
Instance: [],
}
if supertype.type.fullname in exclusions[type(subtype)]:
if supertype.type.fullname in exclusions.get(type(subtype), []):
return
if any(isinstance(tp, UninhabitedType) for tp in get_proper_types(supertype.args)):
# We don't want to add notes for failed inference (e.g. Iterable[<nothing>]).
# This will be only confusing a user even more.
return

class_obj = False
if isinstance(subtype, TupleType):
if not isinstance(subtype.partial_fallback, Instance):
return
Expand All @@ -1828,6 +1830,21 @@ def report_protocol_problems(
if not isinstance(subtype.fallback, Instance):
return
subtype = subtype.fallback
elif isinstance(subtype, TypeType):
if not isinstance(subtype.item, Instance):
return
class_obj = True
subtype = subtype.item
elif isinstance(subtype, CallableType):
if not subtype.is_type_obj():
return
ret_type = get_proper_type(subtype.ret_type)
if isinstance(ret_type, TupleType):
ret_type = ret_type.partial_fallback
if not isinstance(ret_type, Instance):
return
class_obj = True
subtype = ret_type

# Report missing members
missing = get_missing_protocol_members(subtype, supertype)
Expand All @@ -1836,20 +1853,29 @@ def report_protocol_problems(
and len(missing) < len(supertype.type.protocol_members)
and len(missing) <= MAX_ITEMS
):
self.note(
'"{}" is missing following "{}" protocol member{}:'.format(
subtype.type.name, supertype.type.name, plural_s(missing)
),
context,
code=code,
)
self.note(", ".join(missing), context, offset=OFFSET, code=code)
if missing == ["__call__"] and class_obj:
self.note(
'"{}" has constructor incompatible with "__call__" of "{}"'.format(
subtype.type.name, supertype.type.name
),
context,
code=code,
)
else:
self.note(
'"{}" is missing following "{}" protocol member{}:'.format(
subtype.type.name, supertype.type.name, plural_s(missing)
),
context,
code=code,
)
self.note(", ".join(missing), context, offset=OFFSET, code=code)
elif len(missing) > MAX_ITEMS or len(missing) == len(supertype.type.protocol_members):
# This is an obviously wrong type: too many missing members
return

# Report member type conflicts
conflict_types = get_conflict_protocol_types(subtype, supertype)
conflict_types = get_conflict_protocol_types(subtype, supertype, class_obj=class_obj)
if conflict_types and (
not is_subtype(subtype, erase_type(supertype))
or not subtype.type.defn.type_vars
Expand All @@ -1875,16 +1901,30 @@ def report_protocol_problems(
else:
self.note("Expected:", context, offset=OFFSET, code=code)
if isinstance(exp, CallableType):
self.note(pretty_callable(exp), context, offset=2 * OFFSET, code=code)
self.note(
pretty_callable(exp, skip_self=class_obj),
context,
offset=2 * OFFSET,
code=code,
)
else:
assert isinstance(exp, Overloaded)
self.pretty_overload(exp, context, 2 * OFFSET, code=code)
self.pretty_overload(
exp, context, 2 * OFFSET, code=code, skip_self=class_obj
)
self.note("Got:", context, offset=OFFSET, code=code)
if isinstance(got, CallableType):
self.note(pretty_callable(got), context, offset=2 * OFFSET, code=code)
self.note(
pretty_callable(got, skip_self=class_obj),
context,
offset=2 * OFFSET,
code=code,
)
else:
assert isinstance(got, Overloaded)
self.pretty_overload(got, context, 2 * OFFSET, code=code)
self.pretty_overload(
got, context, 2 * OFFSET, code=code, skip_self=class_obj
)
self.print_more(conflict_types, context, OFFSET, MAX_ITEMS, code=code)

# Report flag conflicts (i.e. settable vs read-only etc.)
Expand Down Expand Up @@ -1930,6 +1970,7 @@ def pretty_overload(
add_class_or_static_decorator: bool = False,
allow_dups: bool = False,
code: ErrorCode | None = None,
skip_self: bool = False,
) -> None:
for item in tp.items:
self.note("@overload", context, offset=offset, allow_dups=allow_dups, code=code)
Expand All @@ -1940,7 +1981,11 @@ def pretty_overload(
self.note(decorator, context, offset=offset, allow_dups=allow_dups, code=code)

self.note(
pretty_callable(item), context, offset=offset, allow_dups=allow_dups, code=code
pretty_callable(item, skip_self=skip_self),
context,
offset=offset,
allow_dups=allow_dups,
code=code,
)

def print_more(
Expand Down Expand Up @@ -2373,10 +2418,14 @@ def pretty_class_or_static_decorator(tp: CallableType) -> str | None:
return None


def pretty_callable(tp: CallableType) -> str:
def pretty_callable(tp: CallableType, skip_self: bool = False) -> str:
"""Return a nice easily-readable representation of a callable type.
For example:
def [T <: int] f(self, x: int, y: T) -> None

If skip_self is True, print an actual callable type, as it would appear
when bound on an instance/class, rather than how it would appear in the
defining statement.
"""
s = ""
asterisk = False
Expand Down Expand Up @@ -2420,7 +2469,11 @@ def [T <: int] f(self, x: int, y: T) -> None
and hasattr(tp.definition, "arguments")
):
definition_arg_names = [arg.variable.name for arg in tp.definition.arguments]
if len(definition_arg_names) > len(tp.arg_names) and definition_arg_names[0]:
if (
len(definition_arg_names) > len(tp.arg_names)
and definition_arg_names[0]
and not skip_self
):
if s:
s = ", " + s
s = definition_arg_names[0] + s
Expand Down Expand Up @@ -2487,7 +2540,9 @@ def get_missing_protocol_members(left: Instance, right: Instance) -> list[str]:
return missing


def get_conflict_protocol_types(left: Instance, right: Instance) -> list[tuple[str, Type, Type]]:
def get_conflict_protocol_types(
left: Instance, right: Instance, class_obj: bool = False
) -> list[tuple[str, Type, Type]]:
"""Find members that are defined in 'left' but have incompatible types.
Return them as a list of ('member', 'got', 'expected').
"""
Expand All @@ -2498,7 +2553,7 @@ def get_conflict_protocol_types(left: Instance, right: Instance) -> list[tuple[s
continue
supertype = find_member(member, right, left)
assert supertype is not None
subtype = find_member(member, left, left)
subtype = find_member(member, left, left, class_obj=class_obj)
if not subtype:
continue
is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=True)
Expand Down
57 changes: 45 additions & 12 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,14 @@ def visit_callable_type(self, left: CallableType) -> bool:
assert call is not None
if self._is_subtype(left, call):
return True
if right.type.is_protocol and left.is_type_obj():
ret_type = get_proper_type(left.ret_type)
if isinstance(ret_type, TupleType):
ret_type = mypy.typeops.tuple_fallback(ret_type)
if isinstance(ret_type, Instance) and is_protocol_implementation(
ret_type, right, proper_subtype=self.proper_subtype, class_obj=True
):
return True
return self._is_subtype(left.fallback, right)
elif isinstance(right, TypeType):
# This is unsound, we don't check the __init__ signature.
Expand Down Expand Up @@ -897,6 +905,10 @@ def visit_type_type(self, left: TypeType) -> bool:
if isinstance(item, TypeVarType):
item = get_proper_type(item.upper_bound)
if isinstance(item, Instance):
if right.type.is_protocol and is_protocol_implementation(
item, right, proper_subtype=self.proper_subtype, class_obj=True
):
return True
metaclass = item.type.metaclass_type
return metaclass is not None and self._is_subtype(metaclass, right)
return False
Expand All @@ -916,7 +928,7 @@ def pop_on_exit(stack: list[tuple[T, T]], left: T, right: T) -> Iterator[None]:


def is_protocol_implementation(
left: Instance, right: Instance, proper_subtype: bool = False
left: Instance, right: Instance, proper_subtype: bool = False, class_obj: bool = False
) -> bool:
"""Check whether 'left' implements the protocol 'right'.

Expand Down Expand Up @@ -959,7 +971,19 @@ def f(self) -> A: ...
# We always bind self to the subtype. (Similarly to nominal types).
supertype = get_proper_type(find_member(member, right, left))
assert supertype is not None
subtype = get_proper_type(find_member(member, left, left))
if member == "__call__" and class_obj:
# Special case: class objects always have __call__ that is just the constructor.
# TODO: move this helper function to typeops.py?
import mypy.checkmember

def named_type(fullname: str) -> Instance:
return Instance(left.type.mro[-1], [])

subtype: ProperType | None = mypy.checkmember.type_object_type(
left.type, named_type
)
else:
subtype = get_proper_type(find_member(member, left, left, class_obj=class_obj))
# Useful for debugging:
# print(member, 'of', left, 'has type', subtype)
# print(member, 'of', right, 'has type', supertype)
Expand Down Expand Up @@ -1014,7 +1038,7 @@ def f(self) -> A: ...


def find_member(
name: str, itype: Instance, subtype: Type, is_operator: bool = False
name: str, itype: Instance, subtype: Type, is_operator: bool = False, class_obj: bool = False
) -> Type | None:
"""Find the type of member by 'name' in 'itype's TypeInfo.

Expand All @@ -1027,23 +1051,24 @@ def find_member(
method = info.get_method(name)
if method:
if isinstance(method, Decorator):
return find_node_type(method.var, itype, subtype)
return find_node_type(method.var, itype, subtype, class_obj=class_obj)
if method.is_property:
assert isinstance(method, OverloadedFuncDef)
dec = method.items[0]
assert isinstance(dec, Decorator)
return find_node_type(dec.var, itype, subtype)
return find_node_type(method, itype, subtype)
return find_node_type(dec.var, itype, subtype, class_obj=class_obj)
return find_node_type(method, itype, subtype, class_obj=class_obj)
else:
# don't have such method, maybe variable or decorator?
node = info.get(name)
v = node.node if node else None
if isinstance(v, Var):
return find_node_type(v, itype, subtype)
return find_node_type(v, itype, subtype, class_obj=class_obj)
if (
not v
and name not in ["__getattr__", "__setattr__", "__getattribute__"]
and not is_operator
and not class_obj
):
for method_name in ("__getattribute__", "__getattr__"):
# Normally, mypy assumes that instances that define __getattr__ have all
Expand Down Expand Up @@ -1107,7 +1132,9 @@ def get_member_flags(name: str, info: TypeInfo) -> set[int]:
return set()


def find_node_type(node: Var | FuncBase, itype: Instance, subtype: Type) -> Type:
def find_node_type(
node: Var | FuncBase, itype: Instance, subtype: Type, class_obj: bool = False
) -> Type:
"""Find type of a variable or method 'node' (maybe also a decorated method).
Apply type arguments from 'itype', and bind 'self' to 'subtype'.
"""
Expand All @@ -1129,10 +1156,16 @@ def find_node_type(node: Var | FuncBase, itype: Instance, subtype: Type) -> Type
and not node.is_staticmethod
):
assert isinstance(p_typ, FunctionLike)
signature = bind_self(
p_typ, subtype, is_classmethod=isinstance(node, Var) and node.is_classmethod
)
if node.is_property:
if class_obj and not (
node.is_class if isinstance(node, FuncBase) else node.is_classmethod
):
# Don't bind instance methods on class objects.
signature = p_typ
else:
signature = bind_self(
p_typ, subtype, is_classmethod=isinstance(node, Var) and node.is_classmethod
)
if node.is_property and not class_obj:
assert isinstance(signature, CallableType)
typ = signature.ret_type
else:
Expand Down
Loading