diff --git a/pybikes/bicicard.py b/pybikes/bicicard.py index ae0321e65..f19cd8955 100644 --- a/pybikes/bicicard.py +++ b/pybikes/bicicard.py @@ -5,13 +5,11 @@ import json from pybikes import BikeShareSystem, BikeShareStation, PyBikesScraper -from pybikes.contrib import TSTCache + AUTH_URL = '{endpoint}/api/certificado' STATIONS_URL = '{endpoint}/apiapp/SBancada/Estado/TodosSimple' -cache = TSTCache(delta=3600) - def stupidict(thing): """ makes all keys on a dict lowercase. Useful when different endpoints @@ -45,18 +43,23 @@ def auth_url(self): def stations_url(self): return STATIONS_URL.format(endpoint=self.endpoint) + def authorize(self, scraper): + cert = json.loads(scraper.request(self.auth_url, cache_for=3600)) + cert = stupidict(cert) + scraper.headers.update({ + 'Thumbprint': cert['thumbprint'], + }) + return scraper + + def update(self, scraper=None): - headers = { + scraper = scraper or PyBikesScraper() + scraper.headers.update({ 'User-Agent': 'Dalvik/2.1.0 (Linux; U; Android 11; ROCINANTE FIRE Build/9001', 'DeviceModel': 'ROCINANTE FIRE', - } - scraper = scraper or PyBikesScraper(cache) - cert = json.loads(scraper.request(self.auth_url, headers=headers)) - cert = stupidict(cert) - headers.update({ - 'Thumbprint': cert['thumbprint'], }) - data = scraper.request(self.stations_url, headers=headers, skip_cache=True) + self.authorize(scraper) + data = scraper.request(self.stations_url) info = map(stupidict, json.loads(data)) self.stations = list(map(BicicardStation, info)) diff --git a/pybikes/bicimad.py b/pybikes/bicimad.py index 8b46afdb5..0bf37e5fa 100644 --- a/pybikes/bicimad.py +++ b/pybikes/bicimad.py @@ -7,15 +7,12 @@ import json from pybikes import BikeShareSystem, BikeShareStation, PyBikesScraper -from pybikes.contrib import TSTCache COLORS = ['green', 'red', 'yellow', 'gray'] AUTH_URL = 'https://openapi.emtmadrid.es/v2/mobilitylabs/user/login/' FEED_URL = 'https://openapi.emtmadrid.es/v2/transport/bicimad/stations/' -cache = TSTCache(delta=3600) - class Bicimad(BikeShareSystem): authed = True @@ -30,31 +27,22 @@ def __init__(self, tag, meta, key): super(Bicimad, self).__init__(tag, meta) self.key = key - @staticmethod - def authorize(scraper, key): - request = scraper.request - + def authorize(self, scraper, key): headers = { 'passkey': key['passkey'], 'x-clientid': key['clientid'], } - accesstoken_content = scraper.request(AUTH_URL, headers=headers) + accesstoken_content = scraper.request(AUTH_URL, headers=headers, cache_for=3600) accesstoken = json.loads(accesstoken_content)['data'][0]['accessToken'] - def _request(*args, **kwargs): - headers = kwargs.get('headers', {}) - headers.update({'accesstoken': accesstoken}) - kwargs['headers'] = headers - return request(*args, **kwargs) - - scraper.request = _request + scraper.headers.update({'accesstoken': accesstoken}) def update(self, scraper=None): - scraper = scraper or PyBikesScraper(cache) + scraper = scraper or PyBikesScraper() - Bicimad.authorize(scraper, self.key) + self.authorize(scraper, self.key) - scraper_content = scraper.request(FEED_URL, skip_cache=True) + scraper_content = scraper.request(FEED_URL) data = json.loads(scraper_content) diff --git a/pybikes/contrib.py b/pybikes/contrib.py index 2a83dc7ec..6a6729418 100644 --- a/pybikes/contrib.py +++ b/pybikes/contrib.py @@ -1,4 +1,15 @@ +import re import time +import inspect + +try: + # Python 2 + from itertools import imap as map +except ImportError: + # Python 3 + pass + +from pybikes import BikeShareSystem, BikeShareStation class TSTCache(dict): @@ -23,10 +34,8 @@ def __setitem__(self, key, value): key = self.__transform_key__(key) if not self.__test_key__(key): return - self.store[key] = { - 'value': value, - 'ts': time.time() - } + + self.store[key] = self.__transform_value__(key, value) def __getitem__(self, key): key = self.__transform_key__(key) @@ -35,8 +44,11 @@ def __getitem__(self, key): if key not in self.store: raise KeyError('%s' % key) ts_value = self.store[key] - if time.time() - ts_value['ts'] > self.delta: + the_time = time.time() + + if the_time - ts_value['ts'] > self.__get_delta__(key, ts_value): raise KeyError('%s' % key) + return ts_value['value'] def __contains__(self, key): @@ -53,8 +65,132 @@ def __iter__(self): def __len__(self): return len(self.store) + def get(self, key, default=None): + return self[key] if key in self else default + + def items(self): + return ((k, self[k]) for k, v in self.store.items() if k in self) + + def keys(self): + return (k for k in self.store.keys() if k in self) + def __test_key__(self, key): return True def __transform_key__(self, key): return key + + def __transform_value__(self, key, value): + return { + 'value': value, + 'ts': time.time(), + } + + def __get_delta__(self, key, entry): + return self.delta + + def __repr__(self): + return self.store.__repr__() + + def flush(self): + for k, v in list(self.store.items()): + if k in self: + continue + del self.store[k] + + +class PBCache(TSTCache): + """ PBCache stands for PyBikes Cache + + It's the same as the TSTCache, but annotates entries with callstack + information based on being called from a bike share system + + Gets initialized with a list of defined deltas per regex rule. Said + delta will be aplied to entries based on its annotation. + """ + + def __init__(self, * args, ** kwargs): + self.deltas = kwargs.pop('deltas', []) + super(PBCache, self).__init__(* args, ** kwargs) + + def __get_annotation__(self, key): + """ introspect call stack to find a bike share system """ + + def get_frame(entry): + """ python 2 and 3 compatible frame getter """ + if isinstance(entry, tuple): + return entry[0] + else: + return entry.frame + + def get_function(finfo): + """ python 2 and 3 compatible function getter """ + if isinstance(finfo, tuple): + return finfo[3] + else: + return finfo.function + + valid_types = (BikeShareSystem, ) + stack = inspect.stack() + selfs = map(lambda f: (get_frame(f).f_locals.get('self'), f), stack) + bss = filter(lambda f: isinstance(f[0], valid_types), selfs) + + some_bikeshare, frame_info = next(iter(bss), (None, None)) + + # no bike share found on call stack, bail + if not some_bikeshare: + return None + + # create an annotation based on bike share found + # ie: 'gbfs::citi-bike-nyc::update::https://some-url' + annotation = '{cls}::{tag}::{method}::{key}'.format( + cls=some_bikeshare.__class__.__name__.lower(), + tag=some_bikeshare.tag, + method=get_function(frame_info), + key=key, + ) + + return annotation + + def __match_delta__(self, key): + annotation = self.__get_annotation__(key) + + if not annotation: + return None, None + + # get a delta value based on annotation + # list of deltas are a list of dicts like + # - it's a list because it keeps order + # - it's made of dicts because it's a safe json structure + # [ + # {'gbfs::.*::update': 100}, + # {'gbfs::some-tag::update::some-url': 200}, + # ] + + # iterate items on delta list + deltas = map(lambda e: e.items(), self.deltas) + # flatten iterator + deltas = (e for it in deltas for e in it) + apply_rules = filter(lambda r: re.match(r[0], annotation), deltas) + _, delta = next(iter(apply_rules), (None, self.delta)) + + return delta, annotation + + def __transform_value__(self, key, value): + delta, annotation = self.__match_delta__(key) + return { + 'value': value, + 'ts': time.time(), + 'delta': delta, + 'annotation': annotation, + } + + def __get_delta__(self, key, entry): + delta = entry.get('delta') + # guards against a delta = 0 triggering return of self.delta + return delta if delta is not None else self.delta + + def set_with_delta(self, key, value, delta): + entry = self.__transform_value__(key, value) + entry['delta'] = delta + self.store[key] = entry diff --git a/pybikes/deutschebahn.py b/pybikes/deutschebahn.py index f94026251..f7e894130 100644 --- a/pybikes/deutschebahn.py +++ b/pybikes/deutschebahn.py @@ -5,7 +5,6 @@ from pybikes import PyBikesScraper from pybikes.gbfs import Gbfs -from pybikes.contrib import TSTCache FEED_URL = 'https://apis.deutschebahn.com/db-api-marketplace/apis/shared-mobility-gbfs/2-2/de/{provider}/gbfs' @@ -52,12 +51,9 @@ class Callabike(DB): provider = 'CallABike' - # caches the feed for 60s - cache = TSTCache(delta=60) - def __init__(self, * args, ** kwargs): super(Callabike, self).__init__(* args, provider=Callabike.provider, ** kwargs) def update(self, scraper=None): - scraper = scraper or PyBikesScraper(self.cache) + scraper = scraper or PyBikesScraper() super(Callabike, self).update(scraper) diff --git a/pybikes/gbfs.py b/pybikes/gbfs.py index 1110a5c10..7d6e2bba5 100644 --- a/pybikes/gbfs.py +++ b/pybikes/gbfs.py @@ -112,6 +112,22 @@ def get_feeds(self, url, scraper, force_https): return {feed['name']: feed['url'] for feed in feeds} + # We use these dumb functions to mark requests for caching with a + # 'gbfs::tag::get_station_information::url' signature + def get_station_information(self, scraper, feeds): + return json.loads( + scraper.request(feeds['station_information']) + )['data']['stations'] + + def get_station_status(self, scraper, feeds): + return json.loads( + scraper.request(feeds['station_status']) + )['data']['stations'] + + def get_vehicle_types(self, scraper, feeds): + return json.loads( + scraper.request(feeds['vehicle_types']) + )['data']['vehicle_types'] def update(self, scraper=None): scraper = scraper or PyBikesScraper() @@ -122,21 +138,17 @@ def update(self, scraper=None): feeds = self.get_feeds(self.feed_url, scraper, self.force_https) # Station Information and Station Status data retrieval - station_information = json.loads( - scraper.request(feeds['station_information']) - )['data']['stations'] - station_status = json.loads( - scraper.request(feeds['station_status']) - )['data']['stations'] + station_information = self.get_station_information(scraper, feeds) + station_status = self.get_station_status(scraper, feeds) if 'vehicle_types' in feeds: - vehicle_info = json.loads(scraper.request(feeds['vehicle_types'])) + vehicle_info = self.get_vehicle_types(scraper, feeds) # map vehicle id to vehicle info AND extra info resolver # for direct access vehicles = { # TODO: ungrok this line v.get('vehicle_type_id', 'err'): (v, next(iter((r for q, r in self.vehicle_taxonomy if q(v))), lambda v: {})) - for v in vehicle_info['data'].get('vehicle_types', []) + for v in vehicle_info } else: vehicles = {} diff --git a/pybikes/nextbike.py b/pybikes/nextbike.py index b1c88f393..7dfc5e14e 100644 --- a/pybikes/nextbike.py +++ b/pybikes/nextbike.py @@ -8,16 +8,10 @@ from .base import BikeShareSystem, BikeShareStation from pybikes.utils import PyBikesScraper, filter_bounds -from pybikes.contrib import TSTCache -__all__ = ['Nextbike', 'NextbikeStation'] BASE_URL = 'https://{hostname}/maps/nextbike-live.xml?domains={domain}' # NOQA -# Since most networks share the same hostname, there's no need to keep hitting -# the endpoint on the same urls. This caches the feed for 60s -cache = TSTCache(delta=60) - class Nextbike(BikeShareSystem): sync = True @@ -38,7 +32,7 @@ def __init__(self, tag, meta, domain, city_uid, hostname='maps.nextbike.net', def update(self, scraper=None): if scraper is None: - scraper = PyBikesScraper(cache) + scraper = PyBikesScraper() domain_xml = etree.fromstring( scraper.request(self.url).encode('utf-8') ) diff --git a/pybikes/publibike.py b/pybikes/publibike.py index 339c98e64..bf9a5e803 100644 --- a/pybikes/publibike.py +++ b/pybikes/publibike.py @@ -6,13 +6,9 @@ import json from pybikes import BikeShareSystem, BikeShareStation, PyBikesScraper -from pybikes.contrib import TSTCache FEED_URL = 'https://api.publibike.ch/v1/public/partner/stations' -# caches the feed for 60s -cache = TSTCache(delta=60) - class Publibike(BikeShareSystem): sync = True @@ -29,9 +25,7 @@ def __init__(self, tag, meta, city_uid): self.uid = city_uid def update(self, scraper=None): - if scraper is None: - # use cached feed if possible - scraper = PyBikesScraper(cache) + scraper = scraper or PyBikesScraper() stations = json.loads( scraper.request(FEED_URL).encode('utf-8') diff --git a/pybikes/utils.py b/pybikes/utils.py index 68890aab3..d03afe7cc 100644 --- a/pybikes/utils.py +++ b/pybikes/utils.py @@ -3,7 +3,6 @@ # Distributed under the AGPL license, see LICENSE.txt import os -import re try: # Python 2 from itertools import imap as map @@ -16,6 +15,7 @@ from shapely.geometry import Polygon, Point, box from pybikes.base import BikeShareStation +from pybikes.contrib import PBCache def str2bool(v): @@ -45,19 +45,26 @@ class PyBikesScraper(object): requests_timeout = 300 retry = False retry_opts = {} + use_cache = False + cache_statuses = [200, 203, 204, 206, 300, 301, 404, 405, 410, 414, 501] def __init__(self, cachedict=None, headers=None): self.headers = headers if isinstance(headers, dict) else {} self.headers.setdefault('User-Agent', 'PyBikes') self.proxies = {} self.session = requests.session() - self.cachedict = cachedict + + # Implicit enable cache if we got a cache as argument + self.use_cache = cachedict is not None + # Always have a cache with delta 0 in hand for explicit caching + self.cachedict = cachedict if cachedict is not None else PBCache(delta=0) def setUserAgent(self, user_agent): self.headers['User-Agent'] = user_agent def request(self, url, method='GET', params=None, data=None, raw=False, - headers=None, default_encoding='UTF-8', skip_cache=False): + headers=None, default_encoding='UTF-8', skip_cache=False, + cache_for=None): if self.retry: retries = Retry(** self.retry_opts) @@ -66,10 +73,14 @@ def request(self, url, method='GET', params=None, data=None, raw=False, _headers = self.headers.copy() _headers.update(headers or {}) + response = None + must_cache = self.use_cache and not skip_cache or cache_for + # XXX proper encode arguments for proper call args -> response - if self.cachedict and url in self.cachedict and not skip_cache: - response = self.cachedict[url] - else: + if must_cache: + response = self.cachedict.get(url) + + if not response: response = self.session.request( method=method, url=url, @@ -83,6 +94,13 @@ def request(self, url, method='GET', params=None, data=None, raw=False, timeout=self.requests_timeout, ) + if must_cache and response.status_code in self.cache_statuses: + # quack + if cache_for and hasattr(self.cachedict, 'set_with_delta'): + self.cachedict.set_with_delta(url, response, cache_for) + else: + self.cachedict[url] = response + data = response.text # Somehow requests defaults to ISO-8859-1 (when no encoding @@ -97,10 +115,8 @@ def request(self, url, method='GET', params=None, data=None, raw=False, if 'set-cookie' in response.headers: self.headers['Cookie'] = response.headers['set-cookie'] - self.last_request = response - if self.cachedict is not None: - self.cachedict[url] = response + self.last_request = response return data diff --git a/tests/test_contrib.py b/tests/test_contrib.py new file mode 100644 index 000000000..98c4517b1 --- /dev/null +++ b/tests/test_contrib.py @@ -0,0 +1,262 @@ +import pytest + +try: + # python 3 + from unittest.mock import patch +except ImportError: + # python 2 + from mock import patch + + +from pybikes import BikeShareSystem +from pybikes.contrib import TSTCache, PBCache + + +# Acts as time module, with some extra wizz +class FakeTime: + + the_time = 0.0 + + def time(self): + return self.the_time + + def travel(self, delta): + self.the_time += delta + + +class TestTSTCache: + + @pytest.fixture() + def cache(self): + self._store = {} + return TSTCache(self._store, 60) + + @pytest.fixture() + def store(self, cache): + return self._store + + @pytest.fixture() + def time_machine(self): + it = FakeTime() + with patch('time.time', it.time): + yield it + + def test_set_get(self, cache, time_machine): + # time is 0.0 + cache['foo'] = 'bar' + assert cache['foo'] == 'bar' + + # time is 30.0 + time_machine.travel(30) + assert cache['foo'] == 'bar' + cache['bar'] = 'baz' + assert cache['bar'] == 'baz' + + # time is 70.0 + time_machine.travel(40) + assert 'foo' not in cache + with pytest.raises(KeyError): + cache['foo'] + + assert cache['bar'] == 'baz' + + # time is 100.0 + time_machine.travel(30) + assert 'bar' not in cache + with pytest.raises(KeyError): + cache['bar'] + + + +class FooBikeShare(BikeShareSystem): + domain = 'foo.com' + + def update(self, scraper): + scraper.request('https://%s/%s' % (self.domain, self.tag)) + + +class BarBikeShare(FooBikeShare): + domain = 'bar.com' + + def auth(self, scraper): + scraper.request('https://%s/auth' % (self.domain)) + + def update(self, scraper): + self.auth(scraper) + scraper.request('https://%s/%s' % (self.domain, self.tag)) + + +class BazBikeShare(FooBikeShare): + domain = 'baz.com' + + +class FuzzBikeShare(FooBikeShare): + domain = 'fuzz.com' + + +class TestPBCache(TestTSTCache): + @pytest.fixture() + def cache(self, deltas): + self._store = {} + return PBCache(self._store, 60, deltas=deltas) + + @pytest.fixture() + def deltas(self): + return [ + {'foobikeshare::foo-corp-haven::.*': 10}, + {'foobikeshare::foo-end-city::.*': 20}, + {'foobikeshare::.*': 30}, + {'barbikeshare::.*::auth::.*': 100}, + {'barbikeshare::bar-mad-isle::update::.*': 40}, + {'barbikeshare::bar-neo-troy::update::.*': 60}, + {'.*::.*::.*::https://baz.com': 120}, + ] + + @pytest.fixture() + def instances(self): + return [ + FooBikeShare('foo-corp-haven', meta={'name': 'Foo Corp Haven', 'system': 'foo'}), + FooBikeShare('foo-end-city', meta={'name': 'Foo End City', 'system': 'foo'}), + FooBikeShare('foo-devil-hold', meta={'name': 'Foo Devil Hold', 'system': 'foo'}), + + BarBikeShare('bar-mad-isle', meta={'name': 'Bar Mad Isle', 'system': 'bar'}), + BarBikeShare('bar-neo-troy', meta={'name': 'Bar Neo Troy', 'system': 'bar'}), + + BazBikeShare('baz-greed-city', meta={'name': 'Baz Greed City', 'system': 'baz'}), + BazBikeShare('baz-droid-town', meta={'name': 'Baz Droid Town', 'system': 'baz'}), + + FuzzBikeShare('fuzz-neo-titania', meta={'name': 'Fuzz Neo Titania', 'system': 'fuzz'}), + FuzzBikeShare('fuzz-chaos-trail', meta={'name': 'Fuzz Chaos Trail', 'system': 'fuzz'}), + ] + + @pytest.fixture() + def instance_map(self, instances): + return {i.tag: i for i in instances} + + @pytest.fixture() + def scraper(self, cache): + class Scraper: + hits = 0 + miss = 0 + + def request(self, url, * args, ** kwargs): + if url in cache: + self.hits += 1 + return cache[url] + + self.miss += 1 + + data = 'Some request data' + cache[url] = data + return data + + def reset(self): + self.hits = 0 + self.miss = 0 + + return Scraper() + + def test_deltas(self, cache, scraper, instances, time_machine): + # time is 0 + [i.update(scraper) for i in instances] + assert scraper.hits == 1 + assert scraper.miss == 10 + + scraper.reset() + [i.update(scraper) for i in instances] + assert scraper.hits == 11 + assert scraper.miss == 0 + + scraper.reset() + time_machine.travel(1000) + + [i.update(scraper) for i in instances] + assert scraper.hits == 1 + assert scraper.miss == 10 + + def test_tag_delta(self, cache, scraper, instances, time_machine): + # time is 0 + [i.update(scraper) for i in instances if 'foo' in i.tag] + assert scraper.miss == 3 + + time_machine.travel(15) + scraper.reset() + + [i.update(scraper) for i in instances if 'foo' in i.tag] + assert scraper.hits == 2 + assert scraper.miss == 1 + + time_machine.travel(15) + scraper.reset() + + [i.update(scraper) for i in instances if 'foo' in i.tag] + assert scraper.hits == 1 + assert scraper.miss == 2 + + def test_method_delta(self, cache, scraper, instances, time_machine): + # time is 0 + [i.update(scraper) for i in instances if 'bar' in i.tag] + assert scraper.miss == 3 + assert scraper.hits == 1 + + time_machine.travel(90) + scraper.reset() + + [i.update(scraper) for i in instances if 'bar' in i.tag] + assert scraper.miss == 2 + assert scraper.hits == 2 + + def test_url_delta(self, cache, scraper, instances, time_machine): + # time is 0 + [i.update(scraper) for i in instances if 'baz' in i.tag] + assert scraper.miss == 2 + assert scraper.hits == 0 + + time_machine.travel(100) + scraper.reset() + + [i.update(scraper) for i in instances if 'baz' in i.tag] + assert scraper.miss == 0 + assert scraper.hits == 2 + + time_machine.travel(100) + scraper.reset() + + [i.update(scraper) for i in instances if 'baz' in i.tag] + assert scraper.miss == 2 + assert scraper.hits == 0 + + def test_default_delta(self, cache, scraper, instances, time_machine): + # time is 0 + [i.update(scraper) for i in instances if 'fuzz' in i.tag] + assert scraper.miss == 2 + assert scraper.hits == 0 + + time_machine.travel(30) + scraper.reset() + + [i.update(scraper) for i in instances if 'fuzz' in i.tag] + assert scraper.miss == 0 + assert scraper.hits == 2 + + time_machine.travel(60) + scraper.reset() + + [i.update(scraper) for i in instances if 'fuzz' in i.tag] + assert scraper.miss == 2 + assert scraper.hits == 0 + + def test_with_delta(self, cache, scraper, time_machine): + cache.set_with_delta('foo', 'Foobar', 30) + assert cache['foo'] == 'Foobar' + + time_machine.travel(50) + + with pytest.raises(KeyError): + cache['foo'] + + def test_delta_zero(self, cache): + assert cache.__get_delta__('', {'delta': 0}) == 0 + + def test_default_delta_if_None(self, cache): + assert cache.__get_delta__('', {'delta': None}) == cache.delta diff --git a/tests/test_utils.py b/tests/test_utils.py index a1cb58f29..bdc1a925e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,17 @@ -from pybikes import BikeShareStation +try: + # python 3 + from unittest.mock import patch, Mock +except ImportError: + # python 2 + from mock import patch, Mock + +import pytest +import requests + +from pybikes import BikeShareStation, PyBikesScraper from pybikes.utils import filter_bounds + def test_filter_bounds(): """ Tests that filter_bounds utils function correctly filters stations out of a given number of bounds. Function must accept multiple lists @@ -76,3 +87,147 @@ def test_filter_bounds_with_key(): ) assert in_bounds == list(result) + + +class TestPyBikesScraper: + + class FakeSession(Mock): + response_data = { + 'headers': {}, + 'status_code': 200, + 'text': 'hi' + } + + def request(self, * args, ** kwargs): + r = Mock(requests.Request, * args, ** kwargs) + return Mock(requests.Response, request=r, ** self.response_data) + + @pytest.fixture() + def fake_session(self): + session = TestPyBikesScraper.FakeSession() + session.request = Mock(side_effect=session.request) + with patch('requests.sessions.Session.request', session.request): + yield session + + def test_default_useragent(self, fake_session): + scraper = PyBikesScraper() + scraper.request('https://citybik.es') + req = scraper.last_request.request + assert req.headers['User-Agent'] == 'PyBikes' + + def test_base_headers(self, fake_session): + headers = { + 'Hello-World': 42, + 'Foo': 'Bar', + } + scraper = PyBikesScraper(headers=headers) + scraper.request('https://citybik.es') + req = scraper.last_request.request + + assert req.headers == headers + + def test_req_headers(self, fake_session): + headers = { + 'Hello-World': 42, + } + + req_headers = { + 'Foo': 'Bar', + 'Hello-World': 45, + } + + scraper = PyBikesScraper(headers=headers) + scraper.request('https://citybik.es', headers=req_headers) + req = scraper.last_request.request + + assert req.headers == dict(req.headers, ** req_headers) + # checks that original headers are unaffected + assert headers != req_headers + + # next request uses base headers + scraper.request('https://citybik.es') + req = scraper.last_request.request + assert req.headers == headers + + def test_set_cookie(self, fake_session): + scraper = PyBikesScraper() + fake_session.response_data['headers']['set-cookie'] = 'Hello' + scraper.request('https://citybik.es') + assert scraper.headers['Cookie'] == 'Hello' + + def test_cache_disabled(self, fake_session): + scraper = PyBikesScraper() + scraper.request('https://citybik.es') + + def test_uses_cache_if_provided(self, fake_session): + cache = {} + scraper = PyBikesScraper(cache) + scraper.request('https://citybik.es') + assert 'https://citybik.es' in cache + assert fake_session.request.called + + fake_session.request.reset_mock() + + scraper.request('https://citybik.es') + assert not fake_session.request.called + + def test_skip_cache(self, fake_session): + cache = {} + scraper = PyBikesScraper(cache) + scraper.request('https://citybik.es') + assert 'https://citybik.es' in cache + + fake_session.request.reset_mock() + + scraper.request('https://citybik.es', skip_cache=True) + assert fake_session.request.called + + def test_disable_cache(self, fake_session): + cache = {} + scraper = PyBikesScraper(cache) + scraper.request('https://citybik.es') + assert 'https://citybik.es' in cache + + fake_session.request.reset_mock() + + scraper.use_cache = False + scraper.request('https://citybik.es') + assert fake_session.request.called + + fake_session.request.reset_mock() + + scraper.use_cache = True + scraper.request('https://citybik.es') + assert not fake_session.request.called + + def test_cache_statuses(self, fake_session): + cache = {} + scraper = PyBikesScraper(cache) + fake_session.response_data['status_code'] = 500 + scraper.request('https://citybik.es') + + assert 'https://citybik.es' not in cache + + fake_session.request.reset_mock() + fake_session.response_data['status_code'] = 200 + + scraper.request('https://citybik.es') + assert 'https://citybik.es' in cache + + scraper.request('https://citybik.es') + scraper.request('https://citybik.es') + scraper.request('https://citybik.es') + scraper.request('https://citybik.es') + + assert fake_session.request.call_count == 1 + + def test_forwards_cache_for(self, fake_session): + scraper = PyBikesScraper() + mock = Mock() + + with patch('pybikes.contrib.PBCache.set_with_delta', mock): + scraper.request('https://citybik.es', cache_for=3600) + + response = scraper.last_request + + assert mock.called_with('https://citybik.es', response, 3600)