Skip to content

Commit 86a8989

Browse files
committed
bring async_.Routine back and support passing an optional task group
1 parent fc19af4 commit 86a8989

File tree

11 files changed

+156
-126
lines changed

11 files changed

+156
-126
lines changed

examples/mqtt/pub_sub.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66

77
async def main() -> None:
88
config = enapter.mqtt.Config(host="127.0.0.1", port=1883)
9-
async with asyncio.TaskGroup() as tg:
10-
client = enapter.mqtt.Client(tg, config=config)
11-
tg.create_task(subscriber(client))
12-
tg.create_task(publisher(client))
9+
async with enapter.mqtt.Client(config=config) as client:
10+
async with asyncio.TaskGroup() as tg:
11+
tg.create_task(subscriber(client))
12+
tg.create_task(publisher(client))
1313

1414

1515
async def subscriber(client: enapter.mqtt.Client) -> None:

examples/standalone/zigbee2mqtt/script.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,15 @@ async def run(self):
3636
tg.create_task(self.properties_sender())
3737

3838
async def consumer(self, tg):
39-
client = enapter.mqtt.Client(tg, self.mqtt_client_config)
40-
async with client.subscribe(self.mqtt_topic) as messages:
41-
async for msg in messages:
42-
try:
43-
self.telemetry = json.loads(msg.payload)
44-
except json.JSONDecodeError as e:
45-
await self.log.error(f"failed to decode json payload: {e}")
39+
async with enapter.mqtt.Client(
40+
self.mqtt_client_config, task_group=tg
41+
) as client:
42+
async with client.subscribe(self.mqtt_topic) as messages:
43+
async for msg in messages:
44+
try:
45+
self.telemetry = json.loads(msg.payload)
46+
except json.JSONDecodeError as e:
47+
await self.log.error(f"failed to decode json payload: {e}")
4648

4749
async def telemetry_sender(self):
4850
while True:

src/enapter/async_/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .generator import generator
2+
from .routine import Routine
23

3-
__all__ = ["generator"]
4+
__all__ = ["generator", "Routine"]

src/enapter/async_/routine.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import abc
2+
import asyncio
3+
import contextlib
4+
from typing import Self
5+
6+
7+
class Routine(abc.ABC):
8+
9+
def __init__(self, task_group: asyncio.TaskGroup | None) -> None:
10+
self._task_group = task_group
11+
self._task: asyncio.Task | None = None
12+
13+
@abc.abstractmethod
14+
async def _run(self) -> None:
15+
pass
16+
17+
async def __aenter__(self) -> Self:
18+
await self.start()
19+
return self
20+
21+
async def __aexit__(self, *_) -> None:
22+
await self.stop()
23+
24+
async def start(self) -> None:
25+
if self._task is not None:
26+
raise RuntimeError("already started")
27+
if self._task_group is None:
28+
self._task = asyncio.create_task(self._run())
29+
else:
30+
self._task = self._task_group.create_task(self._run())
31+
32+
async def stop(self) -> None:
33+
if self._task is None:
34+
raise RuntimeError("not started yet")
35+
self._task.cancel()
36+
with contextlib.suppress(asyncio.CancelledError):
37+
await self._task

src/enapter/mqtt/client.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,18 @@
1515
LOGGER = logging.getLogger(__name__)
1616

1717

18-
class Client:
18+
class Client(async_.Routine):
1919

20-
def __init__(self, task_group: asyncio.TaskGroup, config: Config) -> None:
20+
def __init__(
21+
self, config: Config, task_group: asyncio.TaskGroup | None = None
22+
) -> None:
23+
super().__init__(task_group=task_group)
2124
self._logger = self._new_logger(config)
2225
self._config = config
2326
self._mdns_resolver = mdns.Resolver()
2427
self._tls_context = self._new_tls_context(config)
2528
self._publisher: aiomqtt.Client | None = None
2629
self._publisher_connected = asyncio.Event()
27-
self._task = task_group.create_task(self._run())
28-
29-
def cancel(self) -> None:
30-
self._task.cancel()
3130

3231
@staticmethod
3332
def _new_logger(config: Config) -> logging.LoggerAdapter:

src/enapter/standalone/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from .app import App, run
21
from .config import Config
32
from .device import Device
43
from .device_protocol import (
@@ -10,9 +9,9 @@
109
Telemetry,
1110
)
1211
from .logger import Logger
12+
from .run import run
1313

1414
__all__ = [
15-
"App",
1615
"CommandArgs",
1716
"CommandResult",
1817
"Config",

src/enapter/standalone/app.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

src/enapter/standalone/device_driver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,22 @@
33
import time
44
import traceback
55

6-
from enapter import mqtt
6+
from enapter import async_, mqtt
77

88
from .device_protocol import DeviceProtocol
99

1010

11-
class DeviceDriver:
11+
class DeviceDriver(async_.Routine):
1212

1313
def __init__(
1414
self,
15-
task_group: asyncio.TaskGroup,
1615
device_channel: mqtt.api.DeviceChannel,
1716
device: DeviceProtocol,
17+
task_group: asyncio.TaskGroup | None,
1818
) -> None:
19+
super().__init__(task_group=task_group)
1920
self._device_channel = device_channel
2021
self._device = device
21-
self._task = task_group.create_task(self._run())
2222

2323
async def _run(self) -> None:
2424
async with asyncio.TaskGroup() as tg:

src/enapter/standalone/run.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import asyncio
2+
import contextlib
3+
4+
from enapter import log, mqtt
5+
6+
from .config import Config
7+
from .device_driver import DeviceDriver
8+
from .device_protocol import DeviceProtocol
9+
from .ucm import UCM
10+
11+
12+
async def run(device: DeviceProtocol) -> None:
13+
log.configure(level=log.LEVEL or "info")
14+
config = Config.from_env()
15+
async with contextlib.AsyncExitStack() as stack:
16+
task_group = await stack.enter_async_context(asyncio.TaskGroup())
17+
mqtt_client = await stack.enter_async_context(
18+
mqtt.Client(config=config.communication.mqtt, task_group=task_group)
19+
)
20+
_ = await stack.enter_async_context(
21+
DeviceDriver(
22+
device_channel=mqtt.api.DeviceChannel(
23+
client=mqtt_client,
24+
hardware_id=config.communication.hardware_id,
25+
channel_id=config.communication.channel_id,
26+
),
27+
device=device,
28+
task_group=task_group,
29+
)
30+
)
31+
if config.communication.ucm_needed:
32+
_ = await stack.enter_async_context(
33+
DeviceDriver(
34+
device_channel=mqtt.api.DeviceChannel(
35+
client=mqtt_client,
36+
hardware_id=config.communication.hardware_id,
37+
channel_id="ucm",
38+
),
39+
device=UCM(),
40+
task_group=task_group,
41+
)
42+
)
43+
await asyncio.Event().wait()

tests/integration/conftest.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,8 @@ async def fixture_enapter_mqtt_client(mosquitto_container):
2020
port=int(ports[0]["HostPort"]),
2121
)
2222
async with asyncio.TaskGroup() as tg:
23-
mqtt_client = enapter.mqtt.Client(tg, config)
24-
yield mqtt_client
25-
mqtt_client.cancel()
23+
async with enapter.mqtt.Client(config, task_group=tg) as mqtt_client:
24+
yield mqtt_client
2625

2726

2827
@pytest.fixture(scope="session", name="mosquitto_container")

0 commit comments

Comments
 (0)