File tree Expand file tree Collapse file tree 11 files changed +156
-126
lines changed Expand file tree Collapse file tree 11 files changed +156
-126
lines changed Original file line number Diff line number Diff line change 66
77async def main () -> None :
88 config = enapter .mqtt .Config (host = "127.0.0.1" , port = 1883 )
9- async with asyncio . TaskGroup ( ) as tg :
10- client = enapter . mqtt . Client ( tg , config = config )
11- tg .create_task (subscriber (client ))
12- tg .create_task (publisher (client ))
9+ async with enapter . mqtt . Client ( config = config ) as client :
10+ async with asyncio . TaskGroup () as tg :
11+ tg .create_task (subscriber (client ))
12+ tg .create_task (publisher (client ))
1313
1414
1515async def subscriber (client : enapter .mqtt .Client ) -> None :
Original file line number Diff line number Diff line change @@ -36,13 +36,15 @@ async def run(self):
3636 tg .create_task (self .properties_sender ())
3737
3838 async def consumer (self , tg ):
39- client = enapter .mqtt .Client (tg , self .mqtt_client_config )
40- async with client .subscribe (self .mqtt_topic ) as messages :
41- async for msg in messages :
42- try :
43- self .telemetry = json .loads (msg .payload )
44- except json .JSONDecodeError as e :
45- await self .log .error (f"failed to decode json payload: { e } " )
39+ async with enapter .mqtt .Client (
40+ self .mqtt_client_config , task_group = tg
41+ ) as client :
42+ async with client .subscribe (self .mqtt_topic ) as messages :
43+ async for msg in messages :
44+ try :
45+ self .telemetry = json .loads (msg .payload )
46+ except json .JSONDecodeError as e :
47+ await self .log .error (f"failed to decode json payload: { e } " )
4648
4749 async def telemetry_sender (self ):
4850 while True :
Original file line number Diff line number Diff line change 11from .generator import generator
2+ from .routine import Routine
23
3- __all__ = ["generator" ]
4+ __all__ = ["generator" , "Routine" ]
Original file line number Diff line number Diff line change 1+ import abc
2+ import asyncio
3+ import contextlib
4+ from typing import Self
5+
6+
7+ class Routine (abc .ABC ):
8+
9+ def __init__ (self , task_group : asyncio .TaskGroup | None ) -> None :
10+ self ._task_group = task_group
11+ self ._task : asyncio .Task | None = None
12+
13+ @abc .abstractmethod
14+ async def _run (self ) -> None :
15+ pass
16+
17+ async def __aenter__ (self ) -> Self :
18+ await self .start ()
19+ return self
20+
21+ async def __aexit__ (self , * _ ) -> None :
22+ await self .stop ()
23+
24+ async def start (self ) -> None :
25+ if self ._task is not None :
26+ raise RuntimeError ("already started" )
27+ if self ._task_group is None :
28+ self ._task = asyncio .create_task (self ._run ())
29+ else :
30+ self ._task = self ._task_group .create_task (self ._run ())
31+
32+ async def stop (self ) -> None :
33+ if self ._task is None :
34+ raise RuntimeError ("not started yet" )
35+ self ._task .cancel ()
36+ with contextlib .suppress (asyncio .CancelledError ):
37+ await self ._task
Original file line number Diff line number Diff line change 1515LOGGER = logging .getLogger (__name__ )
1616
1717
18- class Client :
18+ class Client ( async_ . Routine ) :
1919
20- def __init__ (self , task_group : asyncio .TaskGroup , config : Config ) -> None :
20+ def __init__ (
21+ self , config : Config , task_group : asyncio .TaskGroup | None = None
22+ ) -> None :
23+ super ().__init__ (task_group = task_group )
2124 self ._logger = self ._new_logger (config )
2225 self ._config = config
2326 self ._mdns_resolver = mdns .Resolver ()
2427 self ._tls_context = self ._new_tls_context (config )
2528 self ._publisher : aiomqtt .Client | None = None
2629 self ._publisher_connected = asyncio .Event ()
27- self ._task = task_group .create_task (self ._run ())
28-
29- def cancel (self ) -> None :
30- self ._task .cancel ()
3130
3231 @staticmethod
3332 def _new_logger (config : Config ) -> logging .LoggerAdapter :
Original file line number Diff line number Diff line change 1- from .app import App , run
21from .config import Config
32from .device import Device
43from .device_protocol import (
109 Telemetry ,
1110)
1211from .logger import Logger
12+ from .run import run
1313
1414__all__ = [
15- "App" ,
1615 "CommandArgs" ,
1716 "CommandResult" ,
1817 "Config" ,
Load Diff This file was deleted.
Original file line number Diff line number Diff line change 33import time
44import traceback
55
6- from enapter import mqtt
6+ from enapter import async_ , mqtt
77
88from .device_protocol import DeviceProtocol
99
1010
11- class DeviceDriver :
11+ class DeviceDriver ( async_ . Routine ) :
1212
1313 def __init__ (
1414 self ,
15- task_group : asyncio .TaskGroup ,
1615 device_channel : mqtt .api .DeviceChannel ,
1716 device : DeviceProtocol ,
17+ task_group : asyncio .TaskGroup | None ,
1818 ) -> None :
19+ super ().__init__ (task_group = task_group )
1920 self ._device_channel = device_channel
2021 self ._device = device
21- self ._task = task_group .create_task (self ._run ())
2222
2323 async def _run (self ) -> None :
2424 async with asyncio .TaskGroup () as tg :
Original file line number Diff line number Diff line change 1+ import asyncio
2+ import contextlib
3+
4+ from enapter import log , mqtt
5+
6+ from .config import Config
7+ from .device_driver import DeviceDriver
8+ from .device_protocol import DeviceProtocol
9+ from .ucm import UCM
10+
11+
12+ async def run (device : DeviceProtocol ) -> None :
13+ log .configure (level = log .LEVEL or "info" )
14+ config = Config .from_env ()
15+ async with contextlib .AsyncExitStack () as stack :
16+ task_group = await stack .enter_async_context (asyncio .TaskGroup ())
17+ mqtt_client = await stack .enter_async_context (
18+ mqtt .Client (config = config .communication .mqtt , task_group = task_group )
19+ )
20+ _ = await stack .enter_async_context (
21+ DeviceDriver (
22+ device_channel = mqtt .api .DeviceChannel (
23+ client = mqtt_client ,
24+ hardware_id = config .communication .hardware_id ,
25+ channel_id = config .communication .channel_id ,
26+ ),
27+ device = device ,
28+ task_group = task_group ,
29+ )
30+ )
31+ if config .communication .ucm_needed :
32+ _ = await stack .enter_async_context (
33+ DeviceDriver (
34+ device_channel = mqtt .api .DeviceChannel (
35+ client = mqtt_client ,
36+ hardware_id = config .communication .hardware_id ,
37+ channel_id = "ucm" ,
38+ ),
39+ device = UCM (),
40+ task_group = task_group ,
41+ )
42+ )
43+ await asyncio .Event ().wait ()
Original file line number Diff line number Diff line change @@ -20,9 +20,8 @@ async def fixture_enapter_mqtt_client(mosquitto_container):
2020 port = int (ports [0 ]["HostPort" ]),
2121 )
2222 async with asyncio .TaskGroup () as tg :
23- mqtt_client = enapter .mqtt .Client (tg , config )
24- yield mqtt_client
25- mqtt_client .cancel ()
23+ async with enapter .mqtt .Client (config , task_group = tg ) as mqtt_client :
24+ yield mqtt_client
2625
2726
2827@pytest .fixture (scope = "session" , name = "mosquitto_container" )
You can’t perform that action at this time.
0 commit comments