From 4440b1f50cfa4e0e5f679e072adf2dd97f3aa495 Mon Sep 17 00:00:00 2001 From: Ryo Kitagawa Date: Sun, 27 Oct 2024 19:30:02 +0900 Subject: [PATCH] feat: check default value type of Parameter --- gokart/mypy.py | 29 ++++++++--------------------- test/test_mypy.py | 17 +++++++++++++++-- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/gokart/mypy.py b/gokart/mypy.py index b275d11e..122e8714 100644 --- a/gokart/mypy.py +++ b/gokart/mypy.py @@ -27,7 +27,6 @@ PlaceholderNode, RefExpr, Statement, - TempNode, TypeInfo, Var, ) @@ -118,7 +117,7 @@ class TaskOnKartAttribute: def __init__( self, name: str, - has_default: bool, + default_value: Optional[Expression], line: int, column: int, type: Type | None, @@ -126,7 +125,7 @@ def __init__( api: SemanticAnalyzerPluginInterface, ) -> None: self.name = name - self.has_default = has_default + self.default_value = default_value self.line = line self.column = column self.type = type # Type as __init__ argument @@ -141,7 +140,7 @@ def to_argument(self, current_info: TypeInfo, *, of: Literal['__init__',]) -> Ar return Argument( variable=self.to_var(current_info), type_annotation=self.expand_type(current_info), - initializer=EllipsisExpr() if self.has_default else None, # Only used by stubgen + initializer=self.default_value, kind=arg_kind, ) @@ -162,7 +161,7 @@ def serialize(self) -> JsonDict: assert self.type return { 'name': self.name, - 'has_default': self.has_default, + 'default_value': self.default_value, 'line': self.line, 'column': self.column, 'type': self.type.serialize(), @@ -302,23 +301,11 @@ def collect_attributes(self) -> Optional[list[TaskOnKartAttribute]]: assert isinstance(node, Var) has_parameter_call, parameter_args = self._collect_parameter_args(stmt.rvalue) - has_default = False + default_value: Optional[Expression] = None # Ensure that something like x: int = field() is rejected # after an attribute with a default. - if has_parameter_call: - has_default = 'default' in parameter_args - - # All other assignments are already type checked. - elif not isinstance(stmt.rvalue, TempNode): - has_default = True - - if not has_default: - # Make all non-default task_on_kart attributes implicit because they are de-facto - # set on self in the generated __init__(), not in the class body. On the other - # hand, we don't know how custom task_on_kart transforms initialize attributes, - # so we don't treat them as implicit. This is required to support descriptors - # (https://github.com/python/mypy/issues/14868). - sym.implicit = True + if has_parameter_call and 'default' in parameter_args: + default_value = parameter_args['default'] current_attr_names.add(lhs.name) with state.strict_optional_set(self._api.options.strict_optional): @@ -330,7 +317,7 @@ def collect_attributes(self) -> Optional[list[TaskOnKartAttribute]]: found_attrs[lhs.name] = TaskOnKartAttribute( name=lhs.name, - has_default=has_default, + default_value=default_value, line=stmt.line, column=stmt.column, type=init_type, diff --git a/test/test_mypy.py b/test/test_mypy.py index 74b83a84..7f822a90 100644 --- a/test/test_mypy.py +++ b/test/test_mypy.py @@ -79,7 +79,6 @@ class MyEnum(enum.Enum): FOO = enum.auto() class MyTask(gokart.TaskOnKart): - # NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it. foo = luigi.IntParameter() bar = luigi.DateParameter() baz = gokart.TaskInstanceParameter() @@ -110,7 +109,6 @@ def test_parameter_has_default_type_no_issue_pattern(self): import gokart class MyTask(gokart.TaskOnKart): - # NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it. foo = luigi.IntParameter() bar = luigi.DateParameter() baz = gokart.TaskInstanceParameter() @@ -122,3 +120,18 @@ class MyTask(gokart.TaskOnKart): test_file.flush() result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name]) self.assertIn('Success: no issues found', result[0]) + + def test_parameter_has_uncorrect_default_value(self): + test_code = """ +import luigi +import gokart + +class MyTask(gokart.TaskOnKart): + foo = luigi.IntParameter(default='s') +""" + with tempfile.NamedTemporaryFile(suffix='.py') as test_file: + test_file.write(test_code.encode('utf-8')) + test_file.flush() + result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name]) + print(result[0]) + self.assertIn('Incompatible default for argument "foo" (default has type "str", argument has type "int")', result[0])