Skip to content

Check default value type of Parameter #405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 8 additions & 21 deletions gokart/mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
PlaceholderNode,
RefExpr,
Statement,
TempNode,
TypeInfo,
Var,
)
Expand Down Expand Up @@ -118,15 +117,15 @@ class TaskOnKartAttribute:
def __init__(
self,
name: str,
has_default: bool,
default_value: Optional[Expression],
line: int,
column: int,
type: Type | None,
info: TypeInfo,
api: SemanticAnalyzerPluginInterface,
) -> None:
self.name = name
self.has_default = has_default
self.default_value = default_value
self.line = line
self.column = column
self.type = type # Type as __init__ argument
Expand All @@ -141,7 +140,7 @@ def to_argument(self, current_info: TypeInfo, *, of: Literal['__init__',]) -> Ar
return Argument(
variable=self.to_var(current_info),
type_annotation=self.expand_type(current_info),
initializer=EllipsisExpr() if self.has_default else None, # Only used by stubgen
initializer=self.default_value,
kind=arg_kind,
)

Expand All @@ -162,7 +161,7 @@ def serialize(self) -> JsonDict:
assert self.type
return {
'name': self.name,
'has_default': self.has_default,
'default_value': self.default_value,
'line': self.line,
'column': self.column,
'type': self.type.serialize(),
Expand Down Expand Up @@ -302,23 +301,11 @@ def collect_attributes(self) -> Optional[list[TaskOnKartAttribute]]:
assert isinstance(node, Var)

has_parameter_call, parameter_args = self._collect_parameter_args(stmt.rvalue)
has_default = False
default_value: Optional[Expression] = None
# Ensure that something like x: int = field() is rejected
# after an attribute with a default.
if has_parameter_call:
has_default = 'default' in parameter_args

# All other assignments are already type checked.
elif not isinstance(stmt.rvalue, TempNode):
has_default = True

if not has_default:
# Make all non-default task_on_kart attributes implicit because they are de-facto
# set on self in the generated __init__(), not in the class body. On the other
# hand, we don't know how custom task_on_kart transforms initialize attributes,
# so we don't treat them as implicit. This is required to support descriptors
# (https://github.com/python/mypy/issues/14868).
sym.implicit = True
if has_parameter_call and 'default' in parameter_args:
default_value = parameter_args['default']

current_attr_names.add(lhs.name)
with state.strict_optional_set(self._api.options.strict_optional):
Expand All @@ -330,7 +317,7 @@ def collect_attributes(self) -> Optional[list[TaskOnKartAttribute]]:

found_attrs[lhs.name] = TaskOnKartAttribute(
name=lhs.name,
has_default=has_default,
default_value=default_value,
line=stmt.line,
column=stmt.column,
type=init_type,
Expand Down
17 changes: 15 additions & 2 deletions test/test_mypy.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ class MyEnum(enum.Enum):
FOO = enum.auto()

class MyTask(gokart.TaskOnKart):
# NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

foo = luigi.IntParameter()
bar = luigi.DateParameter()
baz = gokart.TaskInstanceParameter()
Expand Down Expand Up @@ -110,7 +109,6 @@ def test_parameter_has_default_type_no_issue_pattern(self):
import gokart

class MyTask(gokart.TaskOnKart):
# NOTE: mypy shows attr-defined error for the following lines, so we need to ignore it.
foo = luigi.IntParameter()
bar = luigi.DateParameter()
baz = gokart.TaskInstanceParameter()
Expand All @@ -122,3 +120,18 @@ class MyTask(gokart.TaskOnKart):
test_file.flush()
result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
self.assertIn('Success: no issues found', result[0])

def test_parameter_has_uncorrect_default_value(self):
test_code = """
import luigi
import gokart

class MyTask(gokart.TaskOnKart):
foo = luigi.IntParameter(default='s')
"""
with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
test_file.write(test_code.encode('utf-8'))
test_file.flush()
result = api.run(['--show-traceback', '--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
print(result[0])
self.assertIn('Incompatible default for argument "foo" (default has type "str", argument has type "int")', result[0])
Loading