diff --git a/mypy/checker.py b/mypy/checker.py index 59571954e0f7..255920e802d8 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -108,7 +108,7 @@ def __init__(self) -> None: from mypy.expandtype import expand_type from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash from mypy.maptype import map_instance_to_supertype -from mypy.meet import is_overlapping_erased_types, is_overlapping_types, meet_types +from mypy.meet import is_overlapping_types, meet_types from mypy.message_registry import ErrorMessage from mypy.messages import ( SUGGESTED_TEST_FIXTURES, @@ -6720,22 +6720,6 @@ def comparison_type_narrowing_helper(self, node: ComparisonExpr) -> tuple[TypeMa narrowable_indices={0}, ) - # TODO: This remove_optional code should no longer be needed. The only - # thing it does is paper over a pre-existing deficiency in equality - # narrowing w.r.t to enums. - # We only try and narrow away 'None' for now - if ( - not is_unreachable_map(if_map) - and is_overlapping_none(item_type) - and not is_overlapping_none(collection_item_type) - and not ( - isinstance(collection_item_type, Instance) - and collection_item_type.type.fullname == "builtins.object" - ) - and is_overlapping_erased_types(item_type, collection_item_type) - ): - if_map[operands[left_index]] = remove_optional(item_type) - if right_index in narrowable_operand_index_to_hash: if_type, else_type = self.conditional_types_for_iterable( item_type, iterable_type @@ -6820,17 +6804,15 @@ def narrow_type_by_identity_equality( # have to be more careful about what narrowing we can conclude from a successful comparison custom_eq_indices: set[int] - # enum_comparison_is_ambiguous: - # `if x is Fruits.APPLE` we know `x` is `Fruits.APPLE`, but `if x == Fruits.APPLE: ...` - # it could e.g. be an int or str if Fruits is an IntEnum or StrEnum. - # See ambiguous_enum_equality_keys for more details - enum_comparison_is_ambiguous: bool + # Equality can use value semantics, so `if x == Fruits.APPLE: ...` may also + # match non-enum values for IntEnum/StrEnum-like enums. Identity checks don't + # have this ambiguity. + is_identity_comparison = operator in {"is", "is not"} - if operator in {"is", "is not"}: + if is_identity_comparison: is_target_for_value_narrowing = is_singleton_identity_type should_coerce_literals = True custom_eq_indices = set() - enum_comparison_is_ambiguous = False elif operator in {"==", "!="}: is_target_for_value_narrowing = is_singleton_equality_type @@ -6843,7 +6825,6 @@ def narrow_type_by_identity_equality( break custom_eq_indices = {i for i in expr_indices if has_custom_eq_checks(operand_types[i])} - enum_comparison_is_ambiguous = True else: raise AssertionError @@ -6859,8 +6840,6 @@ def narrow_type_by_identity_equality( continue expr_type = operand_types[i] - expr_enum_keys = ambiguous_enum_equality_keys(expr_type) - expr_type = try_expanding_sum_type_to_union(coerce_to_literal(expr_type), None) for j in expr_indices: if i == j: continue @@ -6872,18 +6851,30 @@ def narrow_type_by_identity_equality( if should_coerce_literals: target_type = coerce_to_literal(target_type) - if ( - # See comments in ambiguous_enum_equality_keys - enum_comparison_is_ambiguous - and len(expr_enum_keys | ambiguous_enum_equality_keys(target_type)) > 1 - ): - continue + narrowable_expr_type, ambiguous_expr_type = partition_equality_ambiguous_types( + expr_type, target_type, is_identity=is_identity_comparison + ) - target = TypeRange(target_type, is_upper_bound=False) + if narrowable_expr_type is None: + if_type = else_type = ambiguous_expr_type + else: + narrowable_expr_type = try_expanding_sum_type_to_union( + coerce_to_literal(narrowable_expr_type), None + ) + if_type, else_type = conditional_types( + narrowable_expr_type, + [TypeRange(target_type, is_upper_bound=False)], + from_equality=True, + ) + if ambiguous_expr_type is not None: + if_type = make_simplified_union( + [if_type or narrowable_expr_type, ambiguous_expr_type] + ) + else_type = make_simplified_union( + [else_type or narrowable_expr_type, ambiguous_expr_type] + ) - if_map, else_map = conditional_types_to_typemaps( - operands[i], *conditional_types(expr_type, [target], from_equality=True) - ) + if_map, else_map = conditional_types_to_typemaps(operands[i], if_type, else_type) if is_target_for_value_narrowing(get_proper_type(target_type)): all_if_maps.append(if_map) all_else_maps.append(else_map) @@ -6964,13 +6955,29 @@ def narrow_type_by_identity_equality( target_type = operand_types[j] if should_coerce_literals: target_type = coerce_to_literal(target_type) - target = TypeRange(target_type, is_upper_bound=False) + + narrowable_expr_type, ambiguous_expr_type = partition_equality_ambiguous_types( + expr_type, target_type, is_identity=is_identity_comparison + ) + + if narrowable_expr_type is None: + if_type = else_type = ambiguous_expr_type + else: + narrowable_expr_type = coerce_to_literal( + try_expanding_sum_type_to_union(narrowable_expr_type, None) + ) + if_type, else_type = conditional_types( + narrowable_expr_type, + [TypeRange(target_type, is_upper_bound=False)], + default=narrowable_expr_type, + from_equality=True, + ) + if ambiguous_expr_type is not None: + if_type = make_simplified_union([if_type, ambiguous_expr_type]) + else_type = make_simplified_union([else_type, ambiguous_expr_type]) if_map, else_map = conditional_types_to_typemaps( - operands[i], - *conditional_types( - expr_type, [target], default=expr_type, from_equality=True - ), + operands[i], if_type, else_type ) or_if_maps.append(if_map) if is_target_for_value_narrowing(get_proper_type(target_type)): @@ -8564,17 +8571,10 @@ def conditional_types( # We erase generic args because values with different generic types can compare equal # For instance, cast(list[str], []) and cast(list[int], []) proposed_type = shallow_erase_type_for_equality(proposed_type) - if not is_overlapping_types(current_type, proposed_type, ignore_promotions=False): - # Equality narrowing is one of the places at runtime where subtyping with promotion - # does happen to match runtime semantics - # Expression is never of any type in proposed_type_ranges - return UninhabitedType(), default - if not is_overlapping_types(current_type, proposed_type, ignore_promotions=True): - return default, default - else: - if not is_overlapping_types(current_type, proposed_type, ignore_promotions=True): - # Expression is never of any type in proposed_type_ranges - return UninhabitedType(), default + + if not is_overlapping_types(current_type, proposed_type, ignore_promotions=True): + # Expression is never of any type in proposed_type_ranges + return UninhabitedType(), default # we can only restrict when the type is precise, not bounded proposed_precise_type = UnionType.make_union( @@ -8844,8 +8844,6 @@ def reduce_and_conditional_type_maps(ms: list[TypeMap], *, use_meet: bool) -> Ty BUILTINS_CUSTOM_EQ_CHECKS: Final = { - "builtins.bytearray", - "builtins.memoryview", "builtins.frozenset", "_collections_abc.dict_keys", "_collections_abc.dict_items", @@ -8857,9 +8855,8 @@ def has_custom_eq_checks(t: Type) -> bool: custom_special_method(t, "__eq__", check_all=False) or custom_special_method(t, "__ne__", check_all=False) # custom_special_method has special casing for builtins.* and typing.* that make the - # above always return False. So here we return True if the a value of a builtin type - # will ever compare equal to value of another type, e.g. a bytes value can compare equal - # to a bytearray value. + # above always return False. Some builtin collections still have equality behavior that + # crosses nominal type boundaries and isn't captured by VALUE_EQUALITY_TYPE_DOMAINS. or ( isinstance(pt := get_proper_type(t), Instance) and pt.type.fullname in BUILTINS_CUSTOM_EQ_CHECKS @@ -9637,45 +9634,158 @@ def visit_starred_pattern(self, p: StarredPattern) -> None: self.lvalue = False -def ambiguous_enum_equality_keys(t: Type) -> set[str]: - """ - Used when narrowing types based on equality. +# Open domains also block cross-type narrowing for known domain members, but they +# don't provide an exhaustive union to narrow top types to. +OPEN_VALUE_EQUALITY_DOMAINS: Final = { + "builtins.str": "builtins.str", + "builtins.bool": "builtins.numeric", + "builtins.int": "builtins.numeric", + "builtins.float": "builtins.numeric", + "builtins.complex": "builtins.numeric", +} +OPEN_VALUE_EQUALITY_DOMAIN_NAMES: Final = frozenset(OPEN_VALUE_EQUALITY_DOMAINS.values()) + +# Closed domains also block ordinary cross-type narrowing within the domain. +CLOSED_VALUE_EQUALITY_DOMAINS: Final = { + "builtins.bytes": "builtins.bytes", + "builtins.bytearray": "builtins.bytes", + "builtins.memoryview": "builtins.bytes", +} + +VALUE_EQUALITY_DOMAINS: Final = {**OPEN_VALUE_EQUALITY_DOMAINS, **CLOSED_VALUE_EQUALITY_DOMAINS} + - Certain kinds of enums can compare equal to values of other types, so doing type math - the way `conditional_types` does will be misleading if you expect it to correspond to - conditions based on equality comparisons. +class EqualityDomainInfo(NamedTuple): + type_names: set[str] + enum_type_names: set[str] - For example, StrEnum classes can compare equal to str values. So if we see - `val: StrEnum; if val == "foo": ...` we currently avoid narrowing. - Note that we do wish to continue narrowing for `if val == StrEnum.MEMBER: ...` + +class EqualityValueInfo(NamedTuple): + domains: dict[str, EqualityDomainInfo] + is_top: bool + + +def closed_equality_domain_type_names(info: EqualityValueInfo) -> list[str]: + return [ + fullname + for fullname, domain in CLOSED_VALUE_EQUALITY_DOMAINS.items() + if domain in info.domains + ] + + +def partition_equality_ambiguous_types( + current_type: Type, target_type: Type, *, is_identity: bool +) -> tuple[Type | None, Type | None]: + """Split current_type into ordinary-narrowable and equality-ambiguous pieces. + + Some values compare equal through a value domain broader than their nominal type. For + example, an IntEnum member can compare equal to an int, and a StrEnum member can compare + equal to a str. When narrowing `x: MyStrEnum | str` against `MyStrEnum.MEMBER`, we can + still narrow the enum portion of the union, but we must keep the str portion in both + branches. """ - # We need these things for this to be ambiguous: - # (1) an IntEnum or StrEnum type or enum subclass of int or str - # (2) either a different IntEnum/StrEnum type or a non-enum type ("") - result = set() + if is_identity: + return current_type, None + + typ = get_proper_type(current_type) + items = typ.relevant_items() if isinstance(typ, UnionType) else [current_type] + narrowable_items = [] + ambiguous_items = [] + for item in items: + if is_equality_ambiguous_for_narrowing(item, target_type): + ambiguous_items.append(item) + else: + narrowable_items.append(item) + return ( + UnionType.make_union(narrowable_items) if narrowable_items else None, + UnionType.make_union(ambiguous_items) if ambiguous_items else None, + ) + + +def is_equality_ambiguous_for_narrowing(left: Type, right: Type) -> bool: + """Can left compare equal to right through a value domain outside nominal overlap?""" + left_info = equality_value_info(left) + right_info = equality_value_info(right) + + if left_info.is_top or right_info.is_top: + # Only open-domain enum values can make a top-like type ambiguous. + # Closed domains can be narrowed to their complete known set instead. + other_info = right_info if left_info.is_top else left_info + return any( + domain in OPEN_VALUE_EQUALITY_DOMAIN_NAMES and domain_info.enum_type_names + for domain, domain_info in other_info.domains.items() + ) + + shared_domains = left_info.domains.keys() & right_info.domains.keys() + if not shared_domains: + return False + + for domain in shared_domains: + left_domain = left_info.domains[domain] + right_domain = right_info.domains[domain] + # Equality between two values from the same enum can still narrow by literal member. + if ( + left_domain.enum_type_names + and left_domain.enum_type_names == right_domain.enum_type_names + and left_domain.type_names == left_domain.enum_type_names + and right_domain.type_names == right_domain.enum_type_names + ): + continue + # Different domain-member types may compare equal, but nominal narrowing would + # otherwise treat them as disjoint. + if left_domain.type_names != right_domain.type_names: + return True + # Same domain-member types are only ambiguous if an enum value may compare equal to + # its underlying value type. + if left_domain.enum_type_names or right_domain.enum_type_names: + return True + + return False + + +def equality_value_info(t: Type) -> EqualityValueInfo: t = get_proper_type(t) if isinstance(t, UnionType): - for item in t.items: - result.update(ambiguous_enum_equality_keys(item)) - elif isinstance(t, Instance): - if t.last_known_value: - result.update(ambiguous_enum_equality_keys(t.last_known_value)) - elif t.type.is_enum and any( - base.fullname in ("enum.IntEnum", "enum.StrEnum", "builtins.str", "builtins.int") - for base in t.type.mro - ): - result.add(t.type.fullname) - elif not t.type.is_enum: - # These might compare equal to IntEnum/StrEnum types (e.g. Decimal), so - # let's be conservative - result.add("") - elif isinstance(t, LiteralType): - result.update(ambiguous_enum_equality_keys(t.fallback)) - elif isinstance(t, NoneType): - pass - else: - result.add("") - return result + return combine_equality_value_info(equality_value_info(item) for item in t.items) + if isinstance(t, TypeVarType): + if t.values: + return combine_equality_value_info(equality_value_info(item) for item in t.values) + return equality_value_info(t.upper_bound) + if isinstance(t, Instance) and t.last_known_value is not None: + return equality_value_info(t.last_known_value) + if isinstance(t, LiteralType): + return equality_value_info(t.fallback) + if isinstance(t, Instance): + if t.type.fullname == "builtins.object": + return EqualityValueInfo({}, is_top=True) + + enum_type_names = {t.type.fullname} if t.type.is_enum else set() + domains = {} + for base in t.type.mro: + if domain := VALUE_EQUALITY_DOMAINS.get(base.fullname): + domains[domain] = EqualityDomainInfo({t.type.fullname}, enum_type_names) + + return EqualityValueInfo(domains, is_top=False) + if isinstance(t, AnyType): + return EqualityValueInfo({}, is_top=True) + return EqualityValueInfo({}, is_top=False) + + +def combine_equality_value_info(infos: Iterable[EqualityValueInfo]) -> EqualityValueInfo: + domains: dict[str, EqualityDomainInfo] = {} + is_top = False + for info in infos: + for domain, domain_info in info.domains.items(): + existing_domain_info = domains.get(domain) + if existing_domain_info is None: + domains[domain] = EqualityDomainInfo( + set(domain_info.type_names), set(domain_info.enum_type_names) + ) + else: + existing_domain_info.type_names.update(domain_info.type_names) + existing_domain_info.enum_type_names.update(domain_info.enum_type_names) + is_top = is_top or info.is_top + return EqualityValueInfo(domains, is_top) def is_typeddict_type_context(lvalue_type: Type) -> bool: diff --git a/test-data/unit/check-enum.test b/test-data/unit/check-enum.test index 51d982f35a44..55a20e2b3fa2 100644 --- a/test-data/unit/check-enum.test +++ b/test-data/unit/check-enum.test @@ -2786,17 +2786,16 @@ from typing import Literal # https://github.com/python/mypy/issues/18029 class Foo(enum.StrEnum): - FOO = 'a' + A = 'a' def f1(a: Foo | Literal['foo']) -> Foo: if a == 'foo': # Ideally this is narrowed to just Literal['foo'] (if we learn to narrow based on enum value) reveal_type(a) # N: Revealed type is "__main__.Foo | Literal['foo']" - return Foo.FOO + return Foo.A - # Ideally this passes - reveal_type(a) # N: Revealed type is "__main__.Foo | Literal['foo']" - return a # E: Incompatible return value type (got "Foo | Literal['foo']", expected "Foo") + reveal_type(a) # N: Revealed type is "__main__.Foo" + return a [builtins fixtures/primitives.pyi] [case testStrEnumEqualityAlias] @@ -2891,22 +2890,24 @@ def f2(x: int | Custom | E): def f3_simple(x: str | SE): if x == SE.A: - reveal_type(x) # N: Revealed type is "builtins.str | __main__.SE" + # Note the union gets simplified + reveal_type(x) # N: Revealed type is "builtins.str" + else: + reveal_type(x) # N: Revealed type is "builtins.str" def f3(x: int | str | SE): - # Ideally we filter out some of these ints if x == SE.A: - reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | __main__.SE" + reveal_type(x) # N: Revealed type is "builtins.str" else: - reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | __main__.SE" + reveal_type(x) # N: Revealed type is "builtins.int | builtins.str" if x in cast(list[SE], []): - reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | __main__.SE" + reveal_type(x) # N: Revealed type is "builtins.str" else: reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | __main__.SE" if x == str(): - reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | __main__.SE" + reveal_type(x) # N: Revealed type is "builtins.str" else: reveal_type(x) # N: Revealed type is "builtins.int | builtins.str | __main__.SE" @@ -2928,12 +2929,12 @@ def f4(x: int | Custom | SE): def f5(x: str | Custom | SE): if x == SE.A: - reveal_type(x) # N: Revealed type is "Literal[__main__.SE.A] | __main__.Custom" + reveal_type(x) # N: Revealed type is "builtins.str | __main__.Custom" else: reveal_type(x) # N: Revealed type is "builtins.str | __main__.Custom" if x in cast(list[SE], []): - reveal_type(x) # N: Revealed type is "__main__.SE | __main__.Custom" + reveal_type(x) # N: Revealed type is "builtins.str | __main__.Custom" else: reveal_type(x) # N: Revealed type is "builtins.str | __main__.Custom | __main__.SE" @@ -2943,3 +2944,41 @@ def f5(x: str | Custom | SE): reveal_type(x) # N: Revealed type is "builtins.str | __main__.Custom | __main__.SE" [builtins fixtures/primitives.pyi] + + +[case testNarrowingWithIntEnumAndStrEnumUnionAmbiguous] +# flags: --strict-equality --warn-unreachable +from __future__ import annotations +from enum import IntEnum, StrEnum + +class IE(IntEnum): + X = 1 + Y = 2 + +class SE(StrEnum): + A = "a" + B = "b" + +def f(x: IE | SE) -> None: + if x == IE.X: + reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X]" + else: + reveal_type(x) # N: Revealed type is "Literal[__main__.IE.Y] | __main__.SE" +[builtins fixtures/primitives.pyi] + + +[case testNarrowingWithStrEnumUnionNoneAmbiguous] +# flags: --strict-equality --warn-unreachable +from __future__ import annotations +from enum import StrEnum + +class SE(StrEnum): + A = "a" + B = "b" + +def f(x: SE | None) -> None: + if x == "a": + reveal_type(x) # N: Revealed type is "__main__.SE" + else: + reveal_type(x) # N: Revealed type is "__main__.SE | None" +[builtins fixtures/primitives.pyi] diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 8afcb682712e..048acf41fbf5 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2470,7 +2470,7 @@ def f3(x: object) -> None: def f4(x: int | Any) -> None: if x == IE.X: - reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X] | Any" + reveal_type(x) # N: Revealed type is "builtins.int | Any" else: reveal_type(x) # N: Revealed type is "builtins.int | Any" @@ -2513,9 +2513,9 @@ class E(Enum): def f1(x: IE | MyDecimal) -> None: if x == IE.X: - reveal_type(x) # N: Revealed type is "__main__.IE | __main__.MyDecimal" + reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X]" else: - reveal_type(x) # N: Revealed type is "__main__.IE | __main__.MyDecimal" + reveal_type(x) # N: Revealed type is "Literal[__main__.IE.Y] | __main__.MyDecimal" def f2(x: E | bytes) -> None: if x == E.X: @@ -2525,9 +2525,9 @@ def f2(x: E | bytes) -> None: def f3(x: IE | IE2) -> None: if x == IE.X: - reveal_type(x) # N: Revealed type is "__main__.IE | __main__.IE2" + reveal_type(x) # N: Revealed type is "Literal[__main__.IE.X] | __main__.IE2" else: - reveal_type(x) # N: Revealed type is "__main__.IE | __main__.IE2" + reveal_type(x) # N: Revealed type is "Literal[__main__.IE.Y] | __main__.IE2" def f4(x: IE | E) -> None: if x == IE.X: @@ -2598,6 +2598,39 @@ def f4(x: SE) -> None: reveal_type(x) # N: Revealed type is "Literal[__main__.SE.B]" [builtins fixtures/primitives.pyi] +[case testNarrowingWithBytesEnum] +# flags: --strict-equality --warn-unreachable +from enum import Enum + +class BE(bytes, Enum): + A = b'a' + B = b'b' + +def f1(x: object) -> None: + if x == BE.A: + reveal_type(x) # N: Revealed type is "Literal[__main__.BE.A]" + else: + reveal_type(x) # N: Revealed type is "builtins.object" + +def f2(x: BE) -> None: + if x == BE.A: + reveal_type(x) # N: Revealed type is "Literal[__main__.BE.A]" + else: + reveal_type(x) # N: Revealed type is "Literal[__main__.BE.B]" + +def f3(x: bytearray | BE) -> None: + if x == BE.A: + reveal_type(x) # N: Revealed type is "builtins.bytearray | Literal[__main__.BE.A]" + else: + reveal_type(x) # N: Revealed type is "builtins.bytearray | Literal[__main__.BE.B]" + +def f4(x: memoryview | BE) -> None: + if x == BE.A: + reveal_type(x) # N: Revealed type is "builtins.memoryview | Literal[__main__.BE.A]" + else: + reveal_type(x) # N: Revealed type is "builtins.memoryview | Literal[__main__.BE.B]" +[builtins fixtures/primitives.pyi] + [case testNarrowingWithEnumStrSubclass] # flags: --strict-equality --warn-unreachable from enum import Enum @@ -3731,14 +3764,13 @@ from typing import TypeVar, Any, Type TargetType = TypeVar("TargetType", int, float, str) -# TODO: this behaviour is incorrect, it will be fixed by improving reachability def convert_type(target_type: Type[TargetType]) -> TargetType: if target_type == str: return str() if target_type == int: return int() if target_type == float: - return float() # E: Incompatible return value type (got "float", expected "int") + return float() raise @@ -3762,7 +3794,7 @@ def f2(number: float, five: Literal[5]): def f3(number: float | int, five: Literal[5]): if number == five: - reveal_type(number) # N: Revealed type is "builtins.float | Literal[5]" + reveal_type(number) # N: Revealed type is "Literal[5] | builtins.float" reveal_type(five) # N: Revealed type is "Literal[5]" def f8(number: float | Literal[5], five: Literal[5]): @@ -3829,6 +3861,7 @@ def f(x: bytes | None): [case testNarrowingBytesLikeWithPromotion] # flags: --strict-equality --warn-unreachable --strict-bytes from __future__ import annotations +from typing import Any def check_test(x: bytes) -> None: ... check_test(bytearray(b"asdf")) # E: Argument 1 to "check_test" has incompatible type "bytearray"; expected "bytes" @@ -3838,6 +3871,8 @@ def main( v_bytearray: bytearray, v_memoryview: memoryview, v_all: bytes | bytearray | memoryview, + v_object: object, + v_any: Any, ) -> None: if v_bytes == v_bytearray: reveal_type(v_bytes) # N: Revealed type is "builtins.bytes" @@ -3850,7 +3885,7 @@ def main( reveal_type(v_memoryview) # N: Revealed type is "builtins.memoryview" if v_all == v_bytes: - reveal_type(v_all) # N: Revealed type is "builtins.bytes" + reveal_type(v_all) # N: Revealed type is "builtins.bytes | builtins.bytearray | builtins.memoryview" reveal_type(v_bytes) # N: Revealed type is "builtins.bytes" if v_all == v_bytearray: reveal_type(v_all) # N: Revealed type is "builtins.bytes | builtins.bytearray | builtins.memoryview" @@ -3858,11 +3893,21 @@ def main( if v_all == v_memoryview: reveal_type(v_all) # N: Revealed type is "builtins.bytes | builtins.bytearray | builtins.memoryview" reveal_type(v_memoryview) # N: Revealed type is "builtins.memoryview" + + if v_object == v_bytes: + reveal_type(v_object) # N: Revealed type is "builtins.bytes" + if v_object == b"asdf": + reveal_type(v_object) # N: Revealed type is "Literal[b'asdf']?" + if v_any == v_bytes: + reveal_type(v_any) # N: Revealed type is "Any" + if v_any == b"asdf": + reveal_type(v_any) # N: Revealed type is "Any" [builtins fixtures/primitives.pyi] [case testNarrowingBytesLikeNoPromotion] # flags: --strict-equality --warn-unreachable --no-strict-bytes from __future__ import annotations +from typing import Any def check_test(x: bytes) -> None: ... check_test(bytearray(b"asdf")) @@ -3872,6 +3917,8 @@ def main( v_bytearray: bytearray, v_memoryview: memoryview, v_all: bytes | bytearray | memoryview, + v_object: object, + v_any: Any, ) -> None: if v_bytes == v_bytearray: reveal_type(v_bytes) # N: Revealed type is "builtins.bytes" @@ -3892,6 +3939,15 @@ def main( if v_all == v_memoryview: reveal_type(v_all) # N: Revealed type is "builtins.bytes | builtins.bytearray | builtins.memoryview" reveal_type(v_memoryview) # N: Revealed type is "builtins.memoryview" + + if v_object == v_bytes: + reveal_type(v_object) # N: Revealed type is "builtins.bytes" + if v_object == b"asdf": + reveal_type(v_object) # N: Revealed type is "Literal[b'asdf']?" + if v_any == v_bytes: + reveal_type(v_any) # N: Revealed type is "Any" + if v_any == b"asdf": + reveal_type(v_any) # N: Revealed type is "Any" [builtins fixtures/primitives.pyi] diff --git a/test-data/unit/check-python310.test b/test-data/unit/check-python310.test index b9dc59fd1ebf..99e647aff5a5 100644 --- a/test-data/unit/check-python310.test +++ b/test-data/unit/check-python310.test @@ -43,7 +43,7 @@ match m: def foo(m: bool, x: int): match m: case 1: - reveal_type(m) # N: Revealed type is "Literal[1]" + reveal_type(m) # N: Revealed type is "builtins.bool" match m: case x: @@ -3917,3 +3917,35 @@ def bar(cls: type[Types.A | Types.B]) -> None: case _: assert_never(cls) [builtins fixtures/tuple.pyi] + + +[case testMatchEnumOrdering] +# flags: --strict-equality --warn-unreachable +# Regression test for https://github.com/python/mypy/issues/21187 +import enum +from typing import Literal + +class DummyClass: ... + +class MyEnum(enum.StrEnum): + RELEVANT = "relevant" + IGNORED = "ignored" + +def dummy_class_then_enum(arg: DummyClass | Literal[MyEnum.RELEVANT]): + match arg: + case DummyClass(): + reveal_type(arg) # N: Revealed type is "__main__.DummyClass" + case MyEnum.RELEVANT: + reveal_type(arg) # N: Revealed type is "Literal[__main__.MyEnum.RELEVANT]" + case _: + pass # E: Statement is unreachable + +def enum_then_dummy_class(arg: DummyClass | Literal[MyEnum.RELEVANT]): + match arg: + case MyEnum.RELEVANT: + reveal_type(arg) # N: Revealed type is "Literal[__main__.MyEnum.RELEVANT]" + case DummyClass(): + reveal_type(arg) # N: Revealed type is "__main__.DummyClass" + case _: + pass # E: Statement is unreachable +[builtins fixtures/tuple.pyi]