From f544a6868466175900e2ec4145b0145ff9308b9c Mon Sep 17 00:00:00 2001 From: Jack Lashner Date: Mon, 16 Sep 2024 14:13:48 -0400 Subject: [PATCH 1/6] Initial mypy implementation --- mypy.ini | 23 ++++++++++++++++++ ocs/__init__.py | 4 ++-- ocs/agents/aggregator/agent.py | 2 +- ocs/agents/registry/agent.py | 44 +++++++++++++++++++++------------- ocs/ocs_agent.py | 39 +++++++++++++++++++++++++----- ocs/py.typed | 0 6 files changed, 86 insertions(+), 26 deletions(-) create mode 100644 mypy.ini create mode 100644 ocs/py.typed diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..1b737ae0 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,23 @@ +[mypy-txaio.*] +ignore_missing_imports = True + +[mypy-influxdb.*] +ignore_missing_imports = True + +[mypy-autobahn.*] +ignore_missing_imports = True + +[mypy-deprecation.*] +ignore_missing_imports = True + +[mypy-coverage.*] +ignore_missing_imports = True + +[mypy-so3g.*] +ignore_missing_imports = True + +[mypy-progress.*] +ignore_missing_imports = True + +[mypy-spt3g.*] +ignore_missing_imports = True \ No newline at end of file diff --git a/ocs/__init__.py b/ocs/__init__.py index 2cdec324..6baff5a7 100644 --- a/ocs/__init__.py +++ b/ocs/__init__.py @@ -8,7 +8,7 @@ try: # If setuptools_scm is installed (e.g. in a development environment with # an editable install), then use it to determine the version dynamically. - from setuptools_scm import get_version + from setuptools_scm import get_version # type: ignore # This will fail with LookupError if the package is not installed in # editable mode or if Git is not installed. @@ -16,7 +16,7 @@ except (ImportError, LookupError): # As a fallback, use the version that is hard-coded in the file. try: - from ocs._version import __version__ # noqa: F401 + from ocs._version import __version__ # type: ignore # noqa: F401 except ModuleNotFoundError: # The user is probably trying to run this without having installed # the package, so complain. diff --git a/ocs/agents/aggregator/agent.py b/ocs/agents/aggregator/agent.py index b1e2301a..1663efac 100644 --- a/ocs/agents/aggregator/agent.py +++ b/ocs/agents/aggregator/agent.py @@ -122,7 +122,7 @@ def record(self, session: ocs_agent.OpSession, params): except PermissionError: self.log.error("Unable to intialize Aggregator due to permission " "error, stopping twisted reactor") - reactor.callFromThread(reactor.stop) + reactor.callFromThread(reactor.stop) # type: ignore return False, "Aggregation not started" while self.aggregate: diff --git a/ocs/agents/registry/agent.py b/ocs/agents/registry/agent.py index 4ead257f..10915620 100644 --- a/ocs/agents/registry/agent.py +++ b/ocs/agents/registry/agent.py @@ -6,6 +6,8 @@ from ocs.ocs_feed import Feed import argparse +from typing import Dict, Any, Optional + class RegisteredAgent: """ @@ -28,15 +30,15 @@ class RegisteredAgent: docs from the ``ocs_agent`` module """ - def __init__(self, feed): + def __init__(self, feed: Dict[str, Any]) -> None: self.expired = False - self.time_expired = None + self.time_expired: Optional[float] = None self.last_updated = time.time() - self.op_codes = {} - self.agent_class = feed.get('agent_class') - self.agent_address = feed['agent_address'] + self.op_codes: Dict[str, int] = {} + self.agent_class: Optional[str] = feed.get('agent_class') + self.agent_address: str = feed['agent_address'] - def refresh(self, op_codes=None): + def refresh(self, op_codes: Optional[Dict[str, int]]=None) -> None: self.expired = False self.time_expired = None self.last_updated = time.time() @@ -44,13 +46,13 @@ def refresh(self, op_codes=None): if op_codes: self.op_codes.update(op_codes) - def expire(self): + def expire(self) -> None: self.expired = True self.time_expired = time.time() for k in self.op_codes: self.op_codes[k] = OpCode.EXPIRED.value - def encoded(self): + def encoded(self) -> Dict[str, Any]: return { 'expired': self.expired, 'time_expired': self.time_expired, @@ -85,7 +87,7 @@ class Registry: as expired. """ - def __init__(self, agent, args): + def __init__(self, agent: ocs_agent.OCSAgent, args: argparse.Namespace) -> None: self.log = agent.log self.agent = agent self.wait_time = args.wait_time @@ -94,7 +96,7 @@ def __init__(self, agent, args): self._run = False # Dict containing agent_data for each registered agent - self.registered_agents = {} + self.registered_agents: Dict[str, RegisteredAgent] = {} self.agent_timeout = 5.0 # Removes agent after 5 seconds of no heartbeat. self.agent.subscribe_on_start( @@ -108,7 +110,7 @@ def __init__(self, agent, args): self.agent.register_feed('agent_operations', record=True, agg_params=agg_params, buffer_time=0) - def _register_heartbeat(self, _data): + def _register_heartbeat(self, _data) -> None: """ Function that is called whenever a heartbeat is received from an agent. It will update that agent in the Registry's registered_agent dict. @@ -124,7 +126,7 @@ def _register_heartbeat(self, _data): if publish: self._publish_agent_ops(reg_agent) - def _publish_agent_ops(self, reg_agent): + def _publish_agent_ops(self, reg_agent: RegisteredAgent) -> None: """Publish a registered agent's OpCodes. Args: @@ -150,7 +152,11 @@ def _publish_agent_ops(self, reg_agent): @ocs_agent.param('test_mode', default=False, type=bool) @inlineCallbacks - def main(self, session: ocs_agent.OpSession, params): + def main( + self, + session: ocs_agent.OpSession, + params: Optional[Dict[str, Any]] + ) -> ocs_agent.InlineCallbackOpType: """main(test_mode=False) **Process** - Main run process for the Registry agent. This will loop @@ -214,13 +220,17 @@ def main(self, session: ocs_agent.OpSession, params): for agent in self.registered_agents.values(): self._publish_agent_ops(agent) - if params['test_mode']: + if params['test_mode']: # type: ignore break return True, "Stopped registry main process" @inlineCallbacks - def _stop_main(self, session, params): + def _stop_main( + self, + session: ocs_agent.OpSession, + params: Optional[Dict[str, Any]] + ) -> ocs_agent.InlineCallbackOpType: """Stop function for the 'main' process.""" yield if self._run: @@ -240,7 +250,7 @@ def _register_agent(self, session, agent_data): return True, "'register_agent' is deprecated" -def make_parser(parser=None): +def make_parser(parser: Optional[argparse.ArgumentParser]=None) -> argparse.ArgumentParser: if parser is None: parser = argparse.ArgumentParser() pgroup = parser.add_argument_group('Agent Options') @@ -249,7 +259,7 @@ def make_parser(parser=None): return parser -def main(args=None): +def main(args=None) -> None: parser = make_parser() args = site_config.parse_args(agent_class='RegistryAgent', parser=parser, diff --git a/ocs/ocs_agent.py b/ocs/ocs_agent.py index 5f17a002..2b1680d2 100644 --- a/ocs/ocs_agent.py +++ b/ocs/ocs_agent.py @@ -3,6 +3,8 @@ import txaio txaio.use_twisted() +from __future__ import annotations + from twisted.application.internet import backoffPolicy from twisted.internet import reactor, task, threads from twisted.internet.defer import inlineCallbacks, Deferred, DeferredList, FirstError, maybeDeferred @@ -18,6 +20,7 @@ from autobahn.exception import Disconnected from .ocs_twisted import in_reactor_context +import argparse import json import math import time @@ -28,9 +31,18 @@ from ocs import client_t from ocs import ocs_feed from ocs.base import OpCode +from typing import Tuple, Optional, Callable, Dict, Any, Union, TypeVar, Generator, Union + + +OpReturnType = Union[Tuple[bool, str], Deferred[Tuple[bool, str]]] +OpFuncType = Callable[["OpSession", Optional[Dict[str, Any]]], OpReturnType] +InlineCallbackOpType = Generator[Any, Any, OpReturnType] -def init_site_agent(args, address=None): +def init_site_agent( + args: argparse.Namespace, + address: Optional[str] = None +) -> Tuple[OCSAgent, ApplicationRunner]: """ Create ApplicationSession and ApplicationRunner instances, set up to communicate on the chosen WAMP realm. @@ -432,8 +444,15 @@ def _management_handler(self, q, **kwargs): if q == 'get_agent_class': return self.class_name - def register_task(self, name, func, aborter=None, blocking=True, - aborter_blocking=None, startup=False): + def register_task( + self, + name: str, + func: OpFuncType, + aborter: Optional[OpFuncType] = None, + blocking: bool = True, + aborter_blocking: Optional[bool] = None, + startup: Union[bool, Dict[str, Any]] = False + ) -> None: """Register a Task for this agent. Args: @@ -474,8 +493,15 @@ def register_task(self, name, func, aborter=None, blocking=True, if startup is not False: self.startup_ops.append(('task', name, startup)) - def register_process(self, name, start_func, stop_func, blocking=True, - stopper_blocking=None, startup=False): + def register_process( + self, + name: str, + start_func: OpFuncType, + stop_func: OpFuncType, + blocking: bool = True, + stopper_blocking: Optional[bool] = None, + startup: Union[bool, Dict[str, Any]] = False + ) -> None: """Register a Process for this agent. Args: @@ -1519,8 +1545,9 @@ def check_for_strays(self, ignore=[]): if len(weird_args): raise ParamError(f"params included unexpected values: {weird_args}") +F = TypeVar('F') -def param(key, **kwargs): +def param(key, **kwargs) -> Callable[[F], F]: """Decorator for Agent operation functions to assist with checking params prior to actually trying to execute the code. Example:: diff --git a/ocs/py.typed b/ocs/py.typed new file mode 100644 index 00000000..e69de29b From 315a294303bcc5f3b572aeee6d2054c7f6293e90 Mon Sep 17 00:00:00 2001 From: Jack Lashner Date: Mon, 16 Sep 2024 14:33:17 -0400 Subject: [PATCH 2/6] Fix OCS OpFuncType --- mypy.ini | 2 +- ocs/ocs_agent.py | 8 +++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/mypy.ini b/mypy.ini index 1b737ae0..b1e5e09d 100644 --- a/mypy.ini +++ b/mypy.ini @@ -20,4 +20,4 @@ ignore_missing_imports = True ignore_missing_imports = True [mypy-spt3g.*] -ignore_missing_imports = True \ No newline at end of file +ignore_missing_imports = True diff --git a/ocs/ocs_agent.py b/ocs/ocs_agent.py index 2b1680d2..20b52ea8 100644 --- a/ocs/ocs_agent.py +++ b/ocs/ocs_agent.py @@ -34,9 +34,11 @@ from typing import Tuple, Optional, Callable, Dict, Any, Union, TypeVar, Generator, Union -OpReturnType = Union[Tuple[bool, str], Deferred[Tuple[bool, str]]] -OpFuncType = Callable[["OpSession", Optional[Dict[str, Any]]], OpReturnType] -InlineCallbackOpType = Generator[Any, Any, OpReturnType] +OpFuncType = Union[ + Callable[["OpSession", Optional[Dict[str, Any]]], Tuple[bool, str]], + Callable[["OpSession", Optional[Dict[str, Any]]], Deferred[Tuple[bool, str]]], +] +InlineCallbackOpType = Generator[Any, Any, Tuple[bool, str]] def init_site_agent( From 5bc4ccd9a0fb4122c2a32a3720d5522bc09e0b74 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 Sep 2024 18:37:22 +0000 Subject: [PATCH 3/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- ocs/agents/registry/agent.py | 4 ++-- ocs/ocs_agent.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ocs/agents/registry/agent.py b/ocs/agents/registry/agent.py index 10915620..7ad19d98 100644 --- a/ocs/agents/registry/agent.py +++ b/ocs/agents/registry/agent.py @@ -38,7 +38,7 @@ def __init__(self, feed: Dict[str, Any]) -> None: self.agent_class: Optional[str] = feed.get('agent_class') self.agent_address: str = feed['agent_address'] - def refresh(self, op_codes: Optional[Dict[str, int]]=None) -> None: + def refresh(self, op_codes: Optional[Dict[str, int]] = None) -> None: self.expired = False self.time_expired = None self.last_updated = time.time() @@ -250,7 +250,7 @@ def _register_agent(self, session, agent_data): return True, "'register_agent' is deprecated" -def make_parser(parser: Optional[argparse.ArgumentParser]=None) -> argparse.ArgumentParser: +def make_parser(parser: Optional[argparse.ArgumentParser] = None) -> argparse.ArgumentParser: if parser is None: parser = argparse.ArgumentParser() pgroup = parser.add_argument_group('Agent Options') diff --git a/ocs/ocs_agent.py b/ocs/ocs_agent.py index 20b52ea8..c5d84c22 100644 --- a/ocs/ocs_agent.py +++ b/ocs/ocs_agent.py @@ -1547,8 +1547,10 @@ def check_for_strays(self, ignore=[]): if len(weird_args): raise ParamError(f"params included unexpected values: {weird_args}") + F = TypeVar('F') + def param(key, **kwargs) -> Callable[[F], F]: """Decorator for Agent operation functions to assist with checking params prior to actually trying to execute the code. Example:: From 841aca7b774b5ffef989c00bf53222469abe85d6 Mon Sep 17 00:00:00 2001 From: Jack Lashner Date: Mon, 16 Sep 2024 14:57:17 -0400 Subject: [PATCH 4/6] Move __futures__ import --- ocs/ocs_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocs/ocs_agent.py b/ocs/ocs_agent.py index c5d84c22..e2fabcdc 100644 --- a/ocs/ocs_agent.py +++ b/ocs/ocs_agent.py @@ -1,9 +1,9 @@ +from __future__ import annotations import ocs import txaio txaio.use_twisted() -from __future__ import annotations from twisted.application.internet import backoffPolicy from twisted.internet import reactor, task, threads From 43b27f80b36ea6264b130ce61caa01ec5b5eb887 Mon Sep 17 00:00:00 2001 From: Jack Lashner Date: Mon, 16 Sep 2024 14:58:27 -0400 Subject: [PATCH 5/6] Flake 8 --- ocs/ocs_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocs/ocs_agent.py b/ocs/ocs_agent.py index e2fabcdc..315b0fad 100644 --- a/ocs/ocs_agent.py +++ b/ocs/ocs_agent.py @@ -31,7 +31,7 @@ from ocs import client_t from ocs import ocs_feed from ocs.base import OpCode -from typing import Tuple, Optional, Callable, Dict, Any, Union, TypeVar, Generator, Union +from typing import Tuple, Optional, Callable, Dict, Any, Union, TypeVar, Generator OpFuncType = Union[ From 1b7ed323cf7906a9d076dcab9a8a532ec888fcf7 Mon Sep 17 00:00:00 2001 From: Brian Koopman Date: Mon, 16 Sep 2024 18:18:22 -0400 Subject: [PATCH 6/6] Add mypy to pre-commit --- .pre-commit-config.yaml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fe214d7d..430d0265 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,3 +14,11 @@ repos: rev: 7.1.1 hooks: - id: flake8 +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.11.2 + hooks: + - id: mypy + additional_dependencies: + - types-colorama + - types-PyYAML + - types-requests