Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
294 changes: 202 additions & 92 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def __init__(self) -> None:
from mypy.expandtype import expand_type
from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash
from mypy.maptype import map_instance_to_supertype
from mypy.meet import is_overlapping_erased_types, is_overlapping_types, meet_types
from mypy.meet import is_overlapping_types, meet_types
from mypy.message_registry import ErrorMessage
from mypy.messages import (
SUGGESTED_TEST_FIXTURES,
Expand Down Expand Up @@ -6720,22 +6720,6 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa
narrowable_indices={0},
)

# TODO: This remove_optional code should no longer be needed. The only
# thing it does is paper over a pre-existing deficiency in equality
# narrowing w.r.t to enums.
# We only try and narrow away 'None' for now
if (
not is_unreachable_map(if_map)
and is_overlapping_none(item_type)
and not is_overlapping_none(collection_item_type)
and not (
isinstance(collection_item_type, Instance)
and collection_item_type.type.fullname == "builtins.object"
)
and is_overlapping_erased_types(item_type, collection_item_type)
):
if_map[operands[left_index]] = remove_optional(item_type)

if right_index in narrowable_operand_index_to_hash:
if_type, else_type = self.conditional_types_for_iterable(
item_type, iterable_type
Expand Down Expand Up @@ -6820,17 +6804,15 @@ def narrow_type_by_identity_equality(
# have to be more careful about what narrowing we can conclude from a successful comparison
custom_eq_indices: set[int]

# enum_comparison_is_ambiguous:
# `if x is Fruits.APPLE` we know `x` is `Fruits.APPLE`, but `if x == Fruits.APPLE: ...`
# it could e.g. be an int or str if Fruits is an IntEnum or StrEnum.
# See ambiguous_enum_equality_keys for more details
enum_comparison_is_ambiguous: bool
# Equality can use value semantics, so `if x == Fruits.APPLE: ...` may also
# match non-enum values for IntEnum/StrEnum-like enums. Identity checks don't
# have this ambiguity.
is_identity_comparison = operator in {"is", "is not"}

if operator in {"is", "is not"}:
if is_identity_comparison:
is_target_for_value_narrowing = is_singleton_identity_type
should_coerce_literals = True
custom_eq_indices = set()
enum_comparison_is_ambiguous = False

elif operator in {"==", "!="}:
is_target_for_value_narrowing = is_singleton_equality_type
Expand All @@ -6843,7 +6825,6 @@ def narrow_type_by_identity_equality(
break

custom_eq_indices = {i for i in expr_indices if has_custom_eq_checks(operand_types[i])}
enum_comparison_is_ambiguous = True
else:
raise AssertionError

Expand All @@ -6859,8 +6840,6 @@ def narrow_type_by_identity_equality(
continue

expr_type = operand_types[i]
expr_enum_keys = ambiguous_enum_equality_keys(expr_type)
expr_type = try_expanding_sum_type_to_union(coerce_to_literal(expr_type), None)
for j in expr_indices:
if i == j:
continue
Expand All @@ -6872,18 +6851,30 @@ def narrow_type_by_identity_equality(
if should_coerce_literals:
target_type = coerce_to_literal(target_type)

if (
# See comments in ambiguous_enum_equality_keys
enum_comparison_is_ambiguous
and len(expr_enum_keys | ambiguous_enum_equality_keys(target_type)) > 1
):
continue
narrowable_expr_type, ambiguous_expr_type = partition_equality_ambiguous_types(
expr_type, target_type, is_identity=is_identity_comparison
)

target = TypeRange(target_type, is_upper_bound=False)
if narrowable_expr_type is None:
if_type = else_type = ambiguous_expr_type
else:
narrowable_expr_type = try_expanding_sum_type_to_union(
coerce_to_literal(narrowable_expr_type), None
)
if_type, else_type = conditional_types(
narrowable_expr_type,
[TypeRange(target_type, is_upper_bound=False)],
from_equality=True,
)
if ambiguous_expr_type is not None:
if_type = make_simplified_union(
[if_type or narrowable_expr_type, ambiguous_expr_type]
)
else_type = make_simplified_union(
[else_type or narrowable_expr_type, ambiguous_expr_type]
)

if_map, else_map = conditional_types_to_typemaps(
operands[i], *conditional_types(expr_type, [target], from_equality=True)
)
if_map, else_map = conditional_types_to_typemaps(operands[i], if_type, else_type)
if is_target_for_value_narrowing(get_proper_type(target_type)):
all_if_maps.append(if_map)
all_else_maps.append(else_map)
Expand Down Expand Up @@ -6964,13 +6955,29 @@ def narrow_type_by_identity_equality(
target_type = operand_types[j]
if should_coerce_literals:
target_type = coerce_to_literal(target_type)
target = TypeRange(target_type, is_upper_bound=False)

narrowable_expr_type, ambiguous_expr_type = partition_equality_ambiguous_types(
expr_type, target_type, is_identity=is_identity_comparison
)

if narrowable_expr_type is None:
if_type = else_type = ambiguous_expr_type
else:
narrowable_expr_type = coerce_to_literal(
try_expanding_sum_type_to_union(narrowable_expr_type, None)
)
if_type, else_type = conditional_types(
narrowable_expr_type,
[TypeRange(target_type, is_upper_bound=False)],
default=narrowable_expr_type,
from_equality=True,
)
if ambiguous_expr_type is not None:
if_type = make_simplified_union([if_type, ambiguous_expr_type])
else_type = make_simplified_union([else_type, ambiguous_expr_type])

if_map, else_map = conditional_types_to_typemaps(
operands[i],
*conditional_types(
expr_type, [target], default=expr_type, from_equality=True
),
operands[i], if_type, else_type
)
or_if_maps.append(if_map)
if is_target_for_value_narrowing(get_proper_type(target_type)):
Expand Down Expand Up @@ -8564,17 +8571,10 @@ def conditional_types(
# We erase generic args because values with different generic types can compare equal
# For instance, cast(list[str], []) and cast(list[int], [])
proposed_type = shallow_erase_type_for_equality(proposed_type)
if not is_overlapping_types(current_type, proposed_type, ignore_promotions=False):
# Equality narrowing is one of the places at runtime where subtyping with promotion
# does happen to match runtime semantics
# Expression is never of any type in proposed_type_ranges
return UninhabitedType(), default
if not is_overlapping_types(current_type, proposed_type, ignore_promotions=True):
return default, default
else:
if not is_overlapping_types(current_type, proposed_type, ignore_promotions=True):
# Expression is never of any type in proposed_type_ranges
return UninhabitedType(), default

if not is_overlapping_types(current_type, proposed_type, ignore_promotions=True):
# Expression is never of any type in proposed_type_ranges
return UninhabitedType(), default

# we can only restrict when the type is precise, not bounded
proposed_precise_type = UnionType.make_union(
Expand Down Expand Up @@ -8844,8 +8844,6 @@ def reduce_and_conditional_type_maps(ms: list[TypeMap], *, use_meet: bool) -> Ty


BUILTINS_CUSTOM_EQ_CHECKS: Final = {
"builtins.bytearray",
"builtins.memoryview",
"builtins.frozenset",
"_collections_abc.dict_keys",
"_collections_abc.dict_items",
Expand All @@ -8857,9 +8855,8 @@ def has_custom_eq_checks(t: Type) -> bool:
custom_special_method(t, "__eq__", check_all=False)
or custom_special_method(t, "__ne__", check_all=False)
# custom_special_method has special casing for builtins.* and typing.* that make the
# above always return False. So here we return True if the a value of a builtin type
# will ever compare equal to value of another type, e.g. a bytes value can compare equal
# to a bytearray value.
# above always return False. Some builtin collections still have equality behavior that
# crosses nominal type boundaries and isn't captured by VALUE_EQUALITY_TYPE_DOMAINS.
or (
isinstance(pt := get_proper_type(t), Instance)
and pt.type.fullname in BUILTINS_CUSTOM_EQ_CHECKS
Expand Down Expand Up @@ -9637,45 +9634,158 @@ def visit_starred_pattern(self, p: StarredPattern) -> None:
self.lvalue = False


def ambiguous_enum_equality_keys(t: Type) -> set[str]:
"""
Used when narrowing types based on equality.
# Open domains also block cross-type narrowing for known domain members, but they
# don't provide an exhaustive union to narrow top types to.
OPEN_VALUE_EQUALITY_DOMAINS: Final = {
"builtins.str": "builtins.str",
"builtins.bool": "builtins.numeric",
"builtins.int": "builtins.numeric",
"builtins.float": "builtins.numeric",
"builtins.complex": "builtins.numeric",
}
OPEN_VALUE_EQUALITY_DOMAIN_NAMES: Final = frozenset(OPEN_VALUE_EQUALITY_DOMAINS.values())

# Closed domains also block ordinary cross-type narrowing within the domain.
CLOSED_VALUE_EQUALITY_DOMAINS: Final = {
"builtins.bytes": "builtins.bytes",
"builtins.bytearray": "builtins.bytes",
"builtins.memoryview": "builtins.bytes",
}

VALUE_EQUALITY_DOMAINS: Final = {**OPEN_VALUE_EQUALITY_DOMAINS, **CLOSED_VALUE_EQUALITY_DOMAINS}


Certain kinds of enums can compare equal to values of other types, so doing type math
the way `conditional_types` does will be misleading if you expect it to correspond to
conditions based on equality comparisons.
class EqualityDomainInfo(NamedTuple):
type_names: set[str]
enum_type_names: set[str]

For example, StrEnum classes can compare equal to str values. So if we see
`val: StrEnum; if val == "foo": ...` we currently avoid narrowing.
Note that we do wish to continue narrowing for `if val == StrEnum.MEMBER: ...`

class EqualityValueInfo(NamedTuple):
domains: dict[str, EqualityDomainInfo]
is_top: bool


def closed_equality_domain_type_names(info: EqualityValueInfo) -> list[str]:
return [
fullname
for fullname, domain in CLOSED_VALUE_EQUALITY_DOMAINS.items()
if domain in info.domains
]


def partition_equality_ambiguous_types(
current_type: Type, target_type: Type, *, is_identity: bool
) -> tuple[Type | None, Type | None]:
"""Split current_type into ordinary-narrowable and equality-ambiguous pieces.

Some values compare equal through a value domain broader than their nominal type. For
example, an IntEnum member can compare equal to an int, and a StrEnum member can compare
equal to a str. When narrowing `x: MyStrEnum | str` against `MyStrEnum.MEMBER`, we can
still narrow the enum portion of the union, but we must keep the str portion in both
branches.
"""
# We need these things for this to be ambiguous:
# (1) an IntEnum or StrEnum type or enum subclass of int or str
# (2) either a different IntEnum/StrEnum type or a non-enum type ("<other>")
result = set()
if is_identity:
return current_type, None

typ = get_proper_type(current_type)
items = typ.relevant_items() if isinstance(typ, UnionType) else [current_type]
narrowable_items = []
ambiguous_items = []
for item in items:
if is_equality_ambiguous_for_narrowing(item, target_type):
ambiguous_items.append(item)
else:
narrowable_items.append(item)
return (
UnionType.make_union(narrowable_items) if narrowable_items else None,
UnionType.make_union(ambiguous_items) if ambiguous_items else None,
)


def is_equality_ambiguous_for_narrowing(left: Type, right: Type) -> bool:
"""Can left compare equal to right through a value domain outside nominal overlap?"""
left_info = equality_value_info(left)
right_info = equality_value_info(right)

if left_info.is_top or right_info.is_top:
# Only open-domain enum values can make a top-like type ambiguous.
# Closed domains can be narrowed to their complete known set instead.
other_info = right_info if left_info.is_top else left_info
return any(
domain in OPEN_VALUE_EQUALITY_DOMAIN_NAMES and domain_info.enum_type_names
for domain, domain_info in other_info.domains.items()
)

shared_domains = left_info.domains.keys() & right_info.domains.keys()
if not shared_domains:
return False

for domain in shared_domains:
left_domain = left_info.domains[domain]
right_domain = right_info.domains[domain]
# Equality between two values from the same enum can still narrow by literal member.
if (
left_domain.enum_type_names
and left_domain.enum_type_names == right_domain.enum_type_names
and left_domain.type_names == left_domain.enum_type_names
and right_domain.type_names == right_domain.enum_type_names
):
continue
# Different domain-member types may compare equal, but nominal narrowing would
# otherwise treat them as disjoint.
if left_domain.type_names != right_domain.type_names:
return True
# Same domain-member types are only ambiguous if an enum value may compare equal to
# its underlying value type.
if left_domain.enum_type_names or right_domain.enum_type_names:
return True

return False


def equality_value_info(t: Type) -> EqualityValueInfo:
t = get_proper_type(t)
if isinstance(t, UnionType):
for item in t.items:
result.update(ambiguous_enum_equality_keys(item))
elif isinstance(t, Instance):
if t.last_known_value:
result.update(ambiguous_enum_equality_keys(t.last_known_value))
elif t.type.is_enum and any(
base.fullname in ("enum.IntEnum", "enum.StrEnum", "builtins.str", "builtins.int")
for base in t.type.mro
):
result.add(t.type.fullname)
elif not t.type.is_enum:
# These might compare equal to IntEnum/StrEnum types (e.g. Decimal), so
# let's be conservative
result.add("<other>")
elif isinstance(t, LiteralType):
result.update(ambiguous_enum_equality_keys(t.fallback))
elif isinstance(t, NoneType):
pass
else:
result.add("<other>")
return result
return combine_equality_value_info(equality_value_info(item) for item in t.items)
if isinstance(t, TypeVarType):
if t.values:
return combine_equality_value_info(equality_value_info(item) for item in t.values)
return equality_value_info(t.upper_bound)
if isinstance(t, Instance) and t.last_known_value is not None:
return equality_value_info(t.last_known_value)
if isinstance(t, LiteralType):
return equality_value_info(t.fallback)
if isinstance(t, Instance):
if t.type.fullname == "builtins.object":
return EqualityValueInfo({}, is_top=True)

enum_type_names = {t.type.fullname} if t.type.is_enum else set()
domains = {}
for base in t.type.mro:
if domain := VALUE_EQUALITY_DOMAINS.get(base.fullname):
domains[domain] = EqualityDomainInfo({t.type.fullname}, enum_type_names)

return EqualityValueInfo(domains, is_top=False)
if isinstance(t, AnyType):
return EqualityValueInfo({}, is_top=True)
return EqualityValueInfo({}, is_top=False)


def combine_equality_value_info(infos: Iterable[EqualityValueInfo]) -> EqualityValueInfo:
domains: dict[str, EqualityDomainInfo] = {}
is_top = False
for info in infos:
for domain, domain_info in info.domains.items():
existing_domain_info = domains.get(domain)
if existing_domain_info is None:
domains[domain] = EqualityDomainInfo(
set(domain_info.type_names), set(domain_info.enum_type_names)
)
else:
existing_domain_info.type_names.update(domain_info.type_names)
existing_domain_info.enum_type_names.update(domain_info.enum_type_names)
is_top = is_top or info.is_top
return EqualityValueInfo(domains, is_top)


def is_typeddict_type_context(lvalue_type: Type) -> bool:
Expand Down
Loading
Loading