diff --git a/pom.xml b/pom.xml index 97d7975..5c46b9d 100644 --- a/pom.xml +++ b/pom.xml @@ -26,12 +26,12 @@ com.amazonaws amazon-kinesis-producer - 0.14.0 + 0.14.12 com.amazonaws aws-java-sdk - 1.11.327 + 1.12.198 javax.xml.bind diff --git a/src/main/java/com/warnermedia/kplserver/App.java b/src/main/java/com/warnermedia/kplserver/App.java index 44f6ca2..d552c1c 100644 --- a/src/main/java/com/warnermedia/kplserver/App.java +++ b/src/main/java/com/warnermedia/kplserver/App.java @@ -25,7 +25,7 @@ public static void main(String[] args) throws Exception { ServerSocket errSocket = new ServerSocket(port); errSocket.setSoTimeout(100); - KinesisEventPublisher kinesisEventPublisher = new KinesisEventPublisher(stream, getRegion(), getMetricsLevel(), errSocket); + KinesisEventPublisher kinesisEventPublisher = new KinesisEventPublisher(stream, getRegion(), getMetricsLevel(), getCrossAccountRole(), errSocket); // graceful shutdowns Runtime.getRuntime().addShutdownHook(new Thread() { @@ -88,4 +88,11 @@ static String getMetricsLevel() { return p; } + static String getCrossAccountRole() { + String p = System.getenv("CROSS_ACCOUNT_ROLE"); + if (p == null || p.equals("")) { + return ""; + } + return p; + } } diff --git a/src/main/java/com/warnermedia/kplserver/KinesisEventPublisher.java b/src/main/java/com/warnermedia/kplserver/KinesisEventPublisher.java index 077b008..be8ec5b 100644 --- a/src/main/java/com/warnermedia/kplserver/KinesisEventPublisher.java +++ b/src/main/java/com/warnermedia/kplserver/KinesisEventPublisher.java @@ -1,10 +1,20 @@ package com.warnermedia.kplserver; +import com.amazonaws.auth.AWSCredentialsProvider; +import com.amazonaws.auth.AWSStaticCredentialsProvider; +import com.amazonaws.auth.BasicSessionCredentials; +import com.amazonaws.auth.DefaultAWSCredentialsProviderChain; +import com.amazonaws.auth.profile.ProfileCredentialsProvider; import com.amazonaws.services.kinesis.producer.KinesisProducer; import com.amazonaws.services.kinesis.producer.KinesisProducerConfiguration; import com.amazonaws.services.kinesis.producer.UserRecord; import com.amazonaws.services.kinesis.producer.UserRecordFailedException; import com.amazonaws.services.kinesis.producer.UserRecordResult; +import com.amazonaws.services.securitytoken.AWSSecurityTokenService; +import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceAsyncClientBuilder; +import com.amazonaws.services.securitytoken.model.AssumeRoleRequest; +import com.amazonaws.services.securitytoken.model.AssumeRoleResult; +import com.amazonaws.services.securitytoken.model.Credentials; import com.google.common.util.concurrent.FutureCallback; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; @@ -36,14 +46,47 @@ public class KinesisEventPublisher { ServerSocket errSocket; Socket errClient; - public KinesisEventPublisher(String stream, String region, String metricsLevel, ServerSocket errSocket) { + public KinesisEventPublisher(String stream, String region, String metricsLevel, String crossAccountRole, ServerSocket errSocket) { this.stream = stream; kinesis = new KinesisProducer(new KinesisProducerConfiguration() .setRegion(region) - .setMetricsLevel(metricsLevel)); + .setMetricsLevel(metricsLevel) + .setCredentialsProvider(loadCredentials(crossAccountRole, region))); this.errSocket = errSocket; } + private static AWSCredentialsProvider loadCredentials(String crossAccountRole, String region) { + final AWSCredentialsProvider credentialsProvider; + + Boolean isCrossAccount = false; + if (!crossAccountRole.equals("")) { + isCrossAccount = true; + } + + if (isCrossAccount) { + AWSSecurityTokenService stsClient = AWSSecurityTokenServiceAsyncClientBuilder.standard() + .withRegion(region) + .build(); + + AssumeRoleRequest assumeRoleRequest = new AssumeRoleRequest().withDurationSeconds(3600) + .withRoleArn(crossAccountRole) + .withRoleSessionName("Kinesis_Session"); + + AssumeRoleResult assumeRoleResult = stsClient.assumeRole(assumeRoleRequest); + Credentials creds = assumeRoleResult.getCredentials(); + + credentialsProvider = new AWSStaticCredentialsProvider( + new BasicSessionCredentials(creds.getAccessKeyId(), + creds.getSecretAccessKey(), + creds.getSessionToken()) + ); + } else { + credentialsProvider = new DefaultAWSCredentialsProviderChain(); + } + + return credentialsProvider; + } + public void runOnce(String line) throws Exception { // add new line so that downstream systems have an easier time parsing String finalLine = line + "\n";