diff --git a/CHANGELOG.md b/CHANGELOG.md index baa8322..03bae10 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,11 @@ ## [Unreleased] +## Added + +- Support using alternative documents and parameters with + `SsmPortForwardingSession` + ([#22](https://github.com/ackama/aws_ec2_environment/pull/22)) + ## [0.1.0] - 2022-08-17 - Initial release diff --git a/README.md b/README.md index 097d327..2391d17 100644 --- a/README.md +++ b/README.md @@ -199,6 +199,44 @@ task :forward_port, %i[instance_id remote_port local_port] => :environment do |_ end ``` +You can also use specific documents, and pass in extra parameters, which can be +useful for using tunnels to access other private resources like database +instances: + +```ruby +require "aws_ec2_environment" + +desc "Dumps a copy of the postgres database using AWS and PG environment variables" +task :dump_pg_database, %i[instance_id dump_file] => :environment do |_, args| + instance_id = args.fetch(:instance_id) + dump_file = args.fetch(:dump_file) + + remote_host = ENV.fetch("PGHOST") + remote_port = ENV.fetch("PGPORT", 5432) + + session = AwsEc2Environment::SsmPortForwardingSession.new( + instance_id, + remote_port, + document: "AWS-StartPortForwardingSessionToRemoteHost", + extra_params: { "host" => [remote_host] } + ) + + at_exit { session.close } + + local_port = session.wait_for_local_port + + system([ + "pg_dump", + "--format=c", + "--no-owner", + "--no-privileges", + "--host=localhost", + "--port=#{local_port}", + "--file=#{dump_file}", + ].join(" ")) +end +``` + ### AWS Authentication and Permissions Since this gem interacts with AWS, it must be configured with credentials - see diff --git a/lib/aws_ec2_environment/ssm_port_forwarding_session.rb b/lib/aws_ec2_environment/ssm_port_forwarding_session.rb index 65cb1ab..02898f6 100644 --- a/lib/aws_ec2_environment/ssm_port_forwarding_session.rb +++ b/lib/aws_ec2_environment/ssm_port_forwarding_session.rb @@ -22,8 +22,9 @@ class SessionProcessError < Error; end # rubocop:disable Metrics/ParameterLists def initialize( instance_id, remote_port, + document: "AWS-StartPortForwardingSession", local_port: nil, logger: Logger.new($stdout), - timeout: 15, reason: nil + timeout: 15, reason: nil, extra_params: {} ) # rubocop:enable Metrics/ParameterLists @logger = logger @@ -32,7 +33,7 @@ def initialize( @local_port = nil @timeout = timeout - @reader, @writer, @pid = PTY.spawn(ssm_port_forward_cmd(local_port, reason)) + @reader, @writer, @pid = PTY.spawn(ssm_port_forward_cmd(local_port, reason, document, extra_params)) @cmd_output = "" @session_id = wait_for_session_id @@ -64,9 +65,8 @@ def wait_for_local_port private - def ssm_port_forward_cmd(local_port, reason) - document_name = "AWS-StartPortForwardingSession" - parameters = { "portNumber" => [remote_port.to_s] } + def ssm_port_forward_cmd(local_port, reason, document_name, extra_parameters) + parameters = extra_parameters.merge({ "portNumber" => [remote_port.to_s] }) parameters["localPortNumber"] = [local_port.to_s] unless local_port.nil? flags = [ ["--target", instance_id], diff --git a/sig/aws_ec2_environment/ssm_port_forwarding_session.rbs b/sig/aws_ec2_environment/ssm_port_forwarding_session.rbs index 127843d..8dd1310 100644 --- a/sig/aws_ec2_environment/ssm_port_forwarding_session.rbs +++ b/sig/aws_ec2_environment/ssm_port_forwarding_session.rbs @@ -16,10 +16,12 @@ class AwsEc2Environment def initialize: ( String instance_id, Integer remote_port, + ?document: String, ?local_port: Integer | nil, ?logger: Logger, ?timeout: Numeric, - ?reason: String | nil + ?reason: String | nil, + ?extra_params: Hash[String, untyped] ) -> void def close: () -> void @@ -38,7 +40,7 @@ class AwsEc2Environment @writer: IO @cmd_output: String - def ssm_port_forward_cmd: (Integer | nil local_port, String | nil reason) -> String + def ssm_port_forward_cmd: (Integer | nil local_port, String | nil reason, String document_name, Hash[String, untyped] extra_params) -> String # Checks the cmd process output until either the given +pattern+ matches or the +timeout+ is over. # diff --git a/spec/aws_ec2_environment/ssm_port_forwarding_session_spec.rb b/spec/aws_ec2_environment/ssm_port_forwarding_session_spec.rb index 20d937e..86260ba 100644 --- a/spec/aws_ec2_environment/ssm_port_forwarding_session_spec.rb +++ b/spec/aws_ec2_environment/ssm_port_forwarding_session_spec.rb @@ -175,6 +175,98 @@ def close; end ) end end + + context "when a specific document is provided" do + subject(:session) do + described_class.new( + "i-0d9c4bg3f26157a8e", + 22, + document: "AWS-StartPortForwardingSessionToRemoteHost", + logger: Logger.new(StringIO.new(log)), + # we can use a really low timeout to make the tests a lot faster, + # since we're not actually going to be writing asynchronously + timeout: 0.00001 + ) + end + + it "uses the document" do + expect { session }.not_to raise_error + + parameters = { "portNumber" => ["22"] } + parameters_escaped = Shellwords.escape(parameters.to_json) + + expect(PTY).to have_received(:spawn).with( + %w[ + aws ssm start-session + --target i-0d9c4bg3f26157a8e + --document-name AWS-StartPortForwardingSessionToRemoteHost + --parameters + ].join(" ") + " #{parameters_escaped}" + ) + end + end + + context "when extra parameters are provided" do + it "merges them" do + expect do + described_class.new( + "i-0d9c4bg3f26157a8e", + 22, + extra_params: { "host" => ["my-database.abc.ap-southeast-2.rds.amazonaws.com"] }, + logger: Logger.new(StringIO.new(log)), + # we can use a really low timeout to make the tests a lot faster, + # since we're not actually going to be writing asynchronously + timeout: 0.00001 + ) + end.not_to raise_error + + parameters = { "host" => ["my-database.abc.ap-southeast-2.rds.amazonaws.com"], "portNumber" => ["22"] } + parameters_escaped = Shellwords.escape(parameters.to_json) + + expect(PTY).to have_received(:spawn).with( + %w[ + aws ssm start-session + --target i-0d9c4bg3f26157a8e + --document-name AWS-StartPortForwardingSession + --parameters + ].join(" ") + " #{parameters_escaped}" + ) + end + + it "overrides them with specific parameters" do + expect do + described_class.new( + "i-0d9c4bg3f26157a8e", + 22, + extra_params: { + "host" => ["my-database.abc.ap-southeast-2.rds.amazonaws.com"], + "localPortNumber" => [1234] + }, + local_port: 5432, + logger: Logger.new(StringIO.new(log)), + # we can use a really low timeout to make the tests a lot faster, + # since we're not actually going to be writing asynchronously + timeout: 0.00001 + ) + end.not_to raise_error + + parameters = { + "host" => ["my-database.abc.ap-southeast-2.rds.amazonaws.com"], + "localPortNumber" => ["5432"], + "portNumber" => ["22"] + } + parameters_escaped = Shellwords.escape(parameters.to_json) + + expect(PTY).to have_received(:spawn).with( + %w[ + aws ssm start-session + --target i-0d9c4bg3f26157a8e + --document-name AWS-StartPortForwardingSession + --parameters + ].join(" ") + " #{parameters_escaped}" + ) + end + end end describe "#instance_id" do