Skip to content

Commit 186061b

Browse files
committed
feat: support passing extra parameters and using different documents
1 parent a3a7108 commit 186061b

File tree

4 files changed

+146
-7
lines changed

4 files changed

+146
-7
lines changed

README.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,51 @@ task :forward_port, %i[instance_id remote_port local_port] => :environment do |_
199199
end
200200
```
201201

202+
You can also use specific documents, and pass in extra parameters, which can be
203+
useful for using tunnels to access other private resources like database
204+
instances:
205+
206+
```ruby
207+
require "aws_ec2_environment"
208+
209+
desc "Dumps a copy of the postgres database using AWS and PG environment variables"
210+
task :dump_pg_database, %i[instance_id dump_file] => :environment do |_, args|
211+
# trap ctl+c to make things a bit nicer (otherwise we'll get an ugly stacktrace)
212+
# since we expect this to be used to terminate the command
213+
trap("SIGINT") { exit }
214+
215+
logger = Logger.new($stdout)
216+
217+
instance_id = args.fetch(:instance_id)
218+
dump_file = args.fetch(:dump_file)
219+
220+
remote_host = ENV.fetch("PGHOST")
221+
remote_port = ENV.fetch("PGPORT", 5432)
222+
223+
session = AwsEc2Environment::SsmPortForwardingSession.new(
224+
instance_id,
225+
remote_port,
226+
document: "AWS-StartPortForwardingSessionToRemoteHost",
227+
logger:,
228+
extra_params: { "host" => [remote_host] }
229+
)
230+
231+
at_exit { session.close }
232+
233+
local_port = session.wait_for_local_port
234+
235+
system([
236+
"pg_dump",
237+
"--format=c",
238+
"--no-owner",
239+
"--no-privileges",
240+
"--host=localhost",
241+
"--port=#{local_port}",
242+
"--file=#{dump_file}",
243+
].join(" "))
244+
end
245+
```
246+
202247
### AWS Authentication and Permissions
203248

204249
Since this gem interacts with AWS, it must be configured with credentials - see

lib/aws_ec2_environment/ssm_port_forwarding_session.rb

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ class SessionProcessError < Error; end
2222
# rubocop:disable Metrics/ParameterLists
2323
def initialize(
2424
instance_id, remote_port,
25+
document: "AWS-StartPortForwardingSession",
2526
local_port: nil, logger: Logger.new($stdout),
26-
timeout: 15, reason: nil
27+
timeout: 15, reason: nil, extra_params: {}
2728
)
2829
# rubocop:enable Metrics/ParameterLists
2930
@logger = logger
@@ -32,7 +33,7 @@ def initialize(
3233
@local_port = nil
3334
@timeout = timeout
3435

35-
@reader, @writer, @pid = PTY.spawn(ssm_port_forward_cmd(local_port, reason))
36+
@reader, @writer, @pid = PTY.spawn(ssm_port_forward_cmd(local_port, reason, document, extra_params))
3637

3738
@cmd_output = ""
3839
@session_id = wait_for_session_id
@@ -64,9 +65,8 @@ def wait_for_local_port
6465

6566
private
6667

67-
def ssm_port_forward_cmd(local_port, reason)
68-
document_name = "AWS-StartPortForwardingSession"
69-
parameters = { "portNumber" => [remote_port.to_s] }
68+
def ssm_port_forward_cmd(local_port, reason, document_name, extra_parameters)
69+
parameters = extra_parameters.merge({ "portNumber" => [remote_port.to_s] })
7070
parameters["localPortNumber"] = [local_port.to_s] unless local_port.nil?
7171
flags = [
7272
["--target", instance_id],

sig/aws_ec2_environment/ssm_port_forwarding_session.rbs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ class AwsEc2Environment
1616
def initialize: (
1717
String instance_id,
1818
Integer remote_port,
19+
?document: String,
1920
?local_port: Integer | nil,
2021
?logger: Logger,
2122
?timeout: Numeric,
22-
?reason: String | nil
23+
?reason: String | nil,
24+
?extra_params: Hash[String, untyped]
2325
) -> void
2426

2527
def close: () -> void
@@ -38,7 +40,7 @@ class AwsEc2Environment
3840
@writer: IO
3941
@cmd_output: String
4042

41-
def ssm_port_forward_cmd: (Integer | nil local_port, String | nil reason) -> String
43+
def ssm_port_forward_cmd: (Integer | nil local_port, String | nil reason, String document_name, Hash[String, untyped] extra_params) -> String
4244

4345
# Checks the cmd process output until either the given +pattern+ matches or the +timeout+ is over.
4446
#

spec/aws_ec2_environment/ssm_port_forwarding_session_spec.rb

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,98 @@ def close; end
175175
)
176176
end
177177
end
178+
179+
context "when a specific document is provided" do
180+
subject(:session) do
181+
described_class.new(
182+
"i-0d9c4bg3f26157a8e",
183+
22,
184+
document: "AWS-StartPortForwardingSessionToRemoteHost",
185+
logger: Logger.new(StringIO.new(log)),
186+
# we can use a really low timeout to make the tests a lot faster,
187+
# since we're not actually going to be writing asynchronously
188+
timeout: 0.00001
189+
)
190+
end
191+
192+
it "uses the document" do
193+
expect { session }.not_to raise_error
194+
195+
parameters = { "portNumber" => ["22"] }
196+
parameters_escaped = Shellwords.escape(parameters.to_json)
197+
198+
expect(PTY).to have_received(:spawn).with(
199+
%w[
200+
aws ssm start-session
201+
--target i-0d9c4bg3f26157a8e
202+
--document-name AWS-StartPortForwardingSessionToRemoteHost
203+
--parameters
204+
].join(" ") + " #{parameters_escaped}"
205+
)
206+
end
207+
end
208+
209+
context "when extra parameters are provided" do
210+
it "merges them" do
211+
expect do
212+
described_class.new(
213+
"i-0d9c4bg3f26157a8e",
214+
22,
215+
extra_params: { "host" => ["my-database.abc.ap-southeast-2.rds.amazonaws.com"] },
216+
logger: Logger.new(StringIO.new(log)),
217+
# we can use a really low timeout to make the tests a lot faster,
218+
# since we're not actually going to be writing asynchronously
219+
timeout: 0.00001
220+
)
221+
end.not_to raise_error
222+
223+
parameters = { "host" => ["my-database.abc.ap-southeast-2.rds.amazonaws.com"], "portNumber" => ["22"] }
224+
parameters_escaped = Shellwords.escape(parameters.to_json)
225+
226+
expect(PTY).to have_received(:spawn).with(
227+
%w[
228+
aws ssm start-session
229+
--target i-0d9c4bg3f26157a8e
230+
--document-name AWS-StartPortForwardingSession
231+
--parameters
232+
].join(" ") + " #{parameters_escaped}"
233+
)
234+
end
235+
236+
it "overrides them with specific parameters" do
237+
expect do
238+
described_class.new(
239+
"i-0d9c4bg3f26157a8e",
240+
22,
241+
extra_params: {
242+
"host" => ["my-database.abc.ap-southeast-2.rds.amazonaws.com"],
243+
"localPortNumber" => [1234]
244+
},
245+
local_port: 5432,
246+
logger: Logger.new(StringIO.new(log)),
247+
# we can use a really low timeout to make the tests a lot faster,
248+
# since we're not actually going to be writing asynchronously
249+
timeout: 0.00001
250+
)
251+
end.not_to raise_error
252+
253+
parameters = {
254+
"host" => ["my-database.abc.ap-southeast-2.rds.amazonaws.com"],
255+
"localPortNumber" => ["5432"],
256+
"portNumber" => ["22"]
257+
}
258+
parameters_escaped = Shellwords.escape(parameters.to_json)
259+
260+
expect(PTY).to have_received(:spawn).with(
261+
%w[
262+
aws ssm start-session
263+
--target i-0d9c4bg3f26157a8e
264+
--document-name AWS-StartPortForwardingSession
265+
--parameters
266+
].join(" ") + " #{parameters_escaped}"
267+
)
268+
end
269+
end
178270
end
179271

180272
describe "#instance_id" do

0 commit comments

Comments
 (0)