diff --git a/bumble/controller.py b/bumble/controller.py index 4f687956..a3c99bd6 100644 --- a/bumble/controller.py +++ b/bumble/controller.py @@ -25,12 +25,14 @@ import struct from typing import TYPE_CHECKING, Any, Optional, Union, cast -from bumble import hci, lmp +from bumble import hci +from bumble import link +from bumble import link as bumble_link +from bumble import ll, lmp from bumble.colors import color from bumble.core import PhysicalTransport if TYPE_CHECKING: - from bumble.link import LocalLink from bumble.transport.common import TransportSink # ----------------------------------------------------------------------------- @@ -56,6 +58,128 @@ class CisLink: data_paths: set[int] = dataclasses.field(default_factory=set) +# ----------------------------------------------------------------------------- +@dataclasses.dataclass +class LegacyAdvertiser: + controller: Controller + advertising_interval_min: int = 0 + advertising_interval_max: int = 0 + advertising_type: int = 0 + own_address_type: int = 0 + peer_address_type: int = 0 + peer_address: hci.Address = hci.Address.ANY + advertising_channel_map: int = 0 + advertising_filter_policy: int = 0 + + advertising_data: bytes = b'' + scan_response_data: bytes = b'' + + enabled: bool = False + timer_handle: Optional[asyncio.Handle] = None + + @property + def address(self) -> hci.Address: + '''Address used in advertising PDU.''' + if self.own_address_type == hci.Address.PUBLIC_DEVICE_ADDRESS: + return self.controller.public_address + else: + return self.controller.random_address + + def _on_timer_fired(self) -> None: + self.send_advertising_data() + self.timer_handle = asyncio.get_running_loop().call_later( + self.advertising_interval_min / 1000.0, self._on_timer_fired + ) + + def start(self) -> None: + # Stop any ongoing advertising before we start again + self.stop() + self.enabled = True + + # Advertise now + self.timer_handle = asyncio.get_running_loop().call_soon(self._on_timer_fired) + + def stop(self) -> None: + if self.timer_handle is not None: + self.timer_handle.cancel() + self.timer_handle = None + self.enabled = False + + def send_advertising_data(self) -> None: + if not self.enabled: + return + + if ( + self.advertising_type + == hci.HCI_LE_Set_Advertising_Parameters_Command.AdvertisingType.ADV_IND + ): + self.controller.send_advertising_pdu( + ll.AdvInd( + advertiser_address=self.address, + data=self.advertising_data, + ) + ) + + +# ----------------------------------------------------------------------------- +@dataclasses.dataclass +class AdvertisingSet: + controller: Controller + handle: int + parameters: Optional[hci.HCI_LE_Set_Extended_Advertising_Parameters_Command] = None + data: bytearray = dataclasses.field(default_factory=bytearray) + scan_response_data: bytearray = dataclasses.field(default_factory=bytearray) + enabled: bool = False + timer_handle: Optional[asyncio.Handle] = None + random_address: Optional[hci.Address] = None + + @property + def address(self) -> hci.Address | None: + '''Address used in advertising PDU.''' + if not self.parameters: + return None + if self.parameters.own_address_type == hci.Address.PUBLIC_DEVICE_ADDRESS: + return self.controller.public_address + else: + return self.random_address + + def _on_extended_advertising_timer_fired(self) -> None: + if not self.enabled: + return + + self.send_extended_advertising_data() + + interval = ( + self.parameters.primary_advertising_interval_min * 0.625 / 1000.0 + if self.parameters + else 1.0 + ) + self.timer_handle = asyncio.get_running_loop().call_later( + interval, self._on_extended_advertising_timer_fired + ) + + def start(self) -> None: + self.enabled = True + asyncio.get_running_loop().call_soon(self._on_extended_advertising_timer_fired) + + def stop(self) -> None: + self.enabled = False + if timer_handle := self.timer_handle: + timer_handle.cancel() + self.timer_handle = None + + def send_extended_advertising_data(self) -> None: + if self.controller.link: + properties = ( + self.parameters.advertising_event_properties if self.parameters else 0 + ) + + address = self.address + assert address + + self.controller.send_advertising_pdu(ll.AdvInd(address, bytes(self.data))) + + # ----------------------------------------------------------------------------- @dataclasses.dataclass class ScoLink: @@ -70,9 +194,10 @@ class Connection: controller: Controller handle: int role: hci.Role + self_address: hci.Address peer_address: hci.Address - link: Any - transport: int + link: link.LocalLink + transport: PhysicalTransport link_type: int classic_allow_role_switch: bool = False @@ -93,22 +218,27 @@ def on_acl_pdu(self, pdu: bytes) -> None: self.controller, self.peer_address, self.transport, pdu ) + def send_ll_control_pdu(self, packet: ll.ControlPdu) -> None: + if self.link: + self.link.send_ll_control_pdu( + sender_address=self.self_address, + receiver_address=self.peer_address, + packet=packet, + ) + # ----------------------------------------------------------------------------- class Controller: hci_sink: Optional[TransportSink] = None - central_connections: dict[ - hci.Address, Connection - ] # Connections where this controller is the central - peripheral_connections: dict[ - hci.Address, Connection - ] # Connections where this controller is the peripheral + le_connections: dict[hci.Address, Connection] # LE Connections classic_connections: dict[hci.Address, Connection] # Connections in BR/EDR classic_pending_commands: dict[hci.Address, dict[lmp.Opcode, asyncio.Future[int]]] sco_links: dict[hci.Address, ScoLink] # SCO links by address central_cis_links: dict[int, CisLink] # CIS links by handle peripheral_cis_links: dict[int, CisLink] # CIS links by handle + advertising_sets: dict[int, AdvertisingSet] # Advertising sets by handle + le_legacy_advertiser: LegacyAdvertiser hci_version: int = hci.HCI_VERSION_BLUETOOTH_CORE_5_0 hci_revision: int = 0 @@ -131,10 +261,20 @@ class Controller: '30f0f9ff01008004002000000000000000000000000000000000000000000000' ) le_event_mask: int = 0 - advertising_parameters: Optional[hci.HCI_LE_Set_Advertising_Parameters_Command] = ( - None + le_features: hci.LeFeatureMask = ( + hci.LeFeatureMask.LE_ENCRYPTION + | hci.LeFeatureMask.CONNECTION_PARAMETERS_REQUEST_PROCEDURE + | hci.LeFeatureMask.EXTENDED_REJECT_INDICATION + | hci.LeFeatureMask.PERIPHERAL_INITIATED_FEATURE_EXCHANGE + | hci.LeFeatureMask.LE_PING + | hci.LeFeatureMask.LE_DATA_PACKET_LENGTH_EXTENSION + | hci.LeFeatureMask.LL_PRIVACY + | hci.LeFeatureMask.EXTENDED_SCANNER_FILTER_POLICIES + | hci.LeFeatureMask.LE_2M_PHY + | hci.LeFeatureMask.LE_CODED_PHY + | hci.LeFeatureMask.CHANNEL_SELECTION_ALGORITHM_2 + | hci.LeFeatureMask.MINIMUM_NUMBER_OF_USED_CHANNELS_PROCEDURE ) - le_features: bytes = bytes.fromhex('ff49010000000000') le_states: bytes = bytes.fromhex('ffff3fffff030000') advertising_channel_tx_power: int = 0 filter_accept_list_size: int = 8 @@ -150,10 +290,9 @@ class Controller: le_scan_type: int = 0 le_scan_interval: int = 0x10 le_scan_window: int = 0x10 - le_scan_enable: int = 0 + le_scan_enable: bool = False le_scan_own_address_type: int = hci.Address.RANDOM_DEVICE_ADDRESS le_scanning_filter_policy: int = 0 - le_scan_response_data: Optional[bytes] = None le_address_resolution: bool = False le_rpa_timeout: int = 0 sync_flow_control: bool = False @@ -163,6 +302,11 @@ class Controller: advertising_timer_handle: Optional[asyncio.Handle] = None classic_scan_enable: int = 0 classic_allow_role_switch: bool = True + pending_le_connection: ( + hci.HCI_LE_Create_Connection_Command + | hci.HCI_LE_Extended_Create_Connection_Command + | None + ) = None _random_address: hci.Address = hci.Address('00:00:00:00:00:00') @@ -171,23 +315,24 @@ def __init__( name: str, host_source=None, host_sink: Optional[TransportSink] = None, - link: Optional[LocalLink] = None, + link: Optional[link.LocalLink] = None, public_address: Optional[Union[bytes, str, hci.Address]] = None, ) -> None: self.name = name - self.link = link - self.central_connections = {} - self.peripheral_connections = {} + self.link = link or bumble_link.LocalLink() + self.le_connections = {} self.classic_connections = {} self.sco_links = {} self.classic_pending_commands = {} self.central_cis_links = {} self.peripheral_cis_links = {} + self.advertising_sets = {} self.default_phy = { 'all_phys': 0, 'tx_phys': 0, 'rx_phys': 0, } + self.le_legacy_advertiser = LegacyAdvertiser(self) if isinstance(public_address, hci.Address): self._public_address = public_address @@ -320,8 +465,7 @@ def allocate_connection_handle(self) -> int: current_handles = set( cast(Connection | CisLink | ScoLink, link).handle for link in itertools.chain( - self.central_connections.values(), - self.peripheral_connections.values(), + self.le_connections.values(), self.classic_connections.values(), self.sco_links.values(), self.central_cis_links.values(), @@ -329,49 +473,20 @@ def allocate_connection_handle(self) -> int: ) ) return next( - handle for handle in range(0xEFF + 1) if handle not in current_handles - ) - - def find_le_connection_by_address( - self, address: hci.Address - ) -> Optional[Connection]: - return self.central_connections.get(address) or self.peripheral_connections.get( - address + handle + for handle in range(0x0001, 0xEFF + 1) + if handle not in current_handles ) - def find_classic_connection_by_address( - self, address: hci.Address - ) -> Optional[Connection]: - return self.classic_connections.get(address) - def find_connection_by_handle(self, handle: int) -> Optional[Connection]: for connection in itertools.chain( - self.central_connections.values(), - self.peripheral_connections.values(), + self.le_connections.values(), self.classic_connections.values(), ): if connection.handle == handle: return connection return None - def find_central_connection_by_handle(self, handle: int) -> Optional[Connection]: - for connection in self.central_connections.values(): - if connection.handle == handle: - return connection - return None - - def find_peripheral_connection_by_handle(self, handle: int) -> Optional[Connection]: - for connection in self.peripheral_connections.values(): - if connection.handle == handle: - return connection - return None - - def find_classic_connection_by_handle(self, handle: int) -> Optional[Connection]: - for connection in self.classic_connections.values(): - if connection.handle == handle: - return connection - return None - def find_classic_sco_link_by_handle(self, handle: int) -> Optional[ScoLink]: for connection in self.sco_links.values(): if connection.handle == handle: @@ -383,28 +498,81 @@ def find_iso_link_by_handle(self, handle: int) -> Optional[CisLink]: handle ) - def on_link_central_connected(self, central_address: hci.Address) -> None: + def send_advertising_pdu(self, packet: ll.AdvertisingPdu) -> None: + logger.debug("[%s] >>> Advertising PDU: %s", self.name, packet) + if self.link: + self.link.send_advertising_pdu(self, packet) + + def on_ll_control_pdu( + self, sender_address: hci.Address, packet: ll.ControlPdu + ) -> None: + logger.debug("[%s] >>> LL Control PDU: %s", self.name, packet) + if not (connection := self.le_connections.get(sender_address)): + logger.error("Cannot find a connection for %s", sender_address) + return + + if isinstance(packet, ll.TerminateInd): + self.on_le_disconnected(connection, packet.error_code) + elif isinstance(packet, ll.CisReq): + self.on_le_cis_request(connection, packet.cig_id, packet.cis_id) + elif isinstance(packet, ll.CisRsp): + self.on_le_cis_established(packet.cig_id, packet.cis_id) + connection.send_ll_control_pdu(ll.CisInd(packet.cig_id, packet.cis_id)) + elif isinstance(packet, ll.CisInd): + self.on_le_cis_established(packet.cig_id, packet.cis_id) + elif isinstance(packet, ll.CisTerminateInd): + self.on_le_cis_disconnected(packet.cig_id, packet.cis_id) + elif isinstance(packet, ll.EncReq): + self.on_le_encrypted(connection) + + def on_ll_advertising_pdu(self, packet: ll.AdvertisingPdu) -> None: + logger.debug("[%s] <<< Advertising PDU: %s", self.name, packet) + if isinstance(packet, ll.ConnectInd): + self.on_le_connect_ind(packet) + elif isinstance(packet, (ll.AdvInd, ll.AdvExtInd)): + self.on_advertising_pdu(packet) + + def on_le_connect_ind(self, packet: ll.ConnectInd) -> None: ''' Called when an incoming connection occurs from a central on the link ''' + advertiser: LegacyAdvertiser | AdvertisingSet | None + if ( + self.le_legacy_advertiser.address == packet.advertiser_address + and self.le_legacy_advertiser.enabled + ): + advertiser = self.le_legacy_advertiser + else: + advertiser = next( + ( + advertising_set + for advertising_set in self.advertising_sets.values() + if advertising_set.address == packet.advertiser_address + and advertising_set.enabled + ), + None, + ) + + if not advertiser: + # This is not send to us. + return # Allocate (or reuse) a connection handle - peer_address = central_address - peer_address_type = central_address.address_type - connection = self.peripheral_connections.get(peer_address) - if connection is None: - connection_handle = self.allocate_connection_handle() - connection = Connection( - controller=self, - handle=connection_handle, - role=hci.Role.PERIPHERAL, - peer_address=peer_address, - link=self.link, - transport=PhysicalTransport.LE, - link_type=hci.HCI_Connection_Complete_Event.LinkType.ACL, - ) - self.peripheral_connections[peer_address] = connection - logger.debug(f'New PERIPHERAL connection handle: 0x{connection_handle:04X}') + peer_address = packet.initiator_address + + connection_handle = self.allocate_connection_handle() + connection = Connection( + controller=self, + handle=connection_handle, + role=hci.Role.PERIPHERAL, + self_address=packet.advertiser_address, + peer_address=peer_address, + link=self.link, + transport=PhysicalTransport.LE, + link_type=hci.HCI_Connection_Complete_Event.LinkType.ACL, + ) + self.le_connections[peer_address] = connection + logger.debug(f'New PERIPHERAL connection handle: 0x{connection_handle:04X}') # Then say that the connection has completed self.send_hci_packet( @@ -412,7 +580,7 @@ def on_link_central_connected(self, central_address: hci.Address) -> None: status=hci.HCI_SUCCESS, connection_handle=connection.handle, role=connection.role, - peer_address_type=peer_address_type, + peer_address_type=peer_address.address_type, peer_address=peer_address, connection_interval=10, # FIXME peripheral_latency=0, # FIXME @@ -421,131 +589,113 @@ def on_link_central_connected(self, central_address: hci.Address) -> None: ) ) - def on_link_disconnected(self, peer_address: hci.Address, reason: int) -> None: - ''' - Called when an active disconnection occurs from a peer - ''' - - # Send a disconnection complete event - if connection := self.peripheral_connections.get(peer_address): + if isinstance(advertiser, AdvertisingSet): self.send_hci_packet( - hci.HCI_Disconnection_Complete_Event( + hci.HCI_LE_Advertising_Set_Terminated_Event( status=hci.HCI_SUCCESS, + advertising_handle=advertiser.handle, connection_handle=connection.handle, - reason=reason, + num_completed_extended_advertising_events=0, ) ) + advertiser.stop() - # Remove the connection - del self.peripheral_connections[peer_address] - elif connection := self.central_connections.get(peer_address): - self.send_hci_packet( - hci.HCI_Disconnection_Complete_Event( - status=hci.HCI_SUCCESS, - connection_handle=connection.handle, - reason=reason, - ) + def on_le_disconnected(self, connection: Connection, reason: int) -> None: + # Send a disconnection complete event + self.send_hci_packet( + hci.HCI_Disconnection_Complete_Event( + status=hci.HCI_SUCCESS, + connection_handle=connection.handle, + reason=reason, ) + ) - # Remove the connection - del self.central_connections[peer_address] - else: - logger.warning(f'!!! No peripheral connection found for {peer_address}') - - def on_link_peripheral_connection_complete( - self, - le_create_connection_command: hci.HCI_LE_Create_Connection_Command, - status: int, - ) -> None: + def create_le_connection(self, peer_address: hci.Address) -> None: ''' - Called by the link when a connection has been made or has failed to be made + Called when we receive advertisement matching connection filter. ''' + pending_le_connection = self.pending_le_connection + assert pending_le_connection - if status == hci.HCI_SUCCESS: - # Allocate (or reuse) a connection handle - peer_address = le_create_connection_command.peer_address - connection = self.central_connections.get(peer_address) - if connection is None: - connection_handle = self.allocate_connection_handle() - connection = Connection( - controller=self, - handle=connection_handle, - role=hci.Role.CENTRAL, - peer_address=peer_address, - link=self.link, - transport=PhysicalTransport.LE, - link_type=hci.HCI_Connection_Complete_Event.LinkType.ACL, - ) - self.central_connections[peer_address] = connection - logger.debug( - f'New CENTRAL connection handle: 0x{connection_handle:04X}' - ) + if self.le_connections.get(peer_address): + logger.error("Connection for %s already exists?", peer_address) + return + + self_address = ( + self.public_address + if pending_le_connection.own_address_type == hci.OwnAddressType.PUBLIC + else self.random_address + ) + + # Allocate (or reuse) a connection handle + peer_address = pending_le_connection.peer_address + connection_handle = self.allocate_connection_handle() + connection = Connection( + controller=self, + handle=connection_handle, + role=hci.Role.CENTRAL, + self_address=self_address, + peer_address=peer_address, + link=self.link, + transport=PhysicalTransport.LE, + link_type=hci.HCI_Connection_Complete_Event.LinkType.ACL, + ) + self.le_connections[peer_address] = connection + logger.debug(f'New CENTRAL connection handle: 0x{connection_handle:04X}') + + if isinstance( + pending_le_connection, hci.HCI_LE_Extended_Create_Connection_Command + ): + interval = pending_le_connection.connection_interval_mins[0] + latency = pending_le_connection.max_latencies[0] + timeout = pending_le_connection.supervision_timeouts[0] else: - connection = None + interval = pending_le_connection.connection_interval_min + latency = pending_le_connection.max_latency + timeout = pending_le_connection.supervision_timeout + self.send_advertising_pdu( + ll.ConnectInd( + initiator_address=self_address, + advertiser_address=peer_address, + interval=interval, + latency=latency, + timeout=timeout, + ) + ) # Say that the connection has completed self.send_hci_packet( # pylint: disable=line-too-long hci.HCI_LE_Connection_Complete_Event( - status=status, + status=hci.HCI_SUCCESS, connection_handle=connection.handle if connection else 0, role=hci.Role.CENTRAL, - peer_address_type=le_create_connection_command.peer_address_type, - peer_address=le_create_connection_command.peer_address, - connection_interval=le_create_connection_command.connection_interval_min, - peripheral_latency=le_create_connection_command.max_latency, - supervision_timeout=le_create_connection_command.supervision_timeout, + peer_address_type=peer_address.address_type, + peer_address=peer_address, + connection_interval=interval, + peripheral_latency=latency, + supervision_timeout=timeout, central_clock_accuracy=0, ) ) + self.pending_le_connection = None - def on_link_disconnection_complete( - self, disconnection_command: hci.HCI_Disconnect_Command, status: int - ) -> None: - ''' - Called when a disconnection has been completed - ''' - - # Send a disconnection complete event + def on_le_encrypted(self, connection: Connection) -> None: + # For now, just setup the encryption without asking the host self.send_hci_packet( - hci.HCI_Disconnection_Complete_Event( - status=status, - connection_handle=disconnection_command.connection_handle, - reason=disconnection_command.reason, + hci.HCI_Encryption_Change_Event( + status=0, connection_handle=connection.handle, encryption_enabled=1 ) ) - # Remove the connection - if connection := self.find_central_connection_by_handle( - disconnection_command.connection_handle - ): - logger.debug(f'CENTRAL Connection removed: {connection}') - del self.central_connections[connection.peer_address] - elif connection := self.find_peripheral_connection_by_handle( - disconnection_command.connection_handle - ): - logger.debug(f'PERIPHERAL Connection removed: {connection}') - del self.peripheral_connections[connection.peer_address] - - def on_link_encrypted( - self, peer_address: hci.Address, _rand: bytes, _ediv: int, _ltk: bytes - ) -> None: - # For now, just setup the encryption without asking the host - if connection := self.find_le_connection_by_address(peer_address): - self.send_hci_packet( - hci.HCI_Encryption_Change_Event( - status=0, connection_handle=connection.handle, encryption_enabled=1 - ) - ) - def on_link_acl_data( self, sender_address: hci.Address, transport: PhysicalTransport, data: bytes ) -> None: # Look for the connection to which this data belongs if transport == PhysicalTransport.LE: - connection = self.find_le_connection_by_address(sender_address) + connection = self.le_connections.get(sender_address) else: - connection = self.find_classic_connection_by_address(sender_address) + connection = self.classic_connections.get(sender_address) if connection is None: logger.warning(f'!!! no connection for {sender_address}') return @@ -555,43 +705,83 @@ def on_link_acl_data( acl_packet = hci.HCI_AclDataPacket(connection.handle, 2, 0, len(data), data) self.send_hci_packet(acl_packet) - def on_link_advertising_data( - self, sender_address: hci.Address, data: bytes - ) -> None: - # Ignore if we're not scanning - if self.le_scan_enable == 0: - return + def on_advertising_pdu(self, pdu: ll.AdvInd | ll.AdvExtInd) -> None: + if isinstance(pdu, ll.AdvExtInd): + direct_address = pdu.target_address + else: + direct_address = None - # Send a scan report - report = hci.HCI_LE_Advertising_Report_Event.Report( - event_type=hci.HCI_LE_Advertising_Report_Event.EventType.ADV_IND, - address_type=sender_address.address_type, - address=sender_address, - data=data, - rssi=-50, - ) - self.send_hci_packet(hci.HCI_LE_Advertising_Report_Event([report])) + if self.le_scan_enable: + # Send a scan report + if self.le_features & hci.LeFeatureMask.LE_EXTENDED_ADVERTISING: + ext_report = hci.HCI_LE_Extended_Advertising_Report_Event.Report( + event_type=hci.HCI_LE_Extended_Advertising_Report_Event.EventType.CONNECTABLE_ADVERTISING, + address_type=pdu.advertiser_address.address_type, + address=pdu.advertiser_address, + primary_phy=hci.Phy.LE_1M, + secondary_phy=hci.Phy.LE_1M, + advertising_sid=0, + tx_power=0, + rssi=-50, + periodic_advertising_interval=0, + direct_address_type=( + direct_address.address_type if direct_address else 0 + ), + direct_address=direct_address or hci.Address.ANY, + data=pdu.data, + ) + self.send_hci_packet( + hci.HCI_LE_Extended_Advertising_Report_Event([ext_report]) + ) + ext_report = hci.HCI_LE_Extended_Advertising_Report_Event.Report( + event_type=hci.HCI_LE_Extended_Advertising_Report_Event.EventType.SCAN_RESPONSE, + address_type=pdu.advertiser_address.address_type, + address=pdu.advertiser_address, + primary_phy=hci.Phy.LE_1M, + secondary_phy=hci.Phy.LE_1M, + advertising_sid=0, + tx_power=0, + rssi=-50, + periodic_advertising_interval=0, + direct_address_type=( + direct_address.address_type if direct_address else 0 + ), + direct_address=direct_address or hci.Address.ANY, + data=pdu.data, + ) + self.send_hci_packet( + hci.HCI_LE_Extended_Advertising_Report_Event([ext_report]) + ) + else: + report = hci.HCI_LE_Advertising_Report_Event.Report( + event_type=hci.HCI_LE_Advertising_Report_Event.EventType.ADV_IND, + address_type=pdu.advertiser_address.address_type, + address=pdu.advertiser_address, + data=pdu.data, + rssi=-50, + ) + self.send_hci_packet(hci.HCI_LE_Advertising_Report_Event([report])) + report = hci.HCI_LE_Advertising_Report_Event.Report( + event_type=hci.HCI_LE_Advertising_Report_Event.EventType.SCAN_RSP, + address_type=pdu.advertiser_address.address_type, + address=pdu.advertiser_address, + data=pdu.data, + rssi=-50, + ) + self.send_hci_packet(hci.HCI_LE_Advertising_Report_Event([report])) - # Simulate a scan response - report = hci.HCI_LE_Advertising_Report_Event.Report( - event_type=hci.HCI_LE_Advertising_Report_Event.EventType.SCAN_RSP, - address_type=sender_address.address_type, - address=sender_address, - data=data, - rssi=-50, - ) - self.send_hci_packet(hci.HCI_LE_Advertising_Report_Event([report])) + # Create connection. + if ( + pending_le_connection := self.pending_le_connection + ) and pending_le_connection.peer_address == pdu.advertiser_address: + self.create_le_connection(pdu.advertiser_address) - def on_link_cis_request( - self, central_address: hci.Address, cig_id: int, cis_id: int + def on_le_cis_request( + self, connection: Connection, cig_id: int, cis_id: int ) -> None: ''' Called when an incoming CIS request occurs from a central on the link ''' - - connection = self.peripheral_connections.get(central_address) - assert connection - pending_cis_link = CisLink( handle=self.allocate_connection_handle(), cis_id=cis_id, @@ -609,7 +799,7 @@ def on_link_cis_request( ) ) - def on_link_cis_established(self, cig_id: int, cis_id: int) -> None: + def on_le_cis_established(self, cig_id: int, cis_id: int) -> None: ''' Called when an incoming CIS established. ''' @@ -644,7 +834,7 @@ def on_link_cis_established(self, cig_id: int, cis_id: int) -> None: ) ) - def on_link_cis_disconnected(self, cig_id: int, cis_id: int) -> None: + def on_le_cis_disconnected(self, cig_id: int, cis_id: int) -> None: ''' Called when a CIS disconnected. ''' @@ -750,6 +940,7 @@ def on_classic_connection_request( controller=self, handle=0, role=hci.Role.PERIPHERAL, + self_address=self.public_address, peer_address=peer_address, link=self.link, transport=PhysicalTransport.BR_EDR, @@ -784,6 +975,7 @@ def on_classic_connection_complete( controller=self, handle=connection_handle, role=hci.Role.CENTRAL, + self_address=self.public_address, peer_address=peer_address, link=self.link, transport=PhysicalTransport.BR_EDR, @@ -934,33 +1126,12 @@ def on_classic_remote_name_response( ############################################################ # Advertising support ############################################################ - def on_advertising_timer_fired(self) -> None: - self.send_advertising_data() - self.advertising_timer_handle = asyncio.get_running_loop().call_later( - self.advertising_interval / 1000.0, self.on_advertising_timer_fired - ) - - def start_advertising(self) -> None: - # Stop any ongoing advertising before we start again - self.stop_advertising() - - # Advertise now - self.advertising_timer_handle = asyncio.get_running_loop().call_soon( - self.on_advertising_timer_fired - ) - - def stop_advertising(self) -> None: - if self.advertising_timer_handle is not None: - self.advertising_timer_handle.cancel() - self.advertising_timer_handle = None - - def send_advertising_data(self) -> None: - if self.link and self.advertising_data: - self.link.send_advertising_data(self.random_address, self.advertising_data) @property def is_advertising(self) -> bool: - return self.advertising_timer_handle is not None + return self.le_legacy_advertiser.enabled or any( + s.enabled for s in self.advertising_sets.values() + ) ############################################################ # HCI handlers @@ -981,7 +1152,7 @@ def on_hci_create_connection_command( logger.debug(f'Connection request to {command.bd_addr}') # Check that we don't already have a pending connection - if self.link.get_pending_connection(): + if self.pending_le_connection: self.send_hci_packet( hci.HCI_Command_Status_Event( status=hci.HCI_CONTROLLER_BUSY_ERROR, @@ -995,6 +1166,7 @@ def on_hci_create_connection_command( controller=self, handle=0, role=hci.Role.CENTRAL, + self_address=self.public_address, peer_address=command.bd_addr, link=self.link, transport=PhysicalTransport.BR_EDR, @@ -1035,29 +1207,19 @@ def on_hci_disconnect_command( # Notify the link of the disconnection handle = command.connection_handle - if connection := self.find_central_connection_by_handle(handle): - if self.link: - self.link.disconnect( - self.random_address, connection.peer_address, command - ) - else: - # Remove the connection - del self.central_connections[connection.peer_address] - elif connection := self.find_peripheral_connection_by_handle(handle): - if self.link: - self.link.disconnect( - self.random_address, connection.peer_address, command - ) - else: - # Remove the connection - del self.peripheral_connections[connection.peer_address] - elif connection := self.find_classic_connection_by_handle(handle): + if connection := self.find_connection_by_handle(handle): if self.link: - self.send_lmp_packet( - connection.peer_address, - lmp.LmpDetach(command.reason), - ) - self.on_classic_disconnected(connection.peer_address, command.reason) + if connection.transport == PhysicalTransport.BR_EDR: + self.send_lmp_packet( + connection.peer_address, + lmp.LmpDetach(command.reason), + ) + self.on_classic_disconnected( + connection.peer_address, command.reason + ) + else: + connection.send_ll_control_pdu(ll.TerminateInd(command.reason)) + self.on_le_disconnected(connection, command.reason) else: # Remove the connection del self.classic_connections[connection.peer_address] @@ -1088,12 +1250,12 @@ def on_hci_disconnect_command( self.central_cis_links.get(handle) or self.peripheral_cis_links.get(handle) ): if self.link and cis_link.acl_connection: - self.link.disconnect_cis( - initiator_controller=self, - peer_address=cis_link.acl_connection.peer_address, - cig_id=cis_link.cig_id, - cis_id=cis_link.cis_id, + cis_link.acl_connection.send_ll_control_pdu( + ll.CisTerminateInd( + cis_link.cig_id, cis_link.cis_id, command.reason + ), ) + self.on_le_cis_disconnected(cis_link.cig_id, cis_link.cis_id) # Spec requires handle to be kept after disconnection. return None @@ -1185,9 +1347,7 @@ def on_hci_enhanced_setup_synchronous_connection_command( return None if not ( - connection := self.find_classic_connection_by_handle( - command.connection_handle - ) + connection := self.find_connection_by_handle(command.connection_handle) ): self.send_hci_packet( hci.HCI_Command_Status_Event( @@ -1243,7 +1403,7 @@ def on_hci_enhanced_accept_synchronous_connection_request_command( if self.link is None: return None - if not (connection := self.find_classic_connection_by_address(command.bd_addr)): + if not (connection := self.classic_connections.get(command.bd_addr)): self.send_hci_packet( hci.HCI_Command_Status_Event( status=hci.HCI_UNKNOWN_CONNECTION_IDENTIFIER_ERROR, @@ -1742,7 +1902,7 @@ def on_hci_le_read_local_supported_features_command( See Bluetooth spec Vol 4, Part E - 7.8.3 LE Read Local Supported Features Command ''' - return bytes([hci.HCI_SUCCESS]) + self.le_features + return bytes([hci.HCI_SUCCESS]) + self.le_features.value.to_bytes(8, 'little') def on_hci_le_set_random_address_command( self, command: hci.HCI_LE_Set_Random_Address_Command @@ -1759,7 +1919,22 @@ def on_hci_le_set_advertising_parameters_command( ''' See Bluetooth spec Vol 4, Part E - 7.8.5 LE Set Advertising Parameters Command ''' - self.advertising_parameters = command + self.le_legacy_advertiser.advertising_interval_min = ( + command.advertising_interval_min + ) + self.le_legacy_advertiser.advertising_interval_max = ( + command.advertising_interval_max + ) + self.le_legacy_advertiser.advertising_type = command.advertising_type + self.le_legacy_advertiser.own_address_type = command.own_address_type + self.le_legacy_advertiser.peer_address_type = command.peer_address_type + self.le_legacy_advertiser.peer_address = command.peer_address + self.le_legacy_advertiser.advertising_channel_map = ( + command.advertising_channel_map + ) + self.le_legacy_advertiser.advertising_filter_policy = ( + command.advertising_filter_policy + ) return bytes([hci.HCI_SUCCESS]) def on_hci_le_read_advertising_physical_channel_tx_power_command( @@ -1777,7 +1952,8 @@ def on_hci_le_set_advertising_data_command( ''' See Bluetooth spec Vol 4, Part E - 7.8.7 LE Set Advertising Data Command ''' - self.advertising_data = command.advertising_data + self.le_legacy_advertiser.advertising_data = command.advertising_data + return bytes([hci.HCI_SUCCESS]) def on_hci_le_set_scan_response_data_command( @@ -1786,7 +1962,7 @@ def on_hci_le_set_scan_response_data_command( ''' See Bluetooth spec Vol 4, Part E - 7.8.8 LE Set Scan Response Data Command ''' - self.le_scan_response_data = command.scan_response_data + self.le_legacy_advertiser.scan_response_data = command.scan_response_data return bytes([hci.HCI_SUCCESS]) def on_hci_le_set_advertising_enable_command( @@ -1796,9 +1972,9 @@ def on_hci_le_set_advertising_enable_command( See Bluetooth spec Vol 4, Part E - 7.8.9 LE Set Advertising Enable Command ''' if command.advertising_enable: - self.start_advertising() + self.le_legacy_advertiser.start() else: - self.stop_advertising() + self.le_legacy_advertiser.stop() return bytes([hci.HCI_SUCCESS]) @@ -1841,7 +2017,7 @@ def on_hci_le_create_connection_command( logger.debug(f'Connection request to {command.peer_address}') # Check that we don't already have a pending connection - if self.link.get_pending_connection(): + if self.pending_le_connection: self.send_hci_packet( hci.HCI_Command_Status_Event( status=hci.HCI_COMMAND_DISALLOWED_ERROR, @@ -1851,8 +2027,7 @@ def on_hci_le_create_connection_command( ) return None - # Initiate the connection - self.link.connect(self.random_address, command) + self.pending_le_connection = command # Say that the connection is pending self.send_hci_packet( @@ -1872,6 +2047,37 @@ def on_hci_le_create_connection_cancel_command( ''' return bytes([hci.HCI_SUCCESS]) + def on_hci_le_extended_create_connection_command( + self, command: hci.HCI_LE_Extended_Create_Connection_Command + ) -> Optional[bytes]: + ''' + See Bluetooth spec Vol 4, Part E - 7.8.66 LE Extended Create Connection Command + ''' + if not self.link: + return None + + # Check pending + if self.pending_le_connection: + self.send_hci_packet( + hci.HCI_Command_Status_Event( + status=hci.HCI_COMMAND_DISALLOWED_ERROR, + num_hci_command_packets=1, + command_opcode=command.op_code, + ) + ) + return None + + self.pending_le_connection = command + + self.send_hci_packet( + hci.HCI_Command_Status_Event( + status=hci.HCI_COMMAND_STATUS_PENDING, + num_hci_command_packets=1, + command_opcode=command.op_code, + ) + ) + return None + def on_hci_le_read_filter_accept_list_size_command( self, _command: hci.HCI_LE_Read_Filter_Accept_List_Size_Command ) -> Optional[bytes]: @@ -1972,21 +2178,21 @@ def on_hci_le_enable_encryption_command( return None # Check the parameters - if not ( - connection := self.find_central_connection_by_handle( - command.connection_handle + if ( + not ( + connection := self.find_connection_by_handle(command.connection_handle) ) + or connection.transport != PhysicalTransport.LE ): logger.warning('connection not found') return bytes([hci.HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR]) - # Notify that the connection is now encrypted - self.link.on_connection_encrypted( - self.random_address, - connection.peer_address, - command.random_number, - command.encrypted_diversifier, - command.long_term_key, + connection.send_ll_control_pdu( + ll.EncReq( + rand=command.random_number, + ediv=command.encrypted_diversifier, + ltk=command.long_term_key, + ), ) self.send_hci_packet( @@ -1997,6 +2203,9 @@ def on_hci_le_enable_encryption_command( ) ) + # TODO: Handle authentication + self.on_le_encrypted(connection) + return None def on_hci_le_read_supported_states_command( @@ -2134,48 +2343,125 @@ def on_hci_le_set_default_phy_command( return bytes([hci.HCI_SUCCESS]) def on_hci_le_set_advertising_set_random_address_command( - self, _command: hci.HCI_LE_Set_Advertising_Set_Random_Address_Command + self, command: hci.HCI_LE_Set_Advertising_Set_Random_Address_Command ) -> Optional[bytes]: ''' See Bluetooth spec Vol 4, Part E - 7.8.52 LE Set Advertising Set Random hci.Address Command ''' + handle = command.advertising_handle + if handle not in self.advertising_sets: + self.advertising_sets[handle] = AdvertisingSet( + controller=self, handle=handle + ) + self.advertising_sets[handle].random_address = command.random_address return bytes([hci.HCI_SUCCESS]) def on_hci_le_set_extended_advertising_parameters_command( - self, _command: hci.HCI_LE_Set_Extended_Advertising_Parameters_Command + self, command: hci.HCI_LE_Set_Extended_Advertising_Parameters_Command ) -> Optional[bytes]: ''' See Bluetooth spec Vol 4, Part E - 7.8.53 LE Set Extended Advertising Parameters Command ''' + handle = command.advertising_handle + if handle not in self.advertising_sets: + self.advertising_sets[handle] = AdvertisingSet( + controller=self, handle=handle + ) + + self.advertising_sets[handle].parameters = command return bytes([hci.HCI_SUCCESS, 0]) def on_hci_le_set_extended_advertising_data_command( - self, _command: hci.HCI_LE_Set_Extended_Advertising_Data_Command + self, command: hci.HCI_LE_Set_Extended_Advertising_Data_Command ) -> Optional[bytes]: ''' See Bluetooth spec Vol 4, Part E - 7.8.54 LE Set Extended Advertising Data Command ''' + handle = command.advertising_handle + if not (adv_set := self.advertising_sets.get(handle)): + return bytes([hci.HCI_UNKNOWN_ADVERTISING_IDENTIFIER_ERROR]) + + if command.operation in ( + hci.HCI_LE_Set_Extended_Advertising_Data_Command.Operation.FIRST_FRAGMENT, + hci.HCI_LE_Set_Extended_Advertising_Data_Command.Operation.COMPLETE_DATA, + ): + adv_set.data = bytearray(command.advertising_data) + elif command.operation in ( + hci.HCI_LE_Set_Extended_Advertising_Data_Command.Operation.INTERMEDIATE_FRAGMENT, + hci.HCI_LE_Set_Extended_Advertising_Data_Command.Operation.LAST_FRAGMENT, + ): + adv_set.data.extend(command.advertising_data) + return bytes([hci.HCI_SUCCESS]) def on_hci_le_set_extended_scan_response_data_command( - self, _command: hci.HCI_LE_Set_Extended_Scan_Response_Data_Command + self, command: hci.HCI_LE_Set_Extended_Scan_Response_Data_Command ) -> Optional[bytes]: ''' See Bluetooth spec Vol 4, Part E - 7.8.55 LE Set Extended Scan Response Data Command ''' + handle = command.advertising_handle + if not (adv_set := self.advertising_sets.get(handle)): + return bytes([hci.HCI_UNKNOWN_ADVERTISING_IDENTIFIER_ERROR]) + + if command.operation in ( + hci.HCI_LE_Set_Extended_Advertising_Data_Command.Operation.FIRST_FRAGMENT, + hci.HCI_LE_Set_Extended_Advertising_Data_Command.Operation.COMPLETE_DATA, + ): + adv_set.scan_response_data = bytearray(command.scan_response_data) + elif command.operation in ( + hci.HCI_LE_Set_Extended_Advertising_Data_Command.Operation.INTERMEDIATE_FRAGMENT, + hci.HCI_LE_Set_Extended_Advertising_Data_Command.Operation.LAST_FRAGMENT, + ): + adv_set.scan_response_data.extend(command.scan_response_data) + return bytes([hci.HCI_SUCCESS]) def on_hci_le_set_extended_advertising_enable_command( - self, _command: hci.HCI_LE_Set_Extended_Advertising_Enable_Command + self, command: hci.HCI_LE_Set_Extended_Advertising_Enable_Command ) -> Optional[bytes]: ''' See Bluetooth spec Vol 4, Part E - 7.8.56 LE Set Extended Advertising Enable Command ''' + if command.enable: + for handle in command.advertising_handles: + if advertising_set := self.advertising_sets.get(handle): + advertising_set.start() + else: + if not command.advertising_handles: + for advertising_set in self.advertising_sets.values(): + advertising_set.stop() + else: + for handle in command.advertising_handles: + if advertising_set := self.advertising_sets.get(handle): + advertising_set.stop() + return bytes([hci.HCI_SUCCESS]) + + def on_hci_le_remove_advertising_set_command( + self, command: hci.HCI_LE_Remove_Advertising_Set_Command + ) -> Optional[bytes]: + ''' + See Bluetooth spec Vol 4, Part E - 7.8.59 LE Remove Advertising Set Command + ''' + handle = command.advertising_handle + if advertising_set := self.advertising_sets.pop(handle, None): + advertising_set.stop() + return bytes([hci.HCI_SUCCESS]) + + def on_hci_le_clear_advertising_sets_command( + self, _command: hci.HCI_LE_Clear_Advertising_Sets_Command + ) -> Optional[bytes]: + ''' + See Bluetooth spec Vol 4, Part E - 7.8.60 LE Clear Advertising Sets Command + ''' + for advertising_set in self.advertising_sets.values(): + advertising_set.stop() + self.advertising_sets.clear() return bytes([hci.HCI_SUCCESS]) def on_hci_le_read_maximum_advertising_data_length_command( @@ -2279,11 +2565,8 @@ def on_hci_le_create_cis_command( cis_link.acl_connection = connection - self.link.create_cis( - self, - peripheral_address=connection.peer_address, - cig_id=cis_link.cig_id, - cis_id=cis_link.cis_id, + connection.send_ll_control_pdu( + ll.CisReq(cig_id=cis_link.cig_id, cis_id=cis_link.cis_id) ) self.send_hci_packet( @@ -2328,11 +2611,8 @@ def on_hci_le_accept_cis_request_command( return bytes([hci.HCI_INVALID_HCI_COMMAND_PARAMETERS_ERROR]) assert pending_cis_link.acl_connection - self.link.accept_cis( - peripheral_controller=self, - central_address=pending_cis_link.acl_connection.peer_address, - cig_id=pending_cis_link.cig_id, - cis_id=pending_cis_link.cis_id, + pending_cis_link.acl_connection.send_ll_control_pdu( + ll.CisRsp(cig_id=pending_cis_link.cig_id, cis_id=pending_cis_link.cis_id), ) self.send_hci_packet( diff --git a/bumble/device.py b/bumble/device.py index 1f63ba2a..ff495760 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -265,12 +265,22 @@ def from_advertising_report( # ----------------------------------------------------------------------------- class AdvertisementDataAccumulator: + last_advertisement: Advertisement | None + last_data: bytes + passive: bool + def __init__(self, passive: bool = False): self.passive = passive self.last_advertisement = None self.last_data = b'' - def update(self, report): + def update( + self, + report: ( + hci.HCI_LE_Advertising_Report_Event.Report + | hci.HCI_LE_Extended_Advertising_Report_Event.Report + ), + ) -> Advertisement | None: advertisement = Advertisement.from_advertising_report(report) if advertisement is None: return None @@ -283,10 +293,12 @@ def update(self, report): and not self.last_advertisement.is_scan_response ): # This is the response to a scannable advertisement - result = Advertisement.from_advertising_report(report) - result.is_connectable = self.last_advertisement.is_connectable - result.is_scannable = True - result.data = AdvertisingData.from_bytes(self.last_data + report.data) + if result := Advertisement.from_advertising_report(report): + result.is_connectable = self.last_advertisement.is_connectable + result.is_scannable = True + result.data = AdvertisingData.from_bytes( + self.last_data + report.data + ) self.last_data = b'' else: if ( @@ -3333,7 +3345,13 @@ def is_scanning(self): return self.scanning @host_event_handler - def on_advertising_report(self, report): + def on_advertising_report( + self, + report: ( + hci.HCI_LE_Advertising_Report_Event.Report + | hci.HCI_LE_Extended_Advertising_Report_Event.Report + ), + ) -> None: if not (accumulator := self.advertisement_accumulators.get(report.address)): accumulator = AdvertisementDataAccumulator(passive=self.scanning_is_passive) self.advertisement_accumulators[report.address] = accumulator diff --git a/bumble/link.py b/bumble/link.py index 4c4e1688..eef5a21a 100644 --- a/bumble/link.py +++ b/bumble/link.py @@ -19,9 +19,12 @@ # Imports # ----------------------------------------------------------------------------- import logging -from typing import Optional +from typing import TYPE_CHECKING, Optional -from bumble import controller, core, hci, lmp +from bumble import core, hci, ll, lmp + +if TYPE_CHECKING: + from bumble import controller # ----------------------------------------------------------------------------- # Logging @@ -29,11 +32,6 @@ logger = logging.getLogger(__name__) -# ----------------------------------------------------------------------------- -# Utils -# ----------------------------------------------------------------------------- - - # ----------------------------------------------------------------------------- # TODO: add more support for various LL exchanges # (see Vol 6, Part B - 2.4 DATA CHANNEL PDU) @@ -47,7 +45,6 @@ class LocalLink: def __init__(self): self.controllers = set() - self.pending_connection = None self.pending_classic_connection = None ############################################################ @@ -61,10 +58,11 @@ def add_controller(self, controller: controller.Controller): def remove_controller(self, controller: controller.Controller): self.controllers.remove(controller) - def find_controller(self, address: hci.Address) -> controller.Controller | None: + def find_le_controller(self, address: hci.Address) -> controller.Controller | None: for controller in self.controllers: - if controller.random_address == address: - return controller + for connection in controller.le_connections.values(): + if connection.self_address == address: + return controller return None def find_classic_controller( @@ -75,9 +73,6 @@ def find_classic_controller( return controller return None - def get_pending_connection(self): - return self.pending_connection - ############################################################ # LE handlers ############################################################ @@ -85,12 +80,6 @@ def get_pending_connection(self): def on_address_changed(self, controller): pass - def send_advertising_data(self, sender_address: hci.Address, data: bytes): - # Send the advertising data to all controllers, except the sender - for controller in self.controllers: - if controller.random_address != sender_address: - controller.on_link_advertising_data(sender_address, data) - def send_acl_data( self, sender_controller: controller.Controller, @@ -100,7 +89,7 @@ def send_acl_data( ): # Send the data to the first controller with a matching address if transport == core.PhysicalTransport.LE: - destination_controller = self.find_controller(destination_address) + destination_controller = self.find_le_controller(destination_address) source_address = sender_controller.random_address elif transport == core.PhysicalTransport.BR_EDR: destination_controller = self.find_classic_controller(destination_address) @@ -115,152 +104,30 @@ def send_acl_data( ) ) - def on_connection_complete(self) -> None: - # Check that we expect this call - if not self.pending_connection: - logger.warning('on_connection_complete with no pending connection') - return - - central_address, le_create_connection_command = self.pending_connection - self.pending_connection = None - - # Find the controller that initiated the connection - if not (central_controller := self.find_controller(central_address)): - logger.warning('!!! Initiating controller not found') - return - - # Connect to the first controller with a matching address - if peripheral_controller := self.find_controller( - le_create_connection_command.peer_address - ): - central_controller.on_link_peripheral_connection_complete( - le_create_connection_command, hci.HCI_SUCCESS - ) - peripheral_controller.on_link_central_connected(central_address) - return - - # No peripheral found - central_controller.on_link_peripheral_connection_complete( - le_create_connection_command, hci.HCI_CONNECTION_ACCEPT_TIMEOUT_ERROR - ) - - def connect( + def send_advertising_pdu( self, - central_address: hci.Address, - le_create_connection_command: hci.HCI_LE_Create_Connection_Command, + sender_controller: controller.Controller, + packet: ll.AdvertisingPdu, ): - logger.debug( - f'$$$ CONNECTION {central_address} -> ' - f'{le_create_connection_command.peer_address}' - ) - self.pending_connection = (central_address, le_create_connection_command) - asyncio.get_running_loop().call_soon(self.on_connection_complete) + loop = asyncio.get_running_loop() + for c in self.controllers: + if c != sender_controller: + loop.call_soon(c.on_ll_advertising_pdu, packet) - def on_disconnection_complete( + def send_ll_control_pdu( self, - initiating_address: hci.Address, - target_address: hci.Address, - disconnect_command: hci.HCI_Disconnect_Command, + sender_address: hci.Address, + receiver_address: hci.Address, + packet: ll.ControlPdu, ): - # Find the controller that initiated the disconnection - if not (initiating_controller := self.find_controller(initiating_address)): - logger.warning('!!! Initiating controller not found') - return - - # Disconnect from the first controller with a matching address - if target_controller := self.find_controller(target_address): - target_controller.on_link_disconnected( - initiating_address, disconnect_command.reason + if not (receiver_controller := self.find_le_controller(receiver_address)): + raise core.InvalidArgumentError( + f"Unable to find controller for address {receiver_address}" ) - - initiating_controller.on_link_disconnection_complete( - disconnect_command, hci.HCI_SUCCESS - ) - - def disconnect( - self, - initiating_address: hci.Address, - target_address: hci.Address, - disconnect_command: hci.HCI_Disconnect_Command, - ): - logger.debug( - f'$$$ DISCONNECTION {initiating_address} -> ' - f'{target_address}: reason = {disconnect_command.reason}' - ) asyncio.get_running_loop().call_soon( - lambda: self.on_disconnection_complete( - initiating_address, target_address, disconnect_command - ) + lambda: receiver_controller.on_ll_control_pdu(sender_address, packet) ) - def on_connection_encrypted( - self, - central_address: hci.Address, - peripheral_address: hci.Address, - rand: bytes, - ediv: int, - ltk: bytes, - ): - logger.debug(f'*** ENCRYPTION {central_address} -> {peripheral_address}') - - if central_controller := self.find_controller(central_address): - central_controller.on_link_encrypted(peripheral_address, rand, ediv, ltk) - - if peripheral_controller := self.find_controller(peripheral_address): - peripheral_controller.on_link_encrypted(central_address, rand, ediv, ltk) - - def create_cis( - self, - central_controller: controller.Controller, - peripheral_address: hci.Address, - cig_id: int, - cis_id: int, - ) -> None: - logger.debug( - f'$$$ CIS Request {central_controller.random_address} -> {peripheral_address}' - ) - if peripheral_controller := self.find_controller(peripheral_address): - asyncio.get_running_loop().call_soon( - peripheral_controller.on_link_cis_request, - central_controller.random_address, - cig_id, - cis_id, - ) - - def accept_cis( - self, - peripheral_controller: controller.Controller, - central_address: hci.Address, - cig_id: int, - cis_id: int, - ) -> None: - logger.debug( - f'$$$ CIS Accept {peripheral_controller.random_address} -> {central_address}' - ) - if central_controller := self.find_controller(central_address): - loop = asyncio.get_running_loop() - loop.call_soon(central_controller.on_link_cis_established, cig_id, cis_id) - loop.call_soon( - peripheral_controller.on_link_cis_established, cig_id, cis_id - ) - - def disconnect_cis( - self, - initiator_controller: controller.Controller, - peer_address: hci.Address, - cig_id: int, - cis_id: int, - ) -> None: - logger.debug( - f'$$$ CIS Disconnect {initiator_controller.random_address} -> {peer_address}' - ) - if peer_controller := self.find_controller(peer_address): - loop = asyncio.get_running_loop() - loop.call_soon( - initiator_controller.on_link_cis_disconnected, cig_id, cis_id - ) - loop.call_soon(peer_controller.on_link_cis_disconnected, cig_id, cis_id) - ############################################################ # Classic handlers ############################################################ diff --git a/bumble/ll.py b/bumble/ll.py new file mode 100644 index 00000000..08d40c7b --- /dev/null +++ b/bumble/ll.py @@ -0,0 +1,200 @@ +# Copyright 2021-2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ----------------------------------------------------------------------------- +# Imports +# ----------------------------------------------------------------------------- +from __future__ import annotations + +import dataclasses +from typing import ClassVar + +from bumble import hci + + +# ----------------------------------------------------------------------------- +# Advertising PDU +# ----------------------------------------------------------------------------- +class AdvertisingPdu: + """Base Advertising Physical Channel PDU class. + + See Core Spec 6.0, Volume 6, Part B, 2.3. Advertising physical channel PDU. + + Currently these messages don't really follow the LL spec, because LL protocol is + context-aware and we don't have real physical transport. + """ + + +@dataclasses.dataclass +class ConnectInd(AdvertisingPdu): + initiator_address: hci.Address + advertiser_address: hci.Address + interval: int + latency: int + timeout: int + + +@dataclasses.dataclass +class AdvInd(AdvertisingPdu): + advertiser_address: hci.Address + data: bytes + + +@dataclasses.dataclass +class AdvDirectInd(AdvertisingPdu): + advertiser_address: hci.Address + target_address: hci.Address + + +@dataclasses.dataclass +class AdvNonConnInd(AdvertisingPdu): + advertiser_address: hci.Address + data: bytes + + +@dataclasses.dataclass +class AdvExtInd(AdvertisingPdu): + advertiser_address: hci.Address + data: bytes + + target_address: hci.Address | None = None + adi: int | None = None + tx_power: int | None = None + + +# ----------------------------------------------------------------------------- +# LL Control PDU +# ----------------------------------------------------------------------------- +class ControlPdu: + """Base LL Control PDU Class. + + See Core Spec 6.0, Volume 6, Part B, 2.4.2. LL Control PDU. + + Currently these messages don't really follow the LL spec, because LL protocol is + context-aware and we don't have real physical transport. + """ + + class Opcode(hci.SpecableEnum): + LL_CONNECTION_UPDATE_IND = 0x00 + LL_CHANNEL_MAP_IND = 0x01 + LL_TERMINATE_IND = 0x02 + LL_ENC_REQ = 0x03 + LL_ENC_RSP = 0x04 + LL_START_ENC_REQ = 0x05 + LL_START_ENC_RSP = 0x06 + LL_UNKNOWN_RSP = 0x07 + LL_FEATURE_REQ = 0x08 + LL_FEATURE_RSP = 0x09 + LL_PAUSE_ENC_REQ = 0x0A + LL_PAUSE_ENC_RSP = 0x0B + LL_VERSION_IND = 0x0C + LL_REJECT_IND = 0x0D + LL_PERIPHERAL_FEATURE_REQ = 0x0E + LL_CONNECTION_PARAM_REQ = 0x0F + LL_CONNECTION_PARAM_RSP = 0x10 + LL_REJECT_EXT_IND = 0x11 + LL_PING_REQ = 0x12 + LL_PING_RSP = 0x13 + LL_LENGTH_REQ = 0x14 + LL_LENGTH_RSP = 0x15 + LL_PHY_REQ = 0x16 + LL_PHY_RSP = 0x17 + LL_PHY_UPDATE_IND = 0x18 + LL_MIN_USED_CHANNELS_IND = 0x19 + LL_CTE_REQ = 0x1A + LL_CTE_RSP = 0x1B + LL_PERIODIC_SYNC_IND = 0x1C + LL_CLOCK_ACCURACY_REQ = 0x1D + LL_CLOCK_ACCURACY_RSP = 0x1E + LL_CIS_REQ = 0x1F + LL_CIS_RSP = 0x20 + LL_CIS_IND = 0x21 + LL_CIS_TERMINATE_IND = 0x22 + LL_POWER_CONTROL_REQ = 0x23 + LL_POWER_CONTROL_RSP = 0x24 + LL_POWER_CHANGE_IND = 0x25 + LL_SUBRATE_REQ = 0x26 + LL_SUBRATE_IND = 0x27 + LL_CHANNEL_REPORTING_IND = 0x28 + LL_CHANNEL_STATUS_IND = 0x29 + LL_PERIODIC_SYNC_WR_IND = 0x2A + LL_FEATURE_EXT_REQ = 0x2B + LL_FEATURE_EXT_RSP = 0x2C + LL_CS_SEC_RSP = 0x2D + LL_CS_CAPABILITIES_REQ = 0x2E + LL_CS_CAPABILITIES_RSP = 0x2F + LL_CS_CONFIG_REQ = 0x30 + LL_CS_CONFIG_RSP = 0x31 + LL_CS_REQ = 0x32 + LL_CS_RSP = 0x33 + LL_CS_IND = 0x34 + LL_CS_TERMINATE_REQ = 0x35 + LL_CS_FAE_REQ = 0x36 + LL_CS_FAE_RSP = 0x37 + LL_CS_CHANNEL_MAP_IND = 0x38 + LL_CS_SEC_REQ = 0x39 + LL_CS_TERMINATE_RSP = 0x3A + LL_FRAME_SPACE_REQ = 0x3B + LL_FRAME_SPACE_RSP = 0x3C + + opcode: ClassVar[Opcode] + + +@dataclasses.dataclass +class TerminateInd(ControlPdu): + opcode = ControlPdu.Opcode.LL_TERMINATE_IND + + error_code: int + + +@dataclasses.dataclass +class EncReq(ControlPdu): + opcode = ControlPdu.Opcode.LL_ENC_REQ + + rand: bytes + ediv: int + ltk: bytes + + +@dataclasses.dataclass +class CisReq(ControlPdu): + opcode = ControlPdu.Opcode.LL_CIS_REQ + + cig_id: int + cis_id: int + + +@dataclasses.dataclass +class CisRsp(ControlPdu): + opcode = ControlPdu.Opcode.LL_CIS_REQ + + cig_id: int + cis_id: int + + +@dataclasses.dataclass +class CisInd(ControlPdu): + opcode = ControlPdu.Opcode.LL_CIS_REQ + + cig_id: int + cis_id: int + + +@dataclasses.dataclass +class CisTerminateInd(ControlPdu): + opcode = ControlPdu.Opcode.LL_CIS_TERMINATE_IND + + cig_id: int + cis_id: int + error_code: int diff --git a/tests/device_test.py b/tests/device_test.py index 3fd3d20a..768ba7d8 100644 --- a/tests/device_test.py +++ b/tests/device_test.py @@ -284,52 +284,51 @@ async def test_legacy_advertising(): @pytest.mark.asyncio async def test_legacy_advertising_disconnection(auto_restart): devices = TwoDevices() - device = devices[0] - devices.controllers[0].le_features = bytes.fromhex('ffffffffffffffff') - await device.power_on() - peer_address = Address('F0:F1:F2:F3:F4:F5') - await device.start_advertising(auto_restart=auto_restart) - device.on_le_connection( - 0x0001, - peer_address, - None, - None, - Role.PERIPHERAL, - 0, - 0, - 0, + for controller in devices.controllers: + controller.le_features |= hci.LeFeatureMask.LE_EXTENDED_ADVERTISING + for dev in devices: + await dev.power_on() + await devices[0].start_advertising( + auto_restart=auto_restart, advertising_interval_min=1.0 ) + connecion = await devices[1].connect(devices[0].random_address) - device.on_advertising_set_termination( - HCI_SUCCESS, device.legacy_advertising_set.advertising_handle, 0x0001, 0 - ) + await connecion.disconnect() - device.on_disconnection(0x0001, 0) await async_barrier() await async_barrier() if auto_restart: - assert device.legacy_advertising_set + assert devices[0].legacy_advertising_set started = asyncio.Event() - if not device.is_advertising: - device.legacy_advertising_set.once('start', started.set) + if not devices[0].is_advertising: + devices[0].legacy_advertising_set.once('start', started.set) await asyncio.wait_for(started.wait(), _TIMEOUT) - assert device.is_advertising + assert devices[0].is_advertising else: - assert not device.is_advertising + assert not devices[0].is_advertising # ----------------------------------------------------------------------------- @pytest.mark.asyncio -async def test_extended_advertising(): - device = TwoDevices()[0] - await device.power_on() +async def test_advertising_and_scanning(): + devices = TwoDevices() + for dev in devices: + await dev.power_on() + + # Start scanning + advertisements = asyncio.Queue[device.Advertisement]() + devices[1].on(devices[1].EVENT_ADVERTISEMENT, advertisements.put_nowait) + await devices[1].start_scanning() # Start advertising - advertising_set = await device.create_advertising_set() - assert device.extended_advertising_sets + advertising_set = await devices[0].create_advertising_set(advertising_data=b'123') + assert devices[0].extended_advertising_sets assert advertising_set.enabled + advertisement = await asyncio.wait_for(advertisements.get(), _TIMEOUT) + assert advertisement.data_bytes == b'123' + # Stop advertising await advertising_set.stop() assert not advertising_set.enabled @@ -342,33 +341,33 @@ async def test_extended_advertising(): ) @pytest.mark.asyncio async def test_extended_advertising_connection(own_address_type): - device = TwoDevices()[0] - await device.power_on() - peer_address = Address('F0:F1:F2:F3:F4:F5') - advertising_set = await device.create_advertising_set( - advertising_parameters=AdvertisingParameters(own_address_type=own_address_type) - ) - device.on_le_connection( - 0x0001, - peer_address, - None, - None, - Role.PERIPHERAL, - 0, - 0, - 0, + devices = TwoDevices() + for dev in devices: + await dev.power_on() + advertising_set = await devices[0].create_advertising_set( + advertising_parameters=AdvertisingParameters( + own_address_type=own_address_type, primary_advertising_interval_min=1.0 + ) ) - device.on_advertising_set_termination( - HCI_SUCCESS, - advertising_set.advertising_handle, - 0x0001, - 0, + await asyncio.wait_for( + devices[1].connect(advertising_set.random_address or devices[0].public_address), + _TIMEOUT, ) + await async_barrier() + + # Advertising set should be terminated after connected. + assert not advertising_set.enabled if own_address_type == OwnAddressType.PUBLIC: - assert device.lookup_connection(0x0001).self_address == device.public_address + assert ( + devices[0].lookup_connection(0x0001).self_address + == devices[0].public_address + ) else: - assert device.lookup_connection(0x0001).self_address == device.random_address + assert ( + devices[0].lookup_connection(0x0001).self_address + == devices[0].random_address + ) await async_barrier() @@ -382,7 +381,7 @@ async def test_extended_advertising_connection(own_address_type): async def test_extended_advertising_connection_out_of_order(own_address_type): devices = TwoDevices() device = devices[0] - devices.controllers[0].le_features = bytes.fromhex('ffffffffffffffff') + devices.controllers[0].le_features |= hci.LeFeatureMask.LE_EXTENDED_ADVERTISING await device.power_on() advertising_set = await device.create_advertising_set( advertising_parameters=AdvertisingParameters(own_address_type=own_address_type) diff --git a/tests/gatt_test.py b/tests/gatt_test.py index 49c0f758..8b3a295d 100644 --- a/tests/gatt_test.py +++ b/tests/gatt_test.py @@ -69,7 +69,7 @@ from bumble.link import LocalLink from bumble.transport.common import AsyncPipeSink -from .test_utils import async_barrier +from .test_utils import Devices, TwoDevices, async_barrier # ----------------------------------------------------------------------------- @@ -160,7 +160,8 @@ def encode_value(self, value): def decode_value(self, value_bytes): return value_bytes[0] - [client, server] = LinkedDevices().devices[:2] + devices = await TwoDevices.create_with_connection() + [client, server] = devices characteristic = Characteristic( 'FDB159DB-036C-49E3-B3DB-6325AC750806', @@ -189,9 +190,7 @@ async def async_read(connection): ) server.add_service(service) - await client.power_on() - await server.power_on() - connection = await client.connect(server.random_address) + connection = devices.connections[0] peer = Peer(connection) await peer.discover_services() @@ -279,7 +278,8 @@ def on_change(value): # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_attribute_getters(): - [client, server] = LinkedDevices().devices[:2] + devices = await TwoDevices.create_with_connection() + [client, server] = devices characteristic_uuid = UUID('FDB159DB-036C-49E3-B3DB-6325AC750806') characteristic = Characteristic( @@ -629,39 +629,11 @@ async def read_value(connection): m.assert_called_once_with(z, b) -# ----------------------------------------------------------------------------- -class LinkedDevices: - def __init__(self): - self.connections = [None, None, None] - - self.link = LocalLink() - self.controllers = [ - Controller('C1', link=self.link), - Controller('C2', link=self.link), - Controller('C3', link=self.link), - ] - self.devices = [ - Device( - address='F0:F1:F2:F3:F4:F5', - host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])), - ), - Device( - address='F1:F2:F3:F4:F5:F6', - host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])), - ), - Device( - address='F2:F3:F4:F5:F6:F7', - host=Host(self.controllers[2], AsyncPipeSink(self.controllers[2])), - ), - ] - - self.paired = [None, None, None] - - # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_read_write(): - [client, server] = LinkedDevices().devices[:2] + devices = await TwoDevices.create_with_connection() + [client, server] = devices characteristic1 = Characteristic( 'FDB159DB-036C-49E3-B3DB-6325AC750806', @@ -694,9 +666,7 @@ def on_characteristic2_write(connection, value): ) server.add_services([service1]) - await client.power_on() - await server.power_on() - connection = await client.connect(server.random_address) + connection = devices.connections[0] peer = Peer(connection) await peer.discover_services() @@ -740,7 +710,8 @@ def on_characteristic2_write(connection, value): # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_read_write2(): - [client, server] = LinkedDevices().devices[:2] + devices = await TwoDevices.create_with_connection() + [client, server] = devices v = bytes([0x11, 0x22, 0x33, 0x44]) characteristic1 = Characteristic( @@ -753,9 +724,7 @@ async def test_read_write2(): service1 = Service('3A657F47-D34F-46B3-B1EC-698E29B6B829', [characteristic1]) server.add_services([service1]) - await client.power_on() - await server.power_on() - connection = await client.connect(server.random_address) + connection = devices.connections[0] peer = Peer(connection) await peer.discover_services() @@ -785,7 +754,8 @@ async def test_read_write2(): # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_subscribe_notify(): - [client, server] = LinkedDevices().devices[:2] + devices = await TwoDevices.create_with_connection() + [client, server] = devices characteristic1 = Characteristic( 'FDB159DB-036C-49E3-B3DB-6325AC750806', @@ -855,9 +825,7 @@ def on_characteristic_subscription( server.on('characteristic_subscription', on_characteristic_subscription) - await client.power_on() - await server.power_on() - connection = await client.connect(server.random_address) + connection = devices.connections[0] peer = Peer(connection) await peer.discover_services() @@ -1006,7 +974,8 @@ def on_c3_update_3(value): # for indicate # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_unsubscribe(): - [client, server] = LinkedDevices().devices[:2] + devices = await TwoDevices.create_with_connection() + [client, server] = devices characteristic1 = Characteristic( 'FDB159DB-036C-49E3-B3DB-6325AC750806', @@ -1032,9 +1001,7 @@ async def test_unsubscribe(): mock2 = Mock() characteristic2.on('subscription', mock2) - await client.power_on() - await server.power_on() - connection = await client.connect(server.random_address) + connection = devices.connections[0] peer = Peer(connection) await peer.discover_services() @@ -1094,7 +1061,8 @@ def callback(_): # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_discover_all(): - [client, server] = LinkedDevices().devices[:2] + devices = await TwoDevices.create_with_connection() + [client, server] = devices characteristic1 = Characteristic( 'FDB159DB-036C-49E3-B3DB-6325AC750806', @@ -1120,9 +1088,7 @@ async def test_discover_all(): service2 = Service('1111', []) server.add_services([service1, service2]) - await client.power_on() - await server.power_on() - connection = await client.connect(server.random_address) + connection = devices.connections[0] peer = Peer(connection) await peer.discover_all() @@ -1146,7 +1112,10 @@ async def test_discover_all(): # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_mtu_exchange(): - [d1, d2, d3] = LinkedDevices().devices[:3] + devices = Devices(3) + for dev in devices: + await dev.power_on() + [d1, d2, d3] = devices d3.gatt_server.max_mtu = 100 @@ -1160,11 +1129,15 @@ def on_d3_connection(connection): await d2.power_on() await d3.power_on() + await d3.start_advertising(advertising_interval_min=1.0) d1_connection = await d1.connect(d3.random_address) + await async_barrier() assert len(d3_connections) == 1 assert d3_connections[0] is not None + await d3.start_advertising(advertising_interval_min=1.0) d2_connection = await d2.connect(d3.random_address) + await async_barrier() assert len(d3_connections) == 2 assert d3_connections[1] is not None @@ -1233,7 +1206,8 @@ def test_characteristic_property_from_string_assert(): # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_server_string(): - [_, server] = LinkedDevices().devices[:2] + devices = await TwoDevices.create_with_connection() + [_, server] = devices characteristic = Characteristic( 'FDB159DB-036C-49E3-B3DB-6325AC750806', @@ -1422,7 +1396,8 @@ def test_get_attribute_group(): # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_get_characteristics_by_uuid(): - [client, server] = LinkedDevices().devices[:2] + devices = await TwoDevices.create_with_connection() + [client, server] = devices characteristic1 = Characteristic( '1234', @@ -1447,9 +1422,7 @@ async def test_get_characteristics_by_uuid(): server.add_services([service1, service2]) - await client.power_on() - await server.power_on() - connection = await client.connect(server.random_address) + connection = devices.connections[0] peer = Peer(connection) await peer.discover_services() @@ -1472,7 +1445,8 @@ async def test_get_characteristics_by_uuid(): # ----------------------------------------------------------------------------- @pytest.mark.asyncio async def test_write_return_error(): - [client, server] = LinkedDevices().devices[:2] + devices = await TwoDevices.create_with_connection() + [client, server] = devices on_write = Mock(side_effect=ATT_Error(error_code=ErrorCode.VALUE_NOT_ALLOWED)) characteristic = Characteristic( @@ -1484,9 +1458,7 @@ async def test_write_return_error(): service = Service('ABCD', [characteristic]) server.add_service(service) - await client.power_on() - await server.power_on() - connection = await client.connect(server.random_address) + connection = devices.connections[0] async with Peer(connection) as peer: c = peer.get_characteristics_by_uuid(uuid=UUID('1234'))[0] diff --git a/tests/self_test.py b/tests/self_test.py index c431c59a..aa2e7a5f 100644 --- a/tests/self_test.py +++ b/tests/self_test.py @@ -35,7 +35,7 @@ OobLegacyContext, ) -from .test_utils import TwoDevices +from .test_utils import TwoDevices, async_barrier # ----------------------------------------------------------------------------- # Logging @@ -56,12 +56,14 @@ async def test_self_disconnection(): two_devices = TwoDevices() await two_devices.setup_connection() await two_devices.connections[0].disconnect() + await async_barrier() assert two_devices.connections[0] is None assert two_devices.connections[1] is None two_devices = TwoDevices() await two_devices.setup_connection() await two_devices.connections[1].disconnect() + await async_barrier() assert two_devices.connections[0] is None assert two_devices.connections[1] is None @@ -80,7 +82,8 @@ async def test_self_classic_connection(responder_role): two_devices.devices[1].classic_enabled = True # Start - await two_devices.setup_connection() + for dev in two_devices.devices: + await dev.power_on() # Connect the two devices await asyncio.gather( @@ -418,8 +421,9 @@ async def test_self_smp_over_classic(): two_devices.devices[1].classic_enabled = True # Connect the two devices - await two_devices.devices[0].power_on() - await two_devices.devices[1].power_on() + for dev in two_devices.devices: + await dev.power_on() + await asyncio.gather( two_devices.devices[0].connect( two_devices.devices[1].public_address, transport=PhysicalTransport.BR_EDR diff --git a/tests/test_utils.py b/tests/test_utils.py index 1051101a..b50abc26 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -16,6 +16,7 @@ # Imports # ----------------------------------------------------------------------------- import asyncio +import functools from typing import Optional from typing_extensions import Self @@ -30,39 +31,34 @@ # ----------------------------------------------------------------------------- -class TwoDevices: +class Devices: connections: list[Optional[Connection]] - def __init__(self) -> None: - self.connections = [None, None] + def __init__(self, num_devices: int) -> None: + self.connections = [None for _ in range(num_devices)] self.link = LocalLink() - addresses = ['F0:F1:F2:F3:F4:F5', 'F5:F4:F3:F2:F1:F0'] + addresses = [":".join([f"F{i}"] * 6) for i in range(num_devices)] self.controllers = [ - Controller('C1', link=self.link, public_address=addresses[0]), - Controller('C2', link=self.link, public_address=addresses[1]), + Controller(f'C{i+i}', link=self.link, public_address=addresses[i]) + for i in range(num_devices) ] self.devices = [ Device( - address=Address(addresses[0]), - host=Host(self.controllers[0], AsyncPipeSink(self.controllers[0])), - ), - Device( - address=Address(addresses[1]), - host=Host(self.controllers[1], AsyncPipeSink(self.controllers[1])), - ), + address=Address(addresses[i]), + host=Host(self.controllers[i], AsyncPipeSink(self.controllers[i])), + ) + for i in range(num_devices) ] - self.devices[0].on( - 'connection', lambda connection: self.on_connection(0, connection) - ) - self.devices[1].on( - 'connection', lambda connection: self.on_connection(1, connection) - ) + for i in range(num_devices): + self.devices[i].on( + self.devices[i].EVENT_CONNECTION, + functools.partial(self.on_connection, i), + ) self.paired = [ - asyncio.get_event_loop().create_future(), - asyncio.get_event_loop().create_future(), + asyncio.get_event_loop().create_future() for _ in range(num_devices) ] def on_connection(self, which, connection): @@ -77,19 +73,26 @@ def on_paired(self, which: int, keys: PairingKeys) -> None: async def setup_connection(self) -> None: # Start - await self.devices[0].power_on() - await self.devices[1].power_on() + for dev in self.devices: + await dev.power_on() - # Connect the two devices - await self.devices[0].connect(self.devices[1].random_address) - - # Check the post conditions - assert self.connections[0] is not None - assert self.connections[1] is not None + # Connect devices + for dev in self.devices[1:]: + connection_future = asyncio.get_running_loop().create_future() + dev.once(dev.EVENT_CONNECTION, connection_future.set_result) + await dev.start_advertising(advertising_interval_min=1.0) + await self.devices[0].connect(dev.random_address) + await connection_future def __getitem__(self, index: int) -> Device: return self.devices[index] + +# ----------------------------------------------------------------------------- +class TwoDevices(Devices): + def __init__(self) -> None: + super().__init__(2) + @classmethod async def create_with_connection(cls: type[Self]) -> Self: devices = cls()