Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,10 @@ internal class AlertingPlugin : PainlessExtension, ActionPlugin, ScriptPlugin, R
MULTI_TENANCY_ENABLED.get(settings),
if (providerType.isNotEmpty()) JobQueueAccountIdProvider.find(providerType, settings) else null,
REMOTE_METADATA_REGION.get(settings) ?: "",
AlertingSettings.JOB_QUEUE_NAME.get(settings) ?: ""
AlertingSettings.JOB_QUEUE_NAME.get(settings) ?: "",
AlertingSettings.TARGET_TYPE_TO_SERVICE_NAME.get(settings).let {
it.keySet().associateWith { key -> it.get(key) }
}
)

ExternalSchedulerService.initialize(settings)
Expand Down Expand Up @@ -491,7 +494,8 @@ internal class AlertingPlugin : PainlessExtension, ActionPlugin, ScriptPlugin, R
AlertingSettings.JOB_QUEUE_MESSAGE_GROUP_KEY_NAME,
AlertingSettings.EXTERNAL_SCHEDULER_ROLE_ARN,
AlertingSettings.JOB_QUEUE_ACCOUNT_ID,
AlertingSettings.JOB_QUEUE_ACCOUNT_PROVIDER_TYPE
AlertingSettings.JOB_QUEUE_ACCOUNT_PROVIDER_TYPE,
AlertingSettings.TARGET_TYPE_TO_SERVICE_NAME
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.opensearch.common.xcontent.LoggingDeprecationHandler
import org.opensearch.common.xcontent.XContentType
import org.opensearch.commons.alerting.model.Monitor
import org.opensearch.commons.alerting.model.ScheduleJobPayload
import org.opensearch.commons.alerting.model.Target
import org.opensearch.commons.alerting.util.AlertingException
import org.opensearch.commons.utils.scheduler.JobQueueAccountIdProvider
import org.opensearch.core.xcontent.NamedXContentRegistry
Expand All @@ -48,7 +49,8 @@ class MonitorJobPoller(
private val enabled: Boolean,
private val accountIdProvider: JobQueueAccountIdProvider?,
private val region: String,
private val queueName: String
private val queueName: String,
private val targetTypeToServiceName: Map<String, String>
) : AbstractLifecycleComponent() {

private val logger = LogManager.getLogger(MonitorJobPoller::class.java)
Expand All @@ -65,6 +67,7 @@ class MonitorJobPoller(
}
val provider = requireNotNull(accountIdProvider) { "accountIdProvider must be set before starting" }
val sqs = requireNotNull(sqsClient) { "sqsClient must be set before starting" }
require(region.isNotBlank()) { "region must be set before starting" }

logger.info("Starting MonitorJobPoller with $POLLER_THREAD_COUNT workers")
repeat(POLLER_THREAD_COUNT) { scope.launch { pollLoop(provider, sqs, region, queueName) } }
Expand Down Expand Up @@ -134,6 +137,10 @@ class MonitorJobPoller(
}

private suspend fun executeMonitor(monitor: Monitor, jobStartTime: Instant) {
// populate thread context for downstream request interception the moment
// Monitor config is in hand
populateThreadContext(monitor.target)

val request = ExecuteMonitorRequest(
dryrun = false,
requestEnd = TimeValue(jobStartTime.toEpochMilli()),
Expand Down Expand Up @@ -180,8 +187,66 @@ class MonitorJobPoller(
}
}

// populates thread context with KVs that downstream interception will
// need when intercepting search or PPL calls to external customer
// data source
internal fun populateThreadContext(target: Target?) {
if (target == null) {
throw AlertingException.wrap(
IllegalStateException("Monitor received by Job Poller did not contain target")
)
}

if (target.type.isBlank()) {
throw AlertingException.wrap(
IllegalStateException("Monitor target received by Job Poller did not contain target type")
)
}

if (target.endpoint.isBlank()) {
throw AlertingException.wrap(
IllegalStateException("Monitor target received by Job Poller did not contain endpoint")
)
}

val threadContext = client.threadPool().threadContext

// Request interception checks for this flag to know that this is
// a scheduled background monitor execution, meaning there will be
// no user credentials to make the search/ppl call to customer
// data source with, and it must use service credentials
threadContext.putHeader(IS_BACKGROUND_JOB_HEADER, "true")

threadContext.putHeader(SERVICE_NAME_HEADER, mapTargetTypeToServiceName(target.type))

// external customer data source endpoint, to run search/ppl against
threadContext.putHeader(OPENSEARCH_ENDPOINT_HEADER, target.endpoint)

// populated upstream in AlertingPlugin.kt with REMOTE_METADATA_REGION.get(settings)
threadContext.putHeader(REGION_HEADER, region)
}

private fun mapTargetTypeToServiceName(targetType: String): String {
Comment thread
eirsep marked this conversation as resolved.
if (!targetTypeToServiceName.containsKey(targetType)) {
throw AlertingException.wrap(
IllegalStateException(
"Received invalid target type in Job Poller: " + targetType +
", expected one of: " + targetTypeToServiceName.keys
)
)
}

return targetTypeToServiceName[targetType]!!
}

companion object {
const val POLLER_THREAD_COUNT = 10
const val POLL_INTERVAL_MS = 1000L

// thread context header keys for request interception
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think we should follow a consistent naming convention for all these headers but needs to be discussed this across team

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These constants are actually existing header names from other code owners that we are reusing. The only new and original const here is IS_BACKGROUND_JOB_HEADER = "alerting-is-background-job"

const val IS_BACKGROUND_JOB_HEADER = "is-observability-bg-job"
const val SERVICE_NAME_HEADER = "aws-service-name"
const val OPENSEARCH_ENDPOINT_HEADER = "opensearch-url"
const val REGION_HEADER = "aws-region"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -404,5 +404,12 @@ class AlertingSettings {
"plugins.alerting.external_scheduler.job_queue_message_group_key_name",
Setting.Property.NodeScope, Setting.Property.Dynamic
)

/** Mappings from Monitor target type to opensearch service name, used in MonitorJobPoller
* to populate thread context with required Monitor target information */
val TARGET_TYPE_TO_SERVICE_NAME = Setting.groupSetting(
"plugins.alerting.monitor.target_type_to_service_name.",
Setting.Property.NodeScope, Setting.Property.Dynamic
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ package org.opensearch.alerting.service
import com.carrotsearch.randomizedtesting.ThreadFilter
import com.carrotsearch.randomizedtesting.annotations.ThreadLeakFilters
import org.mockito.Mockito.mock
import org.mockito.Mockito.`when`
import org.opensearch.common.settings.Settings
import org.opensearch.commons.alerting.model.Monitor
import org.opensearch.commons.alerting.model.SearchInput
import org.opensearch.commons.alerting.model.Target
import org.opensearch.commons.utils.scheduler.JobQueueAccountIdProvider
import org.opensearch.core.xcontent.NamedXContentRegistry
import org.opensearch.search.SearchModule
Expand Down Expand Up @@ -56,6 +58,13 @@ class MonitorJobPollerTests : OpenSearchTestCase() {
}
}

private fun mappingProvider(): Map<String, String> {
return mapOf(
"target_1" to "service_1",
"target_2" to "service_2"
)
}

private fun validMessageBody(): String {
val monitorConfig = "{\"type\":\"monitor\",\"name\":\"test\"," +
"\"monitor_type\":\"query_level_monitor\",\"enabled\":true," +
Expand All @@ -75,7 +84,8 @@ class MonitorJobPollerTests : OpenSearchTestCase() {
): MonitorJobPoller {
return MonitorJobPoller(
testXContentRegistry(), mockClient(), enabled,
testAccountIdProvider(), "us-west-2", "test-queue"
testAccountIdProvider(), "us-west-2", "test-queue",
mappingProvider()
).also { it.sqsClient = sqsClient }
}

Expand All @@ -101,7 +111,8 @@ class MonitorJobPollerTests : OpenSearchTestCase() {
val sqsClient = FakeSqsClient()
val poller = MonitorJobPoller(
testXContentRegistry(), mockClient(), true,
testAccountIdProvider(), "us-west-2", "test-queue"
testAccountIdProvider(), "us-west-2", "test-queue",
mappingProvider()
).also { it.sqsClient = sqsClient }
poller.start()
Thread.sleep(100)
Expand All @@ -119,7 +130,7 @@ class MonitorJobPollerTests : OpenSearchTestCase() {
)
val poller = MonitorJobPoller(
testXContentRegistry(), mockClient(), false,
null, "", ""
null, "", "", mappingProvider()
)
poller.start()
// Should NOT poll since disabled
Expand All @@ -131,7 +142,19 @@ class MonitorJobPollerTests : OpenSearchTestCase() {
fun `test start throws when provider not set`() {
val poller = MonitorJobPoller(
testXContentRegistry(), mockClient(), true,
null, "us-west-2", "test-queue"
null, "us-west-2", "test-queue", mappingProvider()
)
expectThrows(Exception::class.java) {
poller.start()
}
poller.close()
}

fun `test start throws when region not set`() {
val poller = MonitorJobPoller(
testXContentRegistry(), mockClient(), true,
testAccountIdProvider(), "", "test-queue",
mappingProvider()
)
expectThrows(Exception::class.java) {
poller.start()
Expand Down Expand Up @@ -169,7 +192,8 @@ class MonitorJobPollerTests : OpenSearchTestCase() {
}
val poller = MonitorJobPoller(
testXContentRegistry(), mockClient(), true,
errorProvider, "us-west-2", "test-queue"
errorProvider, "us-west-2", "test-queue",
mappingProvider()
).also { it.sqsClient = FakeSqsClient() }
poller.start()
assertTrue("Worker should have polled twice", latch.await(5, TimeUnit.SECONDS))
Expand All @@ -189,7 +213,8 @@ class MonitorJobPollerTests : OpenSearchTestCase() {
}
val poller = MonitorJobPoller(
testXContentRegistry(), mockClient(), true,
emptyProvider, "us-west-2", "test-queue"
emptyProvider, "us-west-2", "test-queue",
mappingProvider()
).also { it.sqsClient = FakeSqsClient() }
poller.start()
assertTrue("Worker should have polled multiple times", latch.await(5, TimeUnit.SECONDS))
Expand Down Expand Up @@ -335,4 +360,49 @@ class MonitorJobPollerTests : OpenSearchTestCase() {
}
poller.close()
}

fun `test thread context populated correctly based on target type`() {
val mockClient = mockClient()
val mockThreadPool = mock(org.opensearch.threadpool.ThreadPool::class.java)
val mockThreadContext = org.opensearch.common.util.concurrent.ThreadContext(Settings.EMPTY)

`when`(mockClient.threadPool()).thenReturn(mockThreadPool)
`when`(mockThreadPool.threadContext).thenReturn(mockThreadContext)

val poller = MonitorJobPoller(
testXContentRegistry(), mockClient, true,
testAccountIdProvider(), "us-east-1", "test-queue",
mappingProvider()
)

val mockTargetType = mappingProvider().entries.first().key
val target = Target(type = mockTargetType, endpoint = "https://test.aoss.amazonaws.com")

poller.populateThreadContext(target)

assertEquals("true", mockThreadContext.getHeader(MonitorJobPoller.IS_BACKGROUND_JOB_HEADER))
assertEquals(mappingProvider()[mockTargetType], mockThreadContext.getHeader(MonitorJobPoller.SERVICE_NAME_HEADER))
assertEquals("https://test.aoss.amazonaws.com", mockThreadContext.getHeader(MonitorJobPoller.OPENSEARCH_ENDPOINT_HEADER))
assertEquals("us-east-1", mockThreadContext.getHeader(MonitorJobPoller.REGION_HEADER))

poller.close()
}

fun `test thread context population rejects invalid target type`() {
val mockClient = mockClient()

val poller = MonitorJobPoller(
testXContentRegistry(), mockClient, true,
testAccountIdProvider(), "us-east-1", "test-queue",
mappingProvider()
)

val target = Target(type = "non_existent_type", endpoint = "https://test.aoss.amazonaws.com")

expectThrows(Exception::class.java) {
poller.populateThreadContext(target)
}

poller.close()
}
}
Loading