diff --git a/splunk_handler/__init__.py b/splunk_handler/__init__.py index bdb9b14..ab916fc 100644 --- a/splunk_handler/__init__.py +++ b/splunk_handler/__init__.py @@ -51,7 +51,8 @@ def __init__(self, host, port, token, index, force_keep_ahead=False, hostname=None, protocol='https', proxies=None, queue_size=DEFAULT_QUEUE_SIZE, record_format=False, retry_backoff=2.0, retry_count=5, source=None, - sourcetype='text', timeout=60, url=None, verify=True): + sourcetype='text', timeout=60, url=None, verify=True, + additional_headers=None): """ Args: host (str): The Splunk host param @@ -74,6 +75,7 @@ def __init__(self, host, port, token, index, timeout (float): The time to wait for a response from Splunk url (str): Override of the url to send the event to verify (bool): Whether to perform ssl certificate validation + additional_headers (dict): The headers to add in Splunk request """ global instances @@ -105,6 +107,7 @@ def __init__(self, host, port, token, index, self.proxies = proxies self.record_format = record_format self.processing_payload = False + self.additional_headers = additional_headers if not url: self.url = '%s://%s:%s/services/collector/event' % (self.protocol, self.host, self.port) else: @@ -256,12 +259,16 @@ def _splunk_worker(self, payload=None): self.write_debug_log("Payload available for sending") self.write_debug_log("Destination URL is " + self.url) + headers = {'Authorization': "Splunk %s" % self.token} + if self.additional_headers is not None: + headers.update(**self.additional_headers) + try: self.write_debug_log("Sending payload: " + payload) r = self.session.post( self.url, data=payload, - headers={'Authorization': "Splunk %s" % self.token}, + headers=headers, verify=self.verify, timeout=self.timeout ) diff --git a/tests/test_splunk_handler.py b/tests/test_splunk_handler.py index 51232a5..e4ef13d 100644 --- a/tests/test_splunk_handler.py +++ b/tests/test_splunk_handler.py @@ -22,6 +22,7 @@ SPLUNK_DEBUG = False SPLUNK_RETRY_COUNT = 1 SPLUNK_RETRY_BACKOFF = 0.1 +SPLUNK_ADDITIONAL_HEADERS = {"test": "header"} RECEIVER_URL = 'https://%s:%s/services/collector/event' % (SPLUNK_HOST, SPLUNK_PORT) @@ -45,6 +46,7 @@ def setUp(self): debug=SPLUNK_DEBUG, retry_count=SPLUNK_RETRY_COUNT, retry_backoff=SPLUNK_RETRY_BACKOFF, + additional_headers=SPLUNK_ADDITIONAL_HEADERS, ) def tearDown(self): @@ -68,6 +70,7 @@ def test_init(self): self.assertEqual(self.splunk.debug, SPLUNK_DEBUG) self.assertEqual(self.splunk.retry_count, SPLUNK_RETRY_COUNT) self.assertEqual(self.splunk.retry_backoff, SPLUNK_RETRY_BACKOFF) + self.assertTrue(self.splunk.additional_headers, SPLUNK_ADDITIONAL_HEADERS) self.assertFalse(logging.getLogger('requests').propagate) self.assertFalse(logging.getLogger('splunk_handler').propagate) @@ -96,7 +99,7 @@ def test_splunk_worker(self): verify=SPLUNK_VERIFY, data=expected_output, timeout=SPLUNK_TIMEOUT, - headers={'Authorization': "Splunk %s" % SPLUNK_TOKEN}, + headers={'Authorization': "Splunk %s" % SPLUNK_TOKEN, **SPLUNK_ADDITIONAL_HEADERS}, ) def test_splunk_worker_override(self): @@ -123,7 +126,7 @@ def test_splunk_worker_override(self): self.mock_request.assert_called_once_with( RECEIVER_URL, data=expected_output, - headers={'Authorization': "Splunk %s" % SPLUNK_TOKEN}, + headers={'Authorization': "Splunk %s" % SPLUNK_TOKEN, **SPLUNK_ADDITIONAL_HEADERS}, verify=SPLUNK_VERIFY, timeout=SPLUNK_TIMEOUT ) @@ -184,7 +187,7 @@ def test_wait_until_empty_and_keep_ahead(self): self.mock_request.assert_has_calls([call( RECEIVER_URL, data=expected_output * 10, - headers={'Authorization': 'Splunk %s' % SPLUNK_TOKEN}, + headers={'Authorization': "Splunk %s" % SPLUNK_TOKEN, **SPLUNK_ADDITIONAL_HEADERS}, verify=SPLUNK_VERIFY, timeout=SPLUNK_TIMEOUT )] * 2, any_order=True)