Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 73 additions & 62 deletions librespot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,12 @@ def __init__(self, session: Session):
self.__base_url = "https://{}".format(ApResolver.get_random_spclient())

def build_request(
self,
method: str,
suffix: str,
headers: typing.Union[None, CaseInsensitiveDict[str, str]],
body: typing.Union[None, bytes],
url: typing.Union[None, str],
self,
method: str,
suffix: str,
headers: typing.Union[None, CaseInsensitiveDict[str, str]],
body: typing.Union[None, bytes],
url: typing.Union[None, str],
) -> requests.PreparedRequest:
"""

Expand Down Expand Up @@ -122,11 +122,11 @@ def build_request(
return request.prepare()

def send(
self,
method: str,
suffix: str,
headers: typing.Union[None, CaseInsensitiveDict[str, str]],
body: typing.Union[None, bytes],
self,
method: str,
suffix: str,
headers: typing.Union[None, CaseInsensitiveDict[str, str]],
body: typing.Union[None, bytes],
) -> requests.Response:
"""

Expand All @@ -144,12 +144,12 @@ def send(
return response

def sendToUrl(
self,
method: str,
url: str,
suffix: str,
headers: typing.Union[None, CaseInsensitiveDict[str, str]],
body: typing.Union[None, bytes],
self,
method: str,
url: str,
suffix: str,
headers: typing.Union[None, CaseInsensitiveDict[str, str]],
body: typing.Union[None, bytes],
) -> requests.Response:
"""

Expand Down Expand Up @@ -194,10 +194,10 @@ def put_connect_state(self, connection_id: str,

def get_ext_metadata(self, extension_kind: ExtensionKind, uri: str):
headers = CaseInsensitiveDict({"content-type": "application/x-protobuf"})
req = EntityRequest(entity_uri=uri, query=[ExtensionQuery(extension_kind=extension_kind),])
req = EntityRequest(entity_uri=uri, query=[ExtensionQuery(extension_kind=extension_kind), ])

response = self.send("POST", "/extended-metadata/v0/extended-metadata",
headers, BatchedEntityRequest(entity_request=[req,]).SerializeToString())
headers, BatchedEntityRequest(entity_request=[req, ]).SerializeToString())
ApiClient.StatusCodeException.check_status(response)

body = response.content
Expand All @@ -208,7 +208,8 @@ def get_ext_metadata(self, extension_kind: ExtensionKind, uri: str):
proto.ParseFromString(body)
entityextd = proto.extended_metadata.pop().extension_data.pop()
if entityextd.header.status_code != 200:
raise ConnectionError("Extended Metadata request failed: Status code {}".format(entityextd.header.status_code))
raise ConnectionError(
"Extended Metadata request failed: Status code {}".format(entityextd.header.status_code))
mdb: bytes = entityextd.extension_data.value
return mdb

Expand Down Expand Up @@ -754,7 +755,7 @@ def __worker_callback(self, event_builder: EventBuilder):
event_builder, ex))

def send_event(self, event_or_builder: typing.Union[GenericEvent,
EventBuilder]):
EventBuilder]):
"""

:param event_or_builder: typing.Union[GenericEvent:
Expand Down Expand Up @@ -892,6 +893,7 @@ class Session(Closeable, MessageListener, SubListener):
__audio_key_manager: typing.Union[AudioKeyManager, None] = None
__auth_lock = threading.Condition()
__auth_lock_bool = False
__is_active = False
__cache_manager: typing.Union[CacheManager, None]
__cdn_manager: typing.Union[CdnManager, None]
__channel_manager: typing.Union[ChannelManager, None] = None
Expand Down Expand Up @@ -975,6 +977,7 @@ def authenticate(self,
self.__search = SearchManager(self)
self.__event_service = EventService(self)
self.__auth_lock_bool = False
self.__is_active = True
self.__auth_lock.notify_all()
self.dealer().connect()
self.logger.info("Authenticated as {}!".format(
Expand Down Expand Up @@ -1085,7 +1088,7 @@ def connect(self) -> None:
if not pkcs1_v1_5.verify(
sha1,
ap_response_message_proto.challenge.login_crypto_challenge.
diffie_hellman.gs_signature,
diffie_hellman.gs_signature,
):
raise RuntimeError("Failed signature check!")
# Solve challenge
Expand Down Expand Up @@ -1192,10 +1195,11 @@ def get_user_attribute(self, key: str, fallback: str = None) -> str:

def is_valid(self) -> bool:
""" """
if self.__closed:
if self.__closed or not self.__is_active:
return False
self.__wait_auth_lock()
return self.__ap_welcome is not None and self.connection is not None
# Do not wait for lock if we just want to check validity to avoid blocking
# self.__wait_auth_lock()
return self.__ap_welcome is not None and self.connection is not None and self.__receiver is not None and self.__receiver._Receiver__running

def mercury(self) -> MercuryClient:
""" """
Expand Down Expand Up @@ -1352,11 +1356,11 @@ def __authenticate_partial(self,
self.__stored_str = base64.b64encode(
json.dumps({
"username":
self.__ap_welcome.canonical_username,
self.__ap_welcome.canonical_username,
"credentials":
base64.b64encode(reusable).decode(),
base64.b64encode(reusable).decode(),
"type":
reusable_type,
reusable_type,
}).encode()).decode()
with open(self.__inner.conf.stored_credentials_file, "w") as f:
json.dump(
Expand Down Expand Up @@ -1629,7 +1633,7 @@ def stored_file(self,
pass
return self

def oauth(self, oauth_url_callback, success_page_content = None) -> Session.Builder:
def oauth(self, oauth_url_callback, success_page_content=None) -> Session.Builder:
"""
Login via OAuth

Expand All @@ -1638,7 +1642,8 @@ def oauth(self, oauth_url_callback, success_page_content = None) -> Session.Buil
"""
if os.path.isfile(self.conf.stored_credentials_file):
return self.stored_file(None)
self.login_credentials = OAuth(MercuryRequests.keymaster_client_id, "http://127.0.0.1:5588/login", oauth_url_callback).set_success_page_content(success_page_content).flow()
self.login_credentials = OAuth(MercuryRequests.keymaster_client_id, "http://127.0.0.1:5588/login",
oauth_url_callback).set_success_page_content(success_page_content).flow()
return self

def user_pass(self, username: str, password: str) -> Session.Builder:
Expand Down Expand Up @@ -1704,20 +1709,20 @@ class Configuration:
retry_on_chunk_error: bool

def __init__(
self,
# proxy_enabled: bool,
# proxy_type: Proxy.Type,
# proxy_address: str,
# proxy_port: int,
# proxy_auth: bool,
# proxy_username: str,
# proxy_password: str,
cache_enabled: bool,
cache_dir: str,
do_cache_clean_up: bool,
store_credentials: bool,
stored_credentials_file: str,
retry_on_chunk_error: bool,
self,
# proxy_enabled: bool,
# proxy_type: Proxy.Type,
# proxy_address: str,
# proxy_port: int,
# proxy_auth: bool,
# proxy_username: str,
# proxy_password: str,
cache_enabled: bool,
cache_dir: str,
do_cache_clean_up: bool,
store_credentials: bool,
stored_credentials_file: str,
retry_on_chunk_error: bool,
):
# self.proxyEnabled = proxy_enabled
# self.proxyType = proxy_type
Expand Down Expand Up @@ -2009,12 +2014,12 @@ class Inner:
preferred_locale: str

def __init__(
self,
device_type: Connect.DeviceType,
device_name: str,
preferred_locale: str,
conf: Session.Configuration,
device_id: str = None,
self,
device_type: Connect.DeviceType,
device_name: str,
preferred_locale: str,
conf: Session.Configuration,
device_id: str = None,
):
self.preferred_locale = preferred_locale
self.conf = conf
Expand Down Expand Up @@ -2061,7 +2066,12 @@ def run(self) -> None:
if self.__running:
self.__session.logger.fatal(
"Failed reading packet! {}".format(ex))
self.__session.reconnect()
try:
self.__session.reconnect()
except Exception as e:
self.__session.logger.fatal(f"Reconnection failed: {e}")
self.__session._Session__is_active = False
self.__running = False
break
if not self.__running:
break
Expand Down Expand Up @@ -2102,16 +2112,16 @@ def anonymous():
self.__session.logger.debug("Received 0x10: {}".format(
util.bytes_to_hex(packet.payload)))
elif cmd in [
Packet.Type.mercury_sub,
Packet.Type.mercury_unsub,
Packet.Type.mercury_event,
Packet.Type.mercury_req,
Packet.Type.mercury_sub,
Packet.Type.mercury_unsub,
Packet.Type.mercury_event,
Packet.Type.mercury_req,
]:
self.__session.mercury().dispatch(packet)
elif cmd in [Packet.Type.aes_key, Packet.Type.aes_key_error]:
self.__session.audio_key().dispatch(packet)
elif cmd in [
Packet.Type.channel_error, Packet.Type.stream_chunk_res
Packet.Type.channel_error, Packet.Type.stream_chunk_res
]:
self.__session.channel().dispatch(packet)
elif cmd == Packet.Type.product_info:
Expand Down Expand Up @@ -2324,7 +2334,7 @@ def get_token(self, *scopes) -> StoredToken:

def login5(self, scopes: typing.List[str]) -> typing.Union[StoredToken, None]:
"""Submit Login5 request for a fresh access token"""

if self.__session.ap_welcome():
login5_request = Login5.LoginRequest()
login5_request.client_info.client_id = MercuryRequests.keymaster_client_id
Expand All @@ -2341,20 +2351,21 @@ def login5(self, scopes: typing.List[str]) -> typing.Union[StoredToken, None]:
headers=CaseInsensitiveDict({
"Content-Type": "application/x-protobuf",
"Accept": "application/x-protobuf"
}))
}))

if response.status_code == 200:
login5_response = Login5.LoginResponse()
login5_response.ParseFromString(response.content)

if login5_response.HasField('ok'):
self.logger.info("Login5 authentication successful, got access token".format(login5_response.ok.access_token))
self.logger.info(
"Login5 authentication successful, got access token".format(login5_response.ok.access_token))
token = TokenProvider.StoredToken({
"expiresIn": login5_response.ok.access_token_expires_in, # approximately one hour
"expiresIn": login5_response.ok.access_token_expires_in, # approximately one hour
"accessToken": login5_response.ok.access_token,
"scope": scopes
})
return token
return token
else:
self.logger.warning("Login5 authentication failed: {}".format(login5_response.error))
else:
Expand All @@ -2379,7 +2390,7 @@ def expired(self) -> bool:
""" """
return self.timestamp + (self.expires_in - TokenProvider.
token_expire_threshold) * 1000 * 1000 < int(
time.time_ns() / 1000)
time.time_ns() / 1000)

def has_scope(self, scope: str) -> bool:
"""
Expand Down
Loading