diff --git a/README.md b/README.md index 82a68961..8889f94b 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,13 @@ Project will work on both python 2 and python 3 buy_order = my_trader.place_buy_order(stock_instrument, 1) sell_order = my_trader.place_sell_order(stock_instrument, 1) + # save the token so you don't need user/password again + # should probably encrypt this if saving to a file + token = my_trader.token + + # load from token rather than login + my_trader.token = token + ### Data returned * Quote data + Ask Price diff --git a/Robinhood/Robinhood.py b/Robinhood/Robinhood.py index 22308f6e..1720b4f3 100644 --- a/Robinhood/Robinhood.py +++ b/Robinhood/Robinhood.py @@ -2,6 +2,7 @@ import getpass import logging import warnings +from contextlib import suppress from enum import Enum import requests @@ -11,6 +12,7 @@ from six.moves import input from . import exceptions as RH_exception + class Robinhood: """wrapper class for fetching/parsing Robinhood endpoints""" endpoints = { @@ -52,7 +54,7 @@ class Robinhood: headers = None - auth_token = None + _auth_token = None logger = logging.getLogger('Robinhood') logger.addHandler(logging.NullHandler()) @@ -167,8 +169,7 @@ def login( raise RH_exception.TwoFactorRequired() #requires a second call to enable 2FA if 'token' in data.keys(): - self.auth_token = data['token'] - self.headers['Authorization'] = 'Token ' + self.auth_token + self.token = data['token'] return True return False @@ -187,10 +188,19 @@ def logout(self): warnings.warn('Failed to log out ' + repr(err_msg)) self.headers['Authorization'] = None - self.auth_token = None + self._auth_token = None return res + @property + def token(self): + return self._auth_token + + @token.setter + def token(self, token): + self._auth_token = token + self.headers['Authorization'] = 'Token ' + self._auth_token + ############################## #GET DATA ############################## @@ -198,7 +208,7 @@ def logout(self): def investment_profile(self): """fetch investment_profile""" res = self.session.get(self.endpoints['investment_profile']) - res.raise_for_status() #will throw without auth + res.raise_for_status() # will throw without auth data = res.json() return data @@ -208,7 +218,8 @@ def instruments(self, query=None, symbol=None, instrumentid=None): Args: query (str): search for ticker, e.g. by company name symbol (str): find instrument by it's symbol - instrumentid (str): instrumentid [uuid without the rest of URL] + instrumentid (str): instrumentid [uuid with or without the rest of + URL] Returns: (:obj:`dict`): JSON contents from `instruments` endpoint @@ -216,6 +227,9 @@ def instruments(self, query=None, symbol=None, instrumentid=None): """ res = None if instrumentid: + with suppress(IndexError): + # allow for url rather than making the user strip it + instrumentid = instrumentid.strip('/').split('/')[-1] res = self.session.get( self.endpoints['instrumentid'].format(instrumentid=instrumentid) ) @@ -231,14 +245,18 @@ def instruments(self, query=None, symbol=None, instrumentid=None): ) res.raise_for_status() res = res.json() - # if requesting all, return entire object so may paginate with ['next'] # Not sure variable returns types here is the best approach.. # API doesn't return pagination though when query is non-empty query=a ?? #if query is None and not (symbol or instrumentid): # return res # XXX perhaps should return an iterable to hide the pagination, e.g. res['next'], res['previous'] aspects - return res['results'] + try: + return res['results'] + except KeyError: + # res can either contain a list of dicts or a single dict + # return it as a list of one dict rather than variable return types + return [res] def instrument_splits(self, instrumentid=None): """fetch instruments splits endpoint @@ -946,7 +964,6 @@ def cancel_order(self,oid): res.raise_for_status() return res - def cancel_orders_all(self,instrument=None): """ convenience function to cancel all orders, optionally only for a given instrument