diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fe214d7d..430d0265 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,3 +14,11 @@ repos: rev: 7.1.1 hooks: - id: flake8 +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.11.2 + hooks: + - id: mypy + additional_dependencies: + - types-colorama + - types-PyYAML + - types-requests diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..b1e5e09d --- /dev/null +++ b/mypy.ini @@ -0,0 +1,23 @@ +[mypy-txaio.*] +ignore_missing_imports = True + +[mypy-influxdb.*] +ignore_missing_imports = True + +[mypy-autobahn.*] +ignore_missing_imports = True + +[mypy-deprecation.*] +ignore_missing_imports = True + +[mypy-coverage.*] +ignore_missing_imports = True + +[mypy-so3g.*] +ignore_missing_imports = True + +[mypy-progress.*] +ignore_missing_imports = True + +[mypy-spt3g.*] +ignore_missing_imports = True diff --git a/ocs/__init__.py b/ocs/__init__.py index 2cdec324..6baff5a7 100644 --- a/ocs/__init__.py +++ b/ocs/__init__.py @@ -8,7 +8,7 @@ try: # If setuptools_scm is installed (e.g. in a development environment with # an editable install), then use it to determine the version dynamically. - from setuptools_scm import get_version + from setuptools_scm import get_version # type: ignore # This will fail with LookupError if the package is not installed in # editable mode or if Git is not installed. @@ -16,7 +16,7 @@ except (ImportError, LookupError): # As a fallback, use the version that is hard-coded in the file. try: - from ocs._version import __version__ # noqa: F401 + from ocs._version import __version__ # type: ignore # noqa: F401 except ModuleNotFoundError: # The user is probably trying to run this without having installed # the package, so complain. diff --git a/ocs/agents/aggregator/agent.py b/ocs/agents/aggregator/agent.py index b1e2301a..1663efac 100644 --- a/ocs/agents/aggregator/agent.py +++ b/ocs/agents/aggregator/agent.py @@ -122,7 +122,7 @@ def record(self, session: ocs_agent.OpSession, params): except PermissionError: self.log.error("Unable to intialize Aggregator due to permission " "error, stopping twisted reactor") - reactor.callFromThread(reactor.stop) + reactor.callFromThread(reactor.stop) # type: ignore return False, "Aggregation not started" while self.aggregate: diff --git a/ocs/agents/registry/agent.py b/ocs/agents/registry/agent.py index 4ead257f..7ad19d98 100644 --- a/ocs/agents/registry/agent.py +++ b/ocs/agents/registry/agent.py @@ -6,6 +6,8 @@ from ocs.ocs_feed import Feed import argparse +from typing import Dict, Any, Optional + class RegisteredAgent: """ @@ -28,15 +30,15 @@ class RegisteredAgent: docs from the ``ocs_agent`` module """ - def __init__(self, feed): + def __init__(self, feed: Dict[str, Any]) -> None: self.expired = False - self.time_expired = None + self.time_expired: Optional[float] = None self.last_updated = time.time() - self.op_codes = {} - self.agent_class = feed.get('agent_class') - self.agent_address = feed['agent_address'] + self.op_codes: Dict[str, int] = {} + self.agent_class: Optional[str] = feed.get('agent_class') + self.agent_address: str = feed['agent_address'] - def refresh(self, op_codes=None): + def refresh(self, op_codes: Optional[Dict[str, int]] = None) -> None: self.expired = False self.time_expired = None self.last_updated = time.time() @@ -44,13 +46,13 @@ def refresh(self, op_codes=None): if op_codes: self.op_codes.update(op_codes) - def expire(self): + def expire(self) -> None: self.expired = True self.time_expired = time.time() for k in self.op_codes: self.op_codes[k] = OpCode.EXPIRED.value - def encoded(self): + def encoded(self) -> Dict[str, Any]: return { 'expired': self.expired, 'time_expired': self.time_expired, @@ -85,7 +87,7 @@ class Registry: as expired. """ - def __init__(self, agent, args): + def __init__(self, agent: ocs_agent.OCSAgent, args: argparse.Namespace) -> None: self.log = agent.log self.agent = agent self.wait_time = args.wait_time @@ -94,7 +96,7 @@ def __init__(self, agent, args): self._run = False # Dict containing agent_data for each registered agent - self.registered_agents = {} + self.registered_agents: Dict[str, RegisteredAgent] = {} self.agent_timeout = 5.0 # Removes agent after 5 seconds of no heartbeat. self.agent.subscribe_on_start( @@ -108,7 +110,7 @@ def __init__(self, agent, args): self.agent.register_feed('agent_operations', record=True, agg_params=agg_params, buffer_time=0) - def _register_heartbeat(self, _data): + def _register_heartbeat(self, _data) -> None: """ Function that is called whenever a heartbeat is received from an agent. It will update that agent in the Registry's registered_agent dict. @@ -124,7 +126,7 @@ def _register_heartbeat(self, _data): if publish: self._publish_agent_ops(reg_agent) - def _publish_agent_ops(self, reg_agent): + def _publish_agent_ops(self, reg_agent: RegisteredAgent) -> None: """Publish a registered agent's OpCodes. Args: @@ -150,7 +152,11 @@ def _publish_agent_ops(self, reg_agent): @ocs_agent.param('test_mode', default=False, type=bool) @inlineCallbacks - def main(self, session: ocs_agent.OpSession, params): + def main( + self, + session: ocs_agent.OpSession, + params: Optional[Dict[str, Any]] + ) -> ocs_agent.InlineCallbackOpType: """main(test_mode=False) **Process** - Main run process for the Registry agent. This will loop @@ -214,13 +220,17 @@ def main(self, session: ocs_agent.OpSession, params): for agent in self.registered_agents.values(): self._publish_agent_ops(agent) - if params['test_mode']: + if params['test_mode']: # type: ignore break return True, "Stopped registry main process" @inlineCallbacks - def _stop_main(self, session, params): + def _stop_main( + self, + session: ocs_agent.OpSession, + params: Optional[Dict[str, Any]] + ) -> ocs_agent.InlineCallbackOpType: """Stop function for the 'main' process.""" yield if self._run: @@ -240,7 +250,7 @@ def _register_agent(self, session, agent_data): return True, "'register_agent' is deprecated" -def make_parser(parser=None): +def make_parser(parser: Optional[argparse.ArgumentParser] = None) -> argparse.ArgumentParser: if parser is None: parser = argparse.ArgumentParser() pgroup = parser.add_argument_group('Agent Options') @@ -249,7 +259,7 @@ def make_parser(parser=None): return parser -def main(args=None): +def main(args=None) -> None: parser = make_parser() args = site_config.parse_args(agent_class='RegistryAgent', parser=parser, diff --git a/ocs/ocs_agent.py b/ocs/ocs_agent.py index 5f17a002..315b0fad 100644 --- a/ocs/ocs_agent.py +++ b/ocs/ocs_agent.py @@ -1,8 +1,10 @@ +from __future__ import annotations import ocs import txaio txaio.use_twisted() + from twisted.application.internet import backoffPolicy from twisted.internet import reactor, task, threads from twisted.internet.defer import inlineCallbacks, Deferred, DeferredList, FirstError, maybeDeferred @@ -18,6 +20,7 @@ from autobahn.exception import Disconnected from .ocs_twisted import in_reactor_context +import argparse import json import math import time @@ -28,9 +31,20 @@ from ocs import client_t from ocs import ocs_feed from ocs.base import OpCode +from typing import Tuple, Optional, Callable, Dict, Any, Union, TypeVar, Generator + + +OpFuncType = Union[ + Callable[["OpSession", Optional[Dict[str, Any]]], Tuple[bool, str]], + Callable[["OpSession", Optional[Dict[str, Any]]], Deferred[Tuple[bool, str]]], +] +InlineCallbackOpType = Generator[Any, Any, Tuple[bool, str]] -def init_site_agent(args, address=None): +def init_site_agent( + args: argparse.Namespace, + address: Optional[str] = None +) -> Tuple[OCSAgent, ApplicationRunner]: """ Create ApplicationSession and ApplicationRunner instances, set up to communicate on the chosen WAMP realm. @@ -432,8 +446,15 @@ def _management_handler(self, q, **kwargs): if q == 'get_agent_class': return self.class_name - def register_task(self, name, func, aborter=None, blocking=True, - aborter_blocking=None, startup=False): + def register_task( + self, + name: str, + func: OpFuncType, + aborter: Optional[OpFuncType] = None, + blocking: bool = True, + aborter_blocking: Optional[bool] = None, + startup: Union[bool, Dict[str, Any]] = False + ) -> None: """Register a Task for this agent. Args: @@ -474,8 +495,15 @@ def register_task(self, name, func, aborter=None, blocking=True, if startup is not False: self.startup_ops.append(('task', name, startup)) - def register_process(self, name, start_func, stop_func, blocking=True, - stopper_blocking=None, startup=False): + def register_process( + self, + name: str, + start_func: OpFuncType, + stop_func: OpFuncType, + blocking: bool = True, + stopper_blocking: Optional[bool] = None, + startup: Union[bool, Dict[str, Any]] = False + ) -> None: """Register a Process for this agent. Args: @@ -1520,7 +1548,10 @@ def check_for_strays(self, ignore=[]): raise ParamError(f"params included unexpected values: {weird_args}") -def param(key, **kwargs): +F = TypeVar('F') + + +def param(key, **kwargs) -> Callable[[F], F]: """Decorator for Agent operation functions to assist with checking params prior to actually trying to execute the code. Example:: diff --git a/ocs/py.typed b/ocs/py.typed new file mode 100644 index 00000000..e69de29b