Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions arkiv_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
Sequence,
)
from typing import (
TYPE_CHECKING,
Any,
TypeAlias,
cast,
)

from eth_typing import ChecksumAddress, HexStr
from web3 import AsyncWeb3, WebSocketProvider
from web3 import AsyncHTTPProvider, AsyncWeb3, WebSocketProvider
from web3.contract import AsyncContract
from web3.exceptions import ProviderConnectionError, Web3RPCError, Web3ValueError
from web3.method import Method, default_root_munger
Expand Down Expand Up @@ -86,18 +88,23 @@
"Wei",
]

if TYPE_CHECKING:
HTTPClient: TypeAlias = AsyncWeb3[AsyncHTTPProvider]
WSClient: TypeAlias = AsyncWeb3[WebSocketProvider]
else:
HTTPClient: TypeAlias = AsyncWeb3
WSClient: TypeAlias = AsyncWeb3


logger = logging.getLogger(__name__)
"""@private"""


class ArkivHttpClient(AsyncWeb3):
class ArkivHttpClient(HTTPClient):
"""Subclass of AsyncWeb3 with added Arkiv methods."""

def __init__(self, rpc_url: str):
super().__init__(
AsyncWeb3.AsyncHTTPProvider(rpc_url, request_kwargs={"timeout": 60})
)
super().__init__(AsyncHTTPProvider(rpc_url, request_kwargs={"timeout": 60}))

self.eth.attach_methods(
{
Expand Down Expand Up @@ -215,7 +222,7 @@ async def query_entities(self, query: str) -> Sequence[QueryEntitiesResult]:

class ArkivROClient:
_http_client: ArkivHttpClient
_ws_client: AsyncWeb3
_ws_client: WSClient
_arkiv_contract: AsyncContract
_background_tasks: set[asyncio.Task[None]]

Expand All @@ -229,11 +236,11 @@ async def create_ro_client(rpc_url: str, ws_url: str) -> "ArkivROClient":
return ArkivROClient(rpc_url, await ArkivROClient._create_ws_client(ws_url))

@staticmethod
async def _create_ws_client(ws_url: str) -> "AsyncWeb3":
ws_client: AsyncWeb3 = await AsyncWeb3(WebSocketProvider(ws_url))
async def _create_ws_client(ws_url: str) -> "AsyncWeb3[WebSocketProvider]":
ws_client: WSClient = await AsyncWeb3(WebSocketProvider(ws_url))
return ws_client

def __init__(self, rpc_url: str, ws_client: AsyncWeb3) -> None:
def __init__(self, rpc_url: str, ws_client: WSClient) -> None:
"""Initialise the ArkivClient instance."""
self._http_client = ArkivHttpClient(rpc_url)
self._ws_client = ws_client
Expand All @@ -242,7 +249,7 @@ def __init__(self, rpc_url: str, ws_client: AsyncWeb3) -> None:
self._background_tasks = set()

def is_connected(
client: AsyncWeb3,
client: HTTPClient,
) -> Callable[[bool], Coroutine[Any, Any, bool]]:
async def inner(show_traceback: bool) -> bool:
try:
Expand Down Expand Up @@ -291,7 +298,7 @@ def http_client(self) -> ArkivHttpClient:
"""Get the underlying web3 http client."""
return self._http_client

def ws_client(self) -> AsyncWeb3:
def ws_client(self) -> WSClient:
"""Get the underlying web3 websocket client."""
return self._ws_client

Expand Down Expand Up @@ -582,7 +589,7 @@ async def create(rpc_url: str, ws_url: str, private_key: bytes) -> "ArkivClient"
"""
return await ArkivClient.create_rw_client(rpc_url, ws_url, private_key)

def __init__(self, rpc_url: str, ws_client: AsyncWeb3, private_key: bytes) -> None:
def __init__(self, rpc_url: str, ws_client: WSClient, private_key: bytes) -> None:
"""Initialise the ArkivClient instance."""
super().__init__(rpc_url, ws_client)

Expand Down
Loading