Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion mypy/stubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,21 @@ def from_funcitem(stub: nodes.FuncItem) -> Signature[nodes.Argument]:
elif stub_arg.kind == nodes.ARG_STAR:
stub_sig.varpos = stub_arg
elif stub_arg.kind == nodes.ARG_STAR2:
stub_sig.varkw = stub_arg
if stub_arg.variable.type is not None and isinstance(
(typed_dict_arg := mypy.types.get_proper_type(stub_arg.variable.type)),
mypy.types.TypedDictType,
):
for key_name, key_type in typed_dict_arg.items.items():
optional = key_name not in typed_dict_arg.required_keys
stub_sig.kwonly[key_name] = nodes.Argument(
nodes.Var(key_name, key_type),
type_annotation=key_type,
initializer=nodes.EllipsisExpr() if optional else None,
kind=nodes.ARG_NAMED_OPT if optional else nodes.ARG_NAMED,
pos_only=False,
)
else:
stub_sig.varkw = stub_arg
else:
raise AssertionError
return stub_sig
Expand Down
38 changes: 38 additions & 0 deletions mypy/test/teststubtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def __getitem__(self, typeargs: Any) -> object: ...
Literal = 0
NewType = 0
TypedDict = 0
Unpack = 0

class TypeVar:
def __init__(self, name, covariant: bool = ..., contravariant: bool = ...) -> None: ...
Expand Down Expand Up @@ -766,6 +767,43 @@ def test_varargs_varkwargs(self) -> Iterator[Case]:
error="k6",
)

@collect_cases
def test_kwargs_unpack_typeddict(self) -> Iterator[Case]:
yield Case(
stub="""
from typing import TypedDict, Unpack, type_check_only

@type_check_only
class _Args(TypedDict):
a: int
b: int

def f1(**kwargs: Unpack[_Args]) -> None: ...
""",
runtime="def f1(*, a, b): pass",
error=None,
)
yield Case(
stub="def f2(**kwargs: Unpack[_Args]) -> None: ...",
runtime="def f2(*, a, c): pass",
error="f2",
)
Comment thread
hauntsaninja marked this conversation as resolved.
yield Case(
stub="""
@type_check_only
class _OptionalArgs(TypedDict, total=False):
a: int

def f3(**kwargs: Unpack[_OptionalArgs]) -> None: ...
def f4(**kwargs: Unpack[_OptionalArgs]) -> None: ...
""",
runtime="""
def f3(*, a): pass
def f4(*, a=0): pass
""",
error="f3",
)

@collect_cases
def test_overload(self) -> Iterator[Case]:
yield Case(
Expand Down
Loading