Skip to content

Commit 4da1940

Browse files
committed
Refactor code to prep for timeout
1 parent 78be880 commit 4da1940

File tree

1 file changed

+21
-4
lines changed

1 file changed

+21
-4
lines changed

hub/hub.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
from common.allocation import Allocation
1414
from common.block import Block
1515
from common.encryption_key import EncryptionKey
16-
from common.owner import Owner
16+
from common.logging import LOGGER
1717
from common.share import Share
1818
from common.share_api import APIGetShareResponse, APIPostShareRequest
19+
from common.owner import Owner
1920
from common.utils import bytes_to_str, key_id_str_to_uuid, str_to_bytes
2021
from .peer_client import PeerClient
2122

@@ -111,8 +112,7 @@ async def store_share_received_from_client(
111112
share_index=api_post_share_request.share_index,
112113
value=share_value,
113114
)
114-
# TODO: Check if the key UUID is already present, and if so, do something sensible
115-
self._shares[share.user_key_id] = share
115+
self.store_share(share)
116116
peer_client.add_dske_signing_key_header_to_response(headers_temp_response)
117117
peer_client.delete_fully_used_blocks()
118118

@@ -143,7 +143,7 @@ async def get_share_requested_by_client(
143143
share = self.get_share(key_id)
144144
share.check_master_sae(master_sae_id)
145145
share.check_slave_sae(slave_sae_id)
146-
del self._shares[key_id]
146+
self.delete_share(key_id)
147147
encryption_key = EncryptionKey.from_pool(peer_client.local_pool, share.size)
148148
encrypted_share_value = encryption_key.encrypt(share.value)
149149
response = APIGetShareResponse(
@@ -155,6 +155,14 @@ async def get_share_requested_by_client(
155155
peer_client.delete_fully_used_blocks()
156156
return response
157157

158+
def store_share(self, share: Share):
159+
"""
160+
Store a share.
161+
"""
162+
if share.user_key_id in self._shares:
163+
LOGGER.error(f"Overwriting existing share for key ID {share.user_key_id}")
164+
self._shares[share.user_key_id] = share
165+
158166
def get_share(self, key_id: UUID) -> Share:
159167
"""
160168
Get a share by key ID. Raise an exception if the share is not found.
@@ -165,6 +173,15 @@ def get_share(self, key_id: UUID) -> Share:
165173
raise exceptions.UnknownKeyIDError(key_id) from exc
166174
return share
167175

176+
def delete_share(self, key_id: UUID):
177+
"""
178+
Delete a share by key ID.
179+
"""
180+
try:
181+
del self._shares[key_id]
182+
except KeyError as exc:
183+
raise exceptions.UnknownKeyIDError(key_id) from exc
184+
168185
def initiate_stop(self):
169186
"""
170187
Initiate stopping the hub.

0 commit comments

Comments
 (0)