diff --git a/ec2instanceconnectcli/EC2InstanceConnectCLI.py b/ec2instanceconnectcli/EC2InstanceConnectCLI.py index 7d7d9f3..b669326 100644 --- a/ec2instanceconnectcli/EC2InstanceConnectCLI.py +++ b/ec2instanceconnectcli/EC2InstanceConnectCLI.py @@ -81,18 +81,18 @@ def handle_keys(self): key_publisher.push_public_key(session, bundle['instance_id'], bundle['username'], self.pub_key, bundle['zone']) self.logger.debug('Successfully pushed the public key to {0}'.format(bundle['instance_id'])) - def run_command(self, command=None): + def run_command(self, args=None): """ - Runs the given command in a sub-shell - :param command: Command to invoke - :type command: basestring + Runs the given command + :param args: Arguments to invoke + :type args: list of strings :return: Return code for remote command :rtype: int """ - if not command: + if not args: raise ValueError('Must provide a command') - invocation_proc = Popen(command, shell=True) + invocation_proc = Popen(args) while invocation_proc.poll() is None: #sub-process not terminated time.sleep(0.1) return invocation_proc.returncode diff --git a/ec2instanceconnectcli/EC2InstanceConnectCommand.py b/ec2instanceconnectcli/EC2InstanceConnectCommand.py index 32de406..b0a12c2 100644 --- a/ec2instanceconnectcli/EC2InstanceConnectCommand.py +++ b/ec2instanceconnectcli/EC2InstanceConnectCommand.py @@ -43,21 +43,20 @@ def get_command(self): Generates and returns the generated command """ # Start with protocol & identity file - command = '{0} -o "IdentitiesOnly=yes" -i {1}'.format(self.program, self.key_file) + command = [self.program, '-o', 'IdentitiesOnly=yes', '-i', self.key_file] # Next add command flags if present - if len(self.flags) > 0: - command = "{0} {1}".format(command, self.flags) + command.extend(self.flags) # Target - command = "{0} {1}".format(command, self._get_target(self.instance_bundles[0])) + command.append(self._get_target(self.instance_bundles[0])) #program specific command if len(self.program_command) > 0: - command = "{0} {1}".format(command, self.program_command) + command.append(self.program_command) if len(self.instance_bundles) > 1: - command = "{0} {1}".format(command, self._get_target(self.instance_bundles[1])) + command.append(self._get_target(self.instance_bundles[1])) self.logger.debug('Generated command: {0}'.format(command)) diff --git a/ec2instanceconnectcli/input_parser.py b/ec2instanceconnectcli/input_parser.py index 705d214..e53df0c 100644 --- a/ec2instanceconnectcli/input_parser.py +++ b/ec2instanceconnectcli/input_parser.py @@ -117,7 +117,7 @@ def _parse_command_flags(raw_command, instance_bundles, is_ssh=False): :return: tuple of flags and final comamnd or file list :rtype: tuple """ - flags = '' + flags = [] is_user = False is_flagged = False command_index = 0 @@ -133,7 +133,7 @@ def _parse_command_flags(raw_command, instance_bundles, is_ssh=False): used += 1 # This is either a flag or a flag value - flags = '{0} {1}'.format(flags, raw_command[command_index]) + flags.append(raw_command[command_index]) if raw_command[command_index][0] == '-': # Flag @@ -152,8 +152,6 @@ def _parse_command_flags(raw_command, instance_bundles, is_ssh=False): command_index += 1 - flags = flags.strip() - """ Target host and command or file list """ diff --git a/tests/test_EC2ConnectCLI.py b/tests/test_EC2ConnectCLI.py index 584a4e9..f59117e 100644 --- a/tests/test_EC2ConnectCLI.py +++ b/tests/test_EC2ConnectCLI.py @@ -31,7 +31,7 @@ def test_mssh_no_target(self, mock_push_key, mock_run): mock_file = 'identity' - flag = '-f flag' + flags = ['-f', 'flag'] command = 'command arg' logger = EC2InstanceConnectLogger() instance_bundles = [{'username': self.default_user, 'instance_id': self.instance_id, @@ -41,12 +41,12 @@ def test_mssh_no_target(self, mock_instance_data.return_value = self.instance_info mock_push_key.return_value = None - cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flag, command, logger.get_logger()) + cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flags, command, logger.get_logger()) cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger()) cli.invoke_command() - - expected_command = 'ssh -o "IdentitiesOnly=yes" -i {0} {1} {2}@{3} {4}'.format(mock_file, flag, self.default_user, - self.public_ip, command) + + expected_command = ['ssh', '-o', 'IdentitiesOnly=yes', '-i', mock_file, *flags, + '{}@{}'.format(self.default_user, self.public_ip), command] # Check that we successfully get to the run self.assertTrue(mock_instance_data.called) @@ -62,7 +62,7 @@ def test_mssh_no_target_no_public_ip(self, mock_push_key, mock_run): mock_file = "identity" - flag = '-f flag' + flags = ['-f', 'flag'] command = 'command arg' logger = EC2InstanceConnectLogger() instance_bundles = [{'username': self.default_user, 'instance_id': self.instance_id, @@ -72,12 +72,12 @@ def test_mssh_no_target_no_public_ip(self, mock_instance_data.return_value = self.private_instance_info mock_push_key.return_value = None - cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flag, command, logger.get_logger()) + cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flags, command, logger.get_logger()) cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger()) cli.invoke_command() - expected_command = 'ssh -o "IdentitiesOnly=yes" -i {0} {1} {2}@{3} {4}'.format(mock_file, flag, self.default_user, - self.private_ip, command) + expected_command = ['ssh', '-o', 'IdentitiesOnly=yes', '-i', mock_file, *flags, + '{}@{}'.format(self.default_user, self.private_ip), command] # Check that we successfully get to the run self.assertTrue(mock_instance_data.called) @@ -92,7 +92,7 @@ def test_mssh_with_target(self, mock_push_key, mock_run): mock_file = 'identity' - flag = '-f flag' + flags = ['-f', 'flag'] command = 'command arg' host = '0.0.0.0' logger = EC2InstanceConnectLogger() @@ -103,12 +103,12 @@ def test_mssh_with_target(self, mock_instance_data.return_value = self.instance_info mock_push_key.return_value = None - cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flag, command, logger.get_logger()) + cli_command = EC2InstanceConnectCommand("ssh", instance_bundles, mock_file, flags, command, logger.get_logger()) cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger()) cli.invoke_command() - expected_command = 'ssh -o "IdentitiesOnly=yes" -i {0} {1} {2}@{3} {4}'.format(mock_file, flag, self.default_user, - host, command) + expected_command = ['ssh', '-o', 'IdentitiesOnly=yes', '-i', mock_file, *flags, + '{}@{}'.format(self.default_user, host), command] # Check that we successfully get to the run # Since both target and availability_zone are provided, mock_instance_data should not be called self.assertFalse(mock_instance_data.called) @@ -123,7 +123,7 @@ def test_msftp(self, mock_push_key, mock_run): mock_file = 'identity' - flag = '-f flag' + flags = ['-f', 'flag'] command = 'file2 file3' logger = EC2InstanceConnectLogger() instance_bundles = [{'username': self.default_user, 'instance_id': self.instance_id, @@ -133,10 +133,11 @@ def test_msftp(self, mock_instance_data.return_value = self.instance_info mock_push_key.return_value = None - expected_command = 'sftp -o "IdentitiesOnly=yes" -i {0} {1} {2}@{3}:{4} {5}'.format(mock_file, flag, self.default_user, - self.public_ip, 'file1', command) + expected_command = ['sftp', '-o', 'IdentitiesOnly=yes', '-i', mock_file, *flags, + '{}@{}:{}'.format(self.default_user, self.public_ip, 'file1'), + command] - cli_command = EC2InstanceConnectCommand("sftp", instance_bundles, mock_file, flag, command, logger.get_logger()) + cli_command = EC2InstanceConnectCommand("sftp", instance_bundles, mock_file, flags, command, logger.get_logger()) cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger()) cli.invoke_command() @@ -153,7 +154,7 @@ def test_mscp(self, mock_push_key, mock_run): mock_file = 'identity' - flag = '-f flag' + flags = ['-f', 'flag'] command = 'file2 file3' logger = EC2InstanceConnectLogger() instance_bundles = [{'username': self.default_user, 'instance_id': self.instance_id, @@ -166,12 +167,12 @@ def test_mscp(self, mock_instance_data.return_value = self.instance_info mock_push_key.return_value = None - expected_command = 'scp -o "IdentitiesOnly=yes" -i {0} {1} {2}@{3}:{4} {5} {6}@{7}:{8}'.format(mock_file, flag, self.default_user, - self.public_ip, 'file1', command, - self.default_user, - self.public_ip, 'file4') + expected_command = ['scp', '-o', 'IdentitiesOnly=yes', '-i', mock_file, *flags, + '{}@{}:{}'.format(self.default_user, self.public_ip, 'file1'), + command, + '{}@{}:{}'.format(self.default_user, self.public_ip, 'file4')] - cli_command = EC2InstanceConnectCommand("scp", instance_bundles, mock_file, flag, command, logger.get_logger()) + cli_command = EC2InstanceConnectCommand("scp", instance_bundles, mock_file, flags, command, logger.get_logger()) cli = EC2InstanceConnectCLI(instance_bundles, "", cli_command, logger.get_logger()) cli.invoke_command() @@ -183,5 +184,5 @@ def test_mscp(self, def test_status_code(self): #TODO: Refine test for checking run_command status code cli = EC2InstanceConnectCLI(None, None, None, None) - code = cli.run_command("echo ok; exit -1;") + code = cli.run_command(["sh", "-c", "echo ok; exit -1;"]) self.assertEqual(code, 255) diff --git a/tests/test_input_parser.py b/tests/test_input_parser.py index e04fde9..b7ca3ff 100644 --- a/tests/test_input_parser.py +++ b/tests/test_input_parser.py @@ -41,7 +41,7 @@ def test_basic_target(self): self.assertEqual(bundles, [{'username': self.default_user, 'instance_id': self.instance_id, 'target': None, 'zone': None, 'region': None, 'profile': self.profile}]) - self.assertEqual(flags, '') + self.assertEqual(flags, []) self.assertEqual(command, '') def test_username(self): @@ -51,7 +51,7 @@ def test_username(self): self.assertEqual(bundles, [{'username': 'myuser', 'instance_id': self.instance_id, 'target': None, 'zone': None, 'region': None, 'profile': self.profile}]) - self.assertEqual(flags, '') + self.assertEqual(flags, []) self.assertEqual(command, '') def test_dns_name(self): @@ -63,7 +63,7 @@ def test_dns_name(self): self.assertEqual(bundles, [{'username': self.default_user, 'instance_id': self.instance_id, 'target': self.dns_name, 'zone': self.availability_zone, 'region': self.region, 'profile': self.profile}]) - self.assertEqual(flags, '') + self.assertEqual(flags, []) self.assertEqual(command, '') def test_flags(self): @@ -73,7 +73,7 @@ def test_flags(self): self.assertEqual(bundles, [{'username': 'login', 'instance_id': self.instance_id, 'target': None, 'zone': None, 'region': None, 'profile': self.profile}]) - self.assertEqual(flags, '-1 -l login') + self.assertEqual(flags, ['-1', '-l', 'login']) self.assertEqual(command, '') def test_command(self): @@ -83,7 +83,7 @@ def test_command(self): self.assertEqual(bundles, [{'username': self.default_user, 'instance_id': self.instance_id, 'target': None, 'zone': None, 'region': None, 'profile': self.profile}]) - self.assertEqual(flags, '') + self.assertEqual(flags, []) self.assertEqual(command, 'uname -a') def test_sftp(self): @@ -95,7 +95,7 @@ def test_sftp(self): self.assertEqual(bundles, [{'username': self.default_user, 'instance_id': self.instance_id, 'target': None, 'zone': None, 'region': None, 'profile': self.profile, 'file': 'first_file'}]) - self.assertEqual(flags, '') + self.assertEqual(flags, []) self.assertEqual(command, 'second_file') def test_invalid_username(self):