diff --git a/.pylintrc b/.pylintrc index 0f891ef..c79bd1b 100644 --- a/.pylintrc +++ b/.pylintrc @@ -60,28 +60,34 @@ confidence= # --enable=similarities". If you want to run only the classes checker, but have # no Warning level messages displayed, use "--disable=all --enable=classes # --disable=W". -disable=bad-continuation, - bad-whitespace, - bare-except, - broad-except, - consider-using-in, - consider-using-ternary, - fixme, - global-statement, - invalid-name, - missing-docstring, - no-else-raise, - no-else-return, - no-self-use, - trailing-newlines, - too-many-instance-attributes, - unused-argument, - unused-variable, - using-constant-test, - no-else-continue, - no-else-break, - chained-comparison, - useless-object-inheritance +disable= + bare-except, + broad-except, + broad-exception-raised, + consider-using-in, + consider-using-ternary, + consider-using-dict-items, + fixme, + global-statement, + invalid-name, + missing-docstring, + no-else-raise, + no-else-return, + trailing-newlines, + too-many-instance-attributes, + too-many-branches, + too-many-statements, + unused-argument, + unused-variable, + using-constant-test, + no-else-continue, + no-else-break, + condition-evals-to-constant, + chained-comparison, + consider-using-f-string, + use-dict-literal, + duplicate-code, + useless-object-inheritance # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option @@ -401,13 +407,6 @@ max-line-length=160 # Maximum number of lines in a module. max-module-lines=10000 -# List of optional constructs for which whitespace checking is disabled. `dict- -# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. -# `trailing-comma` allows a space between comma and closing bracket: (a, ). -# `empty-line` allows space-only lines. -no-space-check=trailing-comma, - dict-separator - # Allow the body of a class to be on the same line as the declaration if body # contains single statement. single-line-class-stmt=no @@ -509,5 +508,6 @@ known-third-party=enchant # Exceptions that will emit a warning when being caught. Defaults to # "BaseException, Exception". -overgeneral-exceptions=BaseException, - Exception +overgeneral-exceptions=builtins.BaseException, + builtins.Exception + diff --git a/pricing/.gitignore b/pricing/.gitignore index 4b85167..4a7ba05 100644 --- a/pricing/.gitignore +++ b/pricing/.gitignore @@ -1,2 +1,3 @@ *.json -cache +cache* +diffs diff --git a/tox.ini b/tox.ini index 1a1e025..572c251 100644 --- a/tox.ini +++ b/tox.ini @@ -5,23 +5,26 @@ envlist = lint,py3 [package] name = vmtool main_deps = - -r./requirements.txt + -r./requirements.txt lint_deps = - pylint==2.5.3 - pyenchant==2.0.0 + pylint==2.16.1 + pyenchant==3.2.2 + pyflakes==3.0.1 + pytype==2023.2.9 test_deps = - pytest==6.0.1 - pytest-cov==2.10.0 + pytest==7.2.1 + coverage==7.1.0 test_dir = "{toxinidir}/tests" - [testenv] deps = {[package]test_deps} {[package]main_deps} commands = {envpython} --version - pytest --cov {posargs} + coverage run -m pytest {posargs} + coverage html -d cover/{envname} + coverage report [testenv:lint] basepython = python3 @@ -30,4 +33,5 @@ deps = {[package]main_deps} commands = pylint {[package]name} + pytype {[package]name} diff --git a/vmtool/aws.py b/vmtool/aws.py index 68a0518..8bd0bee 100644 --- a/vmtool/aws.py +++ b/vmtool/aws.py @@ -1,8 +1,6 @@ """AWS backend for vmtool. """ -import argparse -import binascii import datetime import errno import gzip @@ -17,31 +15,24 @@ import secrets import shlex import socket -import stat import subprocess import sys import tarfile import time -import uuid -from fnmatch import fnmatch - -import boto3.session import boto3.s3.transfer -import botocore.session +import boto3.session import botocore.config +import botocore.session +from vmtool.base import VmCmd, VmToolBase from vmtool.certs import load_cert_config -from vmtool.config import Config, NoOptionError -from vmtool.envconfig import load_env_config, find_gittop -from vmtool.scripting import EnvScript, UsageError -from vmtool.tarfilter import TarFilter -from vmtool.terra import tf_load_output_var, tf_load_all_vars -from vmtool.util import fmt_dur -from vmtool.util import printf, eprintf, time_printf, print_json, local_cmd, run_successfully -from vmtool.util import ssh_add_known_host, parse_console, rsh_quote, as_unicode -from vmtool.xglob import xglob - +from vmtool.scripting import UsageError +from vmtool.terra import tf_load_all_vars +from vmtool.util import ( + as_unicode, eprintf, fmt_dur, parse_console, + print_json, printf, ssh_add_known_host, time_printf, +) # /usr/share/doc/cloud-init/userdata.txt USERDATA = """\ @@ -79,57 +70,25 @@ """ -SSH_USER_CREATION = '''\ -if ! grep -q '^{user}:' /etc/passwd; then - echo "Adding user {user}" - adduser -q --gecos {user} --disabled-password {user} < /dev/null - install -d -o {user} -g {user} -m 700 ~{user}/.ssh - echo "{pubkey}" > ~{user}/.ssh/authorized_keys - chmod 600 ~{user}/.ssh/authorized_keys - chown {user}:{user} ~{user}/.ssh/authorized_keys - for grp in {auth_groups}; do - adduser -q {user} $grp - done -fi -''' - -# replace those with root specified by image -ROOT_DEV_NAMES = ('root', 'xvda', '/dev/sda1') - - def show_commits(old_id, new_id, dirs, cwd): - cmd = ['git', '--no-pager', 'shortlog', '--no-merges', old_id + '..' + new_id] + cmd = ["git", "--no-pager", "shortlog", "--no-merges", old_id + ".." + new_id] if dirs: - cmd.append('--') + cmd.append("--") cmd.extend(dirs) subprocess.call(cmd, cwd=cwd) -def mk_sshuser_script(user, auth_groups, pubkey): - return SSH_USER_CREATION.format(user=user, auth_groups=' '.join(auth_groups), pubkey=pubkey) - - -class VmCmd: - """Sub-command names used internally in vmtool. - """ - PREP: str = 'prep' - FAILOVER_PROMOTE_SECONDARY: str = 'failover_promote_secondary' - - TAKEOVER_PREPARE_PRIMARY: str = 'takeover_prepare_primary' - TAKEOVER_PREPARE_SECONDARY: str = 'takeover_prepare_secondary' - TAKEOVER_FINISH_PRIMARY: str = 'takeover_finish_primary' - TAKEOVER_FINISH_SECONDARY: str = 'takeover_finish_secondary' - - DROP_NODE_PREPARE: str = 'drop_node_prepare' - class VmState: - PRIMARY: str = 'primary' - SECONDARY: str = 'secondary' + PRIMARY: str = "primary" + SECONDARY: str = "secondary" -class VmTool(EnvScript): +class VmTool(VmToolBase): __doc__ = __doc__ + # replace those with root specified by image + ROOT_DEV_NAMES = ("root", "xvda", "/dev/sda1") + _boto_sessions = None _boto_clients = None @@ -138,201 +97,107 @@ class VmTool(EnvScript): _vm_map = None - role_name = None - env_name = None # name of current env - full_role = None - ssh_dir = None - - new_commit = None - old_commit = None availability_zone = None - log = logging.getLogger('vmtool') - def startup(self): - logging.getLogger('boto3').setLevel(logging.WARNING) - logging.getLogger('botocore').setLevel(logging.WARNING) + super().startup() + logging.getLogger("boto3").setLevel(logging.WARNING) + logging.getLogger("botocore").setLevel(logging.WARNING) + + def conf_func_primary_vm(self, arg, sect, kname): + """Lookup primary vm. - def reload(self): - """Reload config. + Usage: ${PRIMARY_VM ! ${other_role}} """ - self.git_dir = find_gittop() - - # ~/.vmtool - ssh_dir = os.path.expanduser('~/.vmtool') - if not os.path.isdir(ssh_dir): - os.mkdir(ssh_dir, stat.S_IRWXU) - - keys_dir = os.environ.get('VMTOOL_KEY_DIR', os.path.join(self.git_dir, 'keys')) - if not keys_dir or not os.path.isdir(keys_dir): - raise UsageError('Set vmtool config dir: VMTOOL_KEY_DIR') - - ca_log_dir = os.environ.get('VMTOOL_CA_LOG_DIR') - if not ca_log_dir or not os.path.isdir(ca_log_dir): - raise UsageError('Set vmtool config dir: VMTOOL_CA_LOG_DIR') - - env = os.environ.get('VMTOOL_ENV_NAME', '') - if self.options.env: - env = self.options.env - if not env: - raise UsageError('No envronment set: either set VMTOOL_ENV_NAME or give --env=ENV') - - env_name = env - self.full_role = env - if '.' in env: - env_name, self.role_name = env.split('.') - self.env_name = env_name - if self.options.role: - self.role_name = self.options.role - self.full_role = '%s.%s' % (self.env_name, self.role_name) - - self.ca_log_dir = ca_log_dir - self.keys_dir = keys_dir - self.ssh_dir = ssh_dir - - self.cf = load_env_config(self.full_role, { - 'FILE': self.conf_func_file, - 'KEY': self.conf_func_key, - 'TF': self.conf_func_tf, - 'TFAZ': self.conf_func_tfaz, - 'PRIMARY_VM': self.conf_func_primary_vm, - 'NETWORK': self.conf_func_network, - 'NETMASK': self.conf_func_netmask, - 'MEMBERS': self.conf_func_members, - }) - self.process_pkgs() - - self._region = self.cf.get('region') - self.ssh_known_hosts = os.path.join(self.ssh_dir, 'known_hosts') - self.is_live = self.cf.getint('is_live', 0) - - if self.options.az is not None: - self.availability_zone = self.options.az - else: - self.availability_zone = self.cf.getint('availability_zone', 0) + vm = self.get_primary_for_role(arg) + return vm["InstanceId"] - # fill vm_ordered_disk_names - disk_map = self.get_disk_map() - if disk_map: - api_order = [] - size_order = [] - for dev in disk_map: - size = disk_map[dev]["size"] - count = disk_map[dev]["count"] - if size and dev not in ROOT_DEV_NAMES: - for i in range(count): - name = f"{dev}.{i}" if count > 1 else dev - size_order.append( (size, i, name) ) - api_order.append(name) - size_order.sort() - self.cf.set("vm_disk_names_size_order", ", ".join([elem[2] for elem in size_order])) - self.cf.set("vm_disk_names_api_order", ", ".join(api_order)) - - _gpg_cache = None - def load_gpg_file(self, fn): - if self._gpg_cache is None: - self._gpg_cache = {} - if fn in self._gpg_cache: - return self._gpg_cache[fn] - if self.options.verbose: - printf("GPG: %s", fn) - # file data directly - if not os.path.isfile(fn): - raise UsageError("GPG file not found: %s" % fn) - data = self.popen(['gpg', '-q', '-d', '--batch', fn]) - res = as_unicode(data) - self._gpg_cache[fn] = res - return res + def new_ssh_key(self, vm_id): + """Fetch output, parse keys. + """ + time_printf("Waiting for image copy, boot and SSH host key generation") + client = self.get_ec2_client() + keys = None + time.sleep(30) + retry = 0 + for i in range(100): + time.sleep(30) + # load console buffer from EC2 + for retry in range(3): + try: + cres = client.get_console_output(InstanceId=vm_id) + break + except socket.error as ex: + if ex.errno != errno.ETIMEDOUT: + raise + out = cres.get("Output") + if not out: + continue + keys = parse_console(out, ["ssh-ed25519"]) + if keys is not None: + break + if not keys: + raise UsageError("Failed to get SSH keys") + + # set ssh key as tag + ssh_tags = [] + for n, kval in enumerate(keys): + ktype = kval[0] + kcert = kval[1] + tag = {"Key": ktype, "Value": kcert} + ssh_tags.append(tag) + client.create_tags(Resources=[vm_id], Tags=ssh_tags) + + for vm in self.ec2_iter_instances(InstanceIds=[vm_id]): + pub_dns = vm.get("PublicDnsName") + pub_ip = vm.get("PublicIpAddress") + + if pub_ip: + for tag in ssh_tags: + ssh_add_known_host(self.get_ssh_known_hosts_file(vm_id), pub_dns, pub_ip, + tag["Key"], tag["Value"], vm_id) + + priv_dns = vm.get("PrivateDnsName") or None + priv_ip = vm.get("PrivateIpAddress") + if priv_ip: + for tag in ssh_tags: + ssh_add_known_host(self.get_ssh_known_hosts_file(vm_id), priv_dns, priv_ip, + tag["Key"], tag["Value"], vm_id) - def load_gpg_config(self, fn, main_section): - realfn = os.path.join(self.keys_dir, fn) - if not os.path.isfile(realfn): - raise UsageError("GPG file not found: %s" % realfn) - data = self.load_gpg_file(realfn) - cf = Config(main_section, None) - cf.cf.read_string(data, source=realfn) - return cf - - def popen(self, cmd, input_data=None, **kwargs): - """Read command stdout, check for exit code. + def put_known_host_from_tags(self, vm_id): + """Get ssh keys from tags. """ - pipe = subprocess.PIPE - if input_data is not None: - p = subprocess.Popen(cmd, stdin=pipe, stdout=pipe, stderr=pipe, **kwargs) + vm = self.vm_lookup(vm_id) + iplist = [] + if self.cf.getboolean("ssh_internal_ip_works", False): + iplist.append(vm["PrivateIpAddress"]) else: - p = subprocess.Popen(cmd, stdout=pipe, stderr=pipe, **kwargs) - out, err = p.communicate(input_data) - if p.returncode != 0: - raise Exception("command failed: %r - %r" % (cmd, err.strip())) - return out - - def load_command_docs(self): - doc = self.__doc__.strip() - doc = '' - grc = re.compile(r'Group: *(\w+)') - cmds = [] - - for fn in sorted(dir(self)): - if fn.startswith('cmd_'): - fobj = getattr(self, fn) - docstr = (getattr(fobj, '__doc__', '') or '').strip() - mgrp = grc.search(docstr) - grpname = mgrp and mgrp.group(1) or '' - lines = docstr.split('\n') - fdoc = lines[0] - cmd = fn[4:].replace('_', '-') - cmds.append((grpname, cmd, fdoc)) - - for sect in self.cf.sections(): - if sect.startswith('cmd.') or sect.startswith('alias.'): - cmd = sect.split('.', 1)[1] - desc = '' - grpname = '' - if self.cf.cf.has_option(sect, 'desc'): - desc = self.cf.cf.get(sect, 'desc') - if self.cf.cf.has_option(sect, 'group'): - grpname = self.cf.cf.get(sect, 'group') - fdoc = desc.strip().split('\n')[0] - cmds.append((grpname, cmd, desc)) - - cmds.sort() - last_grp = None - sep = '' - for grpname, cmd, fdoc in cmds: - if grpname != last_grp: - doc += sep + '%s commands:\n' % (grpname or 'ungrouped') - last_grp = grpname - sep = '\n' - doc += ' %-30s - %s\n' % (cmd, fdoc) - return doc - - def cmd_help(self): - """Show help about commands. + for iface in vm["NetworkInterfaces"]: + assoc = iface.get("Association") + if assoc: + ip = assoc["PublicIp"] + if ip: + iplist.append(ip) + ip = vm.get("PublicIpAddress") + if ip and ip not in iplist: + iplist.append(ip) - Group: info - """ - doc = self.load_command_docs() - printf(doc) - - def init_argparse(self, parser=None): - if parser is None: - parser = argparse.ArgumentParser(prog='vmtool') - p = super(VmTool, self).init_argparse(parser) - #doc = self.__doc__.strip() - #p.set_usage(doc) - p.add_argument("--env", help="Set environment name (default comes from VMTOOL_ENV_NAME)") - p.add_argument("--role", help="Set role name (default: None)") - p.add_argument("--host", help="Use host instead detecting") - p.add_argument("--all", action="store_true", help="Make command work over all envs") - p.add_argument("--ssh-key", help="Use different SSH key") - p.add_argument("--all-role-vms", action="store_true", help="Run command on all vms for role") - p.add_argument("--all-role-fo-vms", action="store_true", help="Run command on all failover vms for role") - p.add_argument("--earlier-fo-vms", action="store_true", help="Run command on earlier failover vms for role") - p.add_argument("--latest-fo-vm", action="store_true", help="Run command on latest failover vm for rolw") - p.add_argument("--running", action="store_true", help="Show only running instances") - p.add_argument("--az", type=int, help="Set availability zone") - p.add_argument("--tmux", action="store_true", help="Wrap session in tmux") - return p + old_keys = [] + new_keys = [] + for tag in vm.get("Tags", []): + k = tag["Key"] + v = tag["Value"] + if k.startswith("ecdsa-"): + old_keys.append((k, v)) + elif k.startswith("ssh-"): + new_keys.append((k, v)) + + if new_keys: + old_keys = [] + dns = None + for k, v in old_keys + new_keys: + for ip in iplist: + ssh_add_known_host(self.get_ssh_known_hosts_file(vm_id), dns, ip, k, v, vm_id) def get_boto3_session(self, region=None): if not region: @@ -340,17 +205,17 @@ def get_boto3_session(self, region=None): if self._boto_sessions is None: self._boto_sessions = {} if self._boto_sessions.get(region) is None: - profile_name = self.cf.get('aws_profile_name', '') or None - key = self.cf.get('aws_access_key', '') or None - sec = self.cf.get('aws_secret_key', '') or None + profile_name = self.cf.get("aws_profile_name", "") or None + key = self.cf.get("aws_access_key", "") or None + sec = self.cf.get("aws_secret_key", "") or None self._boto_sessions[region] = boto3.session.Session( profile_name=profile_name, region_name=region, aws_access_key_id=key, aws_secret_access_key=sec) return self._boto_sessions[region] def get_boto3_client(self, svc, region=None): - if svc == 'pricing': - region = 'us-east-1' # provided only in 'us-east-1' and 'ap-south-1' + if svc == "pricing": + region = "us-east-1" # provided only in "us-east-1" and "ap-south-1" elif not region: region = self._region if self._boto_clients is None: @@ -359,35 +224,35 @@ def get_boto3_client(self, svc, region=None): scode = (region, svc) if scode not in self._boto_clients: session = self.get_boto3_session(region) - conf = botocore.config.Config(retries = {'mode': 'adaptive', 'max_attempts': 10}) + conf = botocore.config.Config(retries={"mode": "adaptive", "max_attempts": 10}) self._boto_clients[scode] = session.client(svc, config=conf) return self._boto_clients[scode] def get_elb(self, region=None): """Get cached ELB connection. """ - return self.get_boto3_client('elb', region) + return self.get_boto3_client("elb", region) def get_s3(self, region=None): """Get cached S3 connection. """ - return self.get_boto3_client('s3', region) + return self.get_boto3_client("s3", region) def get_ddb(self, region=None): """Get cached DynamoDB connection. """ - return self.get_boto3_client('dynamodb', region) + return self.get_boto3_client("dynamodb", region) def get_route53(self): """Get cached ELB connection. """ - return self.get_boto3_client('route53') + return self.get_boto3_client("route53") def get_ec2_client(self, region=None): - return self.get_boto3_client('ec2', region) + return self.get_boto3_client("ec2", region) def get_pricing_client(self, region=None): - return self.get_boto3_client('pricing', region) + return self.get_boto3_client("pricing", region) def pager(self, client, method, rname): """Create pager function for looping over long results. @@ -402,9 +267,9 @@ def pager(**kwargs): def ec2_iter_instances(self, region=None, **kwargs): client = self.get_ec2_client(region) - pager = self.pager(client, 'describe_instances', 'Reservations') + pager = self.pager(client, "describe_instances", "Reservations") for rv in pager(**kwargs): - for vm in rv['Instances']: + for vm in rv["Instances"]: yield vm def ec2_iter(self, func, result, region=None, **kwargs): @@ -422,44 +287,44 @@ def s3_iter(self, func, result, region=None, **kwargs): def pricing_iter_services(self, **kwargs): """Pricing.Client.describe_services""" client = self.get_pricing_client() - pager = self.pager(client, 'describe_services', 'Services') + pager = self.pager(client, "describe_services", "Services") for rec in pager(**kwargs): yield rec def pricing_iter_products(self, **kwargs): """Pricing.Client.get_products""" client = self.get_pricing_client() - pager = self.pager(client, 'get_products', 'PriceList') + pager = self.pager(client, "get_products", "PriceList") for rec in pager(**kwargs): yield json.loads(rec) def pricing_iter_attribute_values(self, **kwargs): """Pricing.Client.get_attribute_values""" client = self.get_pricing_client() - pager = self.pager(client, 'get_attribute_values', 'AttributeValues') + pager = self.pager(client, "get_attribute_values", "AttributeValues") for rec in pager(**kwargs): yield rec def get_region_desc(self, region): if self._endpoints is None: - self._endpoints = botocore.session.get_session().get_data('endpoints') - if self._endpoints['version'] != 3: - raise Exception("unsupported endpoints version: %d" % self._endpoints['version']) - for part in self._endpoints['partitions']: - if part['partition'] == 'aws': # aws, aws-us-gov, aws-cn - desc = part['regions'][region]['description'] - desc = desc.replace('Europe', 'EU') # botocore vs. us-east-1/pricing bug + self._endpoints = botocore.session.get_session().get_data("endpoints") + if self._endpoints["version"] != 3: + raise Exception("unsupported endpoints version: %d" % self._endpoints["version"]) + for part in self._endpoints["partitions"]: + if part["partition"] == "aws": # aws, aws-us-gov, aws-cn + desc = part["regions"][region]["description"] + desc = desc.replace("Europe", "EU") # botocore vs. us-east-1/pricing bug return desc raise Exception("did not find 'aws' partition") VOL_TYPE_DESC = { - 'standard': 'Magnetic', - 'gp2': 'General Purpose', - 'gp3': 'General Purpose', - 'io1': 'Provisioned IOPS', - 'io2': 'Provisioned IOPS', - 'st1': 'Throughput Optimized HDD', - 'sc1': 'Cold HDD', + "standard": "Magnetic", + "gp2": "General Purpose", + "gp3": "General Purpose", + "io1": "Provisioned IOPS", + "io2": "Provisioned IOPS", + "st1": "Throughput Optimized HDD", + "sc1": "Cold HDD", } VOL_TYPES = tuple(VOL_TYPE_DESC) VOL_ENC_TYPES = tuple("enc-" + x for x in VOL_TYPES) @@ -468,24 +333,24 @@ def get_volume_desc(self, vol_type): return self.VOL_TYPE_DESC[vol_type] STORAGE_FILTER = { - 'STANDARD': {'productFamily': 'Storage', 'volumeType': 'Standard'}, - 'STANDARD_IA': {'productFamily': 'Storage', 'volumeType': 'Standard - Infrequent Access'}, - 'ONEZONE_IA': {'productFamily': 'Storage', 'volumeType': 'One Zone - Infrequent Access'}, - 'GLACIER': {'productFamily': 'Storage', 'volumeType': 'Amazon Glacier'}, + "STANDARD": {"productFamily": "Storage", "volumeType": "Standard"}, + "STANDARD_IA": {"productFamily": "Storage", "volumeType": "Standard - Infrequent Access"}, + "ONEZONE_IA": {"productFamily": "Storage", "volumeType": "One Zone - Infrequent Access"}, + "GLACIER": {"productFamily": "Storage", "volumeType": "Amazon Glacier"}, # deprecated - 'REDUCED_REDUNDANCY': {'productFamily': 'Storage', 'volumeType': 'Reduced Redundancy'}, + "REDUCED_REDUNDANCY": {"productFamily": "Storage", "volumeType": "Reduced Redundancy"}, # buggy pricing data - 'DEEP_ARCHIVE': {'volumeType': 'Glacier Deep Archive'}, - 'INTELLIGENT_TIERING': {'storageClass': 'Intelligent-Tiering'}, + "DEEP_ARCHIVE": {"volumeType": "Glacier Deep Archive"}, + "INTELLIGENT_TIERING": {"storageClass": "Intelligent-Tiering"}, } def get_storage_filter(self, storage_class): """Return filter for pricing query. """ - #storageClass: ['Archive', 'General Purpose', 'Infrequent Access', 'Intelligent-Tiering', 'Non-Critical Data', 'Staging', 'Tags'] - #volumeType: ['Amazon Glacier', 'Glacier Deep Archive', 'Intelligent-Tiering Frequent Access', - # 'Intelligent-Tiering Infrequent Access', 'Intelligent-Tiering', 'One Zone - Infrequent Access', - # 'Reduced Redundancy', 'Standard - Infrequent Access', 'Standard', 'Tags'] + #storageClass: ["Archive", "General Purpose", "Infrequent Access", "Intelligent-Tiering", "Non-Critical Data", "Staging", "Tags"] + #volumeType: ["Amazon Glacier", "Glacier Deep Archive", "Intelligent-Tiering Frequent Access", + # "Intelligent-Tiering Infrequent Access", "Intelligent-Tiering", "One Zone - Infrequent Access", + # "Reduced Redundancy", "Standard - Infrequent Access", "Standard", "Tags"] return self.STORAGE_FILTER[storage_class] def get_cached_pricing(self, **kwargs): @@ -499,26 +364,30 @@ def get_cached_pricing(self, **kwargs): """ filters = [] for k, v in kwargs.items(): - filters.append({'Type': 'TERM_MATCH', 'Field': k, 'Value': v}) + filters.append({"Type": "TERM_MATCH", "Field": k, "Value": v}) cache_key = json.dumps(kwargs, sort_keys=True) if cache_key not in self._pricing_cache: res = [] - for rec in self.pricing_iter_products(FormatVersion='aws_v1', ServiceCode=kwargs.get('ServiceCode'), Filters=filters): + for rec in self.pricing_iter_products( + FormatVersion="aws_v1", + ServiceCode=kwargs.get("ServiceCode"), + Filters=filters + ): res.append(rec) if len(res) != 1: raise UsageError("Broken pricing filter: expect 1 row, got %d, cache_key: %s" % ( - len(res), cache_key) - ) + len(res), cache_key + )) self._pricing_cache[cache_key] = res[0] return self._pricing_cache[cache_key] def get_offer_price(self, offer, unit): - prices = list(offer['priceDimensions'].values()) + prices = list(offer["priceDimensions"].values()) if len(prices) != 1: - raise Exception('prices: expected one value, got %d' % len(prices)) - if prices[0]['unit'] != unit: - raise Exception('prices: expected %s, got %s' % (unit, prices[0]['unit'])) - return float(prices[0]['pricePerUnit']['USD']) + raise Exception("prices: expected one value, got %d" % len(prices)) + if prices[0]["unit"] != unit: + raise Exception("prices: expected %s, got %s" % (unit, prices[0]["unit"])) + return float(prices[0]["pricePerUnit"]["USD"]) def get_vm_pricing(self, region, vmtype): """Return simplified price object for vm cost. @@ -526,78 +395,78 @@ def get_vm_pricing(self, region, vmtype): def loadOnDemand(vmdata): """Return hourly price for ondemand instances.""" - offers = list(vmdata['terms']['OnDemand'].values()) + offers = list(vmdata["terms"]["OnDemand"].values()) if len(offers) != 1: - raise Exception('OnDemand.offers: expected one value, got %d' % len(offers)) - return self.get_offer_price(offers[0], 'Hrs') + raise Exception("OnDemand.offers: expected one value, got %d" % len(offers)) + return self.get_offer_price(offers[0], "Hrs") def loadReserved(vmdata): """Return hourly price for reserved (no-upfront/standard/1yr) instances.""" got = [] - for offer in vmdata['terms']['Reserved'].values(): - atts = offer['termAttributes'] - opt = atts['PurchaseOption'] # No Upfront, All Upfront, Partial Upfront - cls = atts['OfferingClass'] # standard, convertible - lse = atts['LeaseContractLength'] # 1yr, 3yr - if (opt, cls, lse) == ('No Upfront', 'standard', '1yr'): - got.append(self.get_offer_price(offer, 'Hrs')) + for offer in vmdata["terms"]["Reserved"].values(): + atts = offer["termAttributes"] + opt = atts["PurchaseOption"] # No Upfront, All Upfront, Partial Upfront + cls = atts["OfferingClass"] # standard, convertible + lse = atts["LeaseContractLength"] # 1yr, 3yr + if (opt, cls, lse) == ("No Upfront", "standard", "1yr"): + got.append(self.get_offer_price(offer, "Hrs")) if len(got) != 1: - raise Exception('expected one value, got %d' % len(got)) + raise Exception("expected one value, got %d" % len(got)) return got[0] vmdata = self.get_cached_pricing( - ServiceCode='AmazonEC2', - locationType='AWS Region', + ServiceCode="AmazonEC2", + locationType="AWS Region", location=self.get_region_desc(region), - productFamily='Compute Instance', - preInstalledSw='NA', # NA, SQL Ent, SQL Std, SQL Web - operatingSystem='Linux', # NA, Linux, RHEL, SUSE, Windows - tenancy='Shared', # NA, Dedicated, Host, Reserved, Shared - capacitystatus='Used', # NA, Used, AllocatedCapacityReservation, AllocatedHost, UnusedCapacityReservation + productFamily="Compute Instance", + preInstalledSw="NA", # NA, SQL Ent, SQL Std, SQL Web + operatingSystem="Linux", # NA, Linux, RHEL, SUSE, Windows + tenancy="Shared", # NA, Dedicated, Host, Reserved, Shared + capacitystatus="Used", # NA, Used, AllocatedCapacityReservation, AllocatedHost, UnusedCapacityReservation instanceType=vmtype, ) return { - 'onDemandHourly': loadOnDemand(vmdata), - 'reservedHourly': loadReserved(vmdata), + "onDemandHourly": loadOnDemand(vmdata), + "reservedHourly": loadReserved(vmdata), } def get_volume_pricing(self, region, vol_type): """Return numeric price for volume cost. """ p = self.get_cached_pricing( - ServiceCode='AmazonEC2', - locationType='AWS Region', + ServiceCode="AmazonEC2", + locationType="AWS Region", location=self.get_region_desc(region), - productFamily='Storage', + productFamily="Storage", volumeType=self.get_volume_desc(vol_type)) - offers = list(p['terms']['OnDemand'].values()) + offers = list(p["terms"]["OnDemand"].values()) if len(offers) != 1: - raise Exception('expected one value, got %d' % len(offers)) - return self.get_offer_price(offers[0], 'GB-Mo') + raise Exception("expected one value, got %d" % len(offers)) + return self.get_offer_price(offers[0], "GB-Mo") def get_s3_pricing(self, region, storage_class, size): """Return numeric price for volume cost. """ p = self.get_cached_pricing( - ServiceCode='AmazonS3', - locationType='AWS Region', + ServiceCode="AmazonS3", + locationType="AWS Region", location=self.get_region_desc(region), **self.get_storage_filter(storage_class)) - offers = list(p['terms']['OnDemand'].values()) + offers = list(p["terms"]["OnDemand"].values()) if len(offers) != 1: - raise Exception('expected one value, got %d' % len(offers)) + raise Exception("expected one value, got %d" % len(offers)) # S3 prices are in segments total = 0 - for pdim in offers[0]['priceDimensions'].values(): - if pdim['unit'] != 'GB-Mo': - raise Exception('expected GB-Mo, got %s' % pdim['unit']) - beginRange = int(pdim['beginRange']) - if pdim['endRange'] != 'Inf': - endRange = int(pdim['endRange']) + for pdim in offers[0]["priceDimensions"].values(): + if pdim["unit"] != "GB-Mo": + raise Exception("expected GB-Mo, got %s" % pdim["unit"]) + beginRange = int(pdim["beginRange"]) + if pdim["endRange"] != "Inf": + endRange = int(pdim["endRange"]) else: endRange = size @@ -607,39 +476,39 @@ def get_s3_pricing(self, region, storage_class, size): curblk = endRange - beginRange else: curblk = size - beginRange - total += curblk * float(pdim['pricePerUnit']['USD']) + total += curblk * float(pdim["pricePerUnit"]["USD"]) return total def cmd_debug_pricing(self): - svcNames = [svc['ServiceCode'] for svc in self.pricing_iter_services()] - print('ServiceCode=%s' % svcNames) + svcNames = [svc["ServiceCode"] for svc in self.pricing_iter_services()] + print("ServiceCode=%s" % svcNames) - svclist = list(self.pricing_iter_services(ServiceCode='AmazonEC2')) - print('%s=%s' % (svclist[0]['ServiceCode'], json.dumps(svclist[0], indent=2))) + svclist = list(self.pricing_iter_services(ServiceCode="AmazonEC2")) + print("%s=%s" % (svclist[0]["ServiceCode"], json.dumps(svclist[0], indent=2))) for svc in svclist: - code = svc['ServiceCode'] - for att in sorted(svc['AttributeNames']): + code = svc["ServiceCode"] + for att in sorted(svc["AttributeNames"]): vlist = [] cnt = 0 for val in self.pricing_iter_attribute_values(ServiceCode=code, AttributeName=att): if cnt > 100: - vlist.append('...') + vlist.append("...") break - vlist.append(val['Value']) + vlist.append(val["Value"]) cnt += 1 - print('%s.%s = %r' % (code, att, vlist)) + print("%s.%s = %r" % (code, att, vlist)) def route53_iter_rrsets(self, **kwargs): client = self.get_route53() - pager = self.pager(client, 'list_resource_record_sets', 'ResourceRecordSets') + pager = self.pager(client, "list_resource_record_sets", "ResourceRecordSets") return pager(**kwargs) def sgroups_lookup(self, sgs_list): # manual lookup for sgs sg_ids = [] for sg in sgs_list: - if sg.startswith('sg-'): + if sg.startswith("sg-"): sg_ids.append(sg) else: raise UsageError("deprecated non-id sg: %r" % sg) @@ -651,10 +520,10 @@ def show_vm_list(self, vm_list, adrmap=None, dnsmap=None): use_colors = sys.stdout.isatty() if use_colors: - if sys.platform.startswith('win'): + if sys.platform.startswith("win"): use_colors = False - vm_list = sorted(vm_list, key=lambda vm: vm['LaunchTime']) + vm_list = sorted(vm_list, key=lambda vm: vm["LaunchTime"]) extra_verbose = self.options.verbose and self.options.verbose > 1 vol_map = {} @@ -663,34 +532,34 @@ def show_vm_list(self, vm_list, adrmap=None, dnsmap=None): for vm in vm_list: if not self.options.all: - if not self._check_tags(vm.get('Tags')): + if not self._check_tags(vm.get("Tags")): continue eip = "" name = "" extra_lines = [] - if vm.get('InstanceId') in adrmap: - eip += " EIP=%s" % adrmap[vm['InstanceId']] - if vm.get('PrivateIpAddress') in dnsmap: - eip += " IDNS=" + dnsmap[vm['PrivateIpAddress']] - if vm.get('PublicIpAddress') in dnsmap: - eip += " PDNS=" + dnsmap[vm['PublicIpAddress']] - if len(vm['NetworkInterfaces']) > 1: - for iface in vm['NetworkInterfaces']: - att = iface['Attachment'] - sep = '' - eni = 'net#%s - %s - IP=' % (att['DeviceIndex'], att['Status']) - for adr in iface['PrivateIpAddresses']: - eni += sep + adr['PrivateIpAddress'] - sep = ',' - if adr.get('Association'): - eni += ' (%s)' % (adr['Association']['PublicIp']) - if iface['Attachment']['DeleteOnTermination']: - eni += ' del=yes' + if vm.get("InstanceId") in adrmap: + eip += " EIP=%s" % adrmap[vm["InstanceId"]] + if vm.get("PrivateIpAddress") in dnsmap: + eip += " IDNS=" + dnsmap[vm["PrivateIpAddress"]] + if vm.get("PublicIpAddress") in dnsmap: + eip += " PDNS=" + dnsmap[vm["PublicIpAddress"]] + if len(vm["NetworkInterfaces"]) > 1: + for iface in vm["NetworkInterfaces"]: + att = iface["Attachment"] + sep = "" + eni = "net#%s - %s - IP=" % (att["DeviceIndex"], att["Status"]) + for adr in iface["PrivateIpAddresses"]: + eni += sep + adr["PrivateIpAddress"] + sep = "," + if adr.get("Association"): + eni += " (%s)" % (adr["Association"]["PublicIp"]) + if iface["Attachment"]["DeleteOnTermination"]: + eni += " del=yes" else: - eni += ' del=no' - eni += ' ENI=' + iface['NetworkInterfaceId'] - if iface.get('Description'): - eni += ' desc=' + iface.get('Description') + eni += " del=no" + eni += " ENI=" + iface["NetworkInterfaceId"] + if iface.get("Description"): + eni += " desc=" + iface.get("Description") extra_lines.append(eni) # add colors @@ -698,272 +567,114 @@ def show_vm_list(self, vm_list, adrmap=None, dnsmap=None): c2 = "" if use_colors: if eip: - if vm['State']['Name'] == 'running': + if vm["State"]["Name"] == "running": c1 = "\033[32m" # green else: c1 = "\033[35m" # light purple - elif vm['State']['Name'] == 'running': + elif vm["State"]["Name"] == "running": c1 = "\033[31m" # red # close color if c1: c2 = "\033[0m" - vm_env = '-' - vm_role = '' - for tag in vm.get('Tags', []): - if tag['Key'] == 'Env': - vm_env = tag['Value'] - elif tag['Key'] == 'Role': - vm_role = tag['Value'] + vm_env = "-" + vm_role = "" + for tag in vm.get("Tags", []): + if tag["Key"] == "Env": + vm_env = tag["Value"] + elif tag["Key"] == "Role": + vm_role = tag["Value"] name += " Env=" + vm_env if vm_role: - name += '.' + vm_role + name += "." + vm_role - name += ' type=%s' % vm['InstanceType'] + name += " type=%s" % vm["InstanceType"] tags = "" - for tagname in ['Date', 'Commit', 'PYI', 'DBI', 'JSI', 'SYS']: - for tag in vm.get('Tags', []): - if tag['Key'] == tagname: - tags += ' %s=%s' % (tagname, tag['Value']) + for tagname in ["Date", "Commit", "PYI", "DBI", "JSI", "SYS"]: + for tag in vm.get("Tags", []): + if tag["Key"] == tagname: + tags += " %s=%s" % (tagname, tag["Value"]) - az = vm['Placement']['AvailabilityZone'] + az = vm["Placement"]["AvailabilityZone"] if az[-2].isdigit() and az[-1].islower(): - name += ' AZ=%d' % (ord(az[-1]) - ord('a')) + name += " AZ=%d" % (ord(az[-1]) - ord("a")) else: - name += ' AZ=%s' % az + name += " AZ=%s" % az - int_ip = '' - if vm.get('PrivateIpAddress'): - int_ip = ' ip=%s' % vm['PrivateIpAddress'] + int_ip = "" + if vm.get("PrivateIpAddress"): + int_ip = " ip=%s" % vm["PrivateIpAddress"] # one-line output - printf("%s [%s%s%s]%s%s%s%s", vm['InstanceId'], c1, vm['State']['Name'], c2, name, tags, int_ip, eip) + printf("%s [%s%s%s]%s%s%s%s", vm["InstanceId"], c1, vm["State"]["Name"], c2, name, tags, int_ip, eip) if self.options.verbose and extra_lines: for xln in extra_lines: - printf(' %s', xln) - printf('') + printf(" %s", xln) + printf("") if not extra_verbose: continue # verbose output - printf(' LaunchTime: %s', vm['LaunchTime']) - if vm.get('RootDeviceName'): - printf(' RootDevice: %s - %s', vm['RootDeviceType'], vm['RootDeviceName']) - if vm.get('IamInstanceProfile'): - printf(' IamInstanceProfile: %s', vm['IamInstanceProfile']['Arn']) - printf(' Zone=%s', vm['Placement']['AvailabilityZone']) - if vm.get('PublicIpAddress'): - printf(" PublicIpAddress: %s / %s", vm['PublicIpAddress'], (vm.get('PublicDnsName') or '-')) - if vm.get('PrivateIpAddress'): - printf(" PrivateIpAddress: %s / %s", vm['PrivateIpAddress'], (vm.get('PrivateDnsName') or '-')) - printf(" Groups: %s", ', '.join([g['GroupName'] for g in vm['SecurityGroups']])) - for iface in vm.get('NetworkInterfaces', []): - printf(' NetworkInterface id=%s', iface.get('NetworkInterfaceId')) - printf(' Association=%s', iface.get('Association')) - printf(' PrivateIpAddresses=%s', iface.get('PrivateIpAddresses')) - for bdev in vm.get('BlockDeviceMappings', []): - ebs = bdev.get('Ebs') + printf(" LaunchTime: %s", vm["LaunchTime"]) + if vm.get("RootDeviceName"): + printf(" RootDevice: %s - %s", vm["RootDeviceType"], vm["RootDeviceName"]) + if vm.get("IamInstanceProfile"): + printf(" IamInstanceProfile: %s", vm["IamInstanceProfile"]["Arn"]) + printf(" Zone=%s", vm["Placement"]["AvailabilityZone"]) + if vm.get("PublicIpAddress"): + printf(" PublicIpAddress: %s / %s", vm["PublicIpAddress"], (vm.get("PublicDnsName") or "-")) + if vm.get("PrivateIpAddress"): + printf(" PrivateIpAddress: %s / %s", vm["PrivateIpAddress"], (vm.get("PrivateDnsName") or "-")) + printf(" Groups: %s", ", ".join([g["GroupName"] for g in vm["SecurityGroups"]])) + for iface in vm.get("NetworkInterfaces", []): + printf(" NetworkInterface id=%s", iface.get("NetworkInterfaceId")) + printf(" Association=%s", iface.get("Association")) + printf(" PrivateIpAddresses=%s", iface.get("PrivateIpAddresses")) + for bdev in vm.get("BlockDeviceMappings", []): + ebs = bdev.get("Ebs") if ebs: - vol = vol_map[ebs['VolumeId']] - printf(' BlockDeviceMapping name=%s size=%d type=%s vol=%s', - bdev.get('DeviceName'), - vol['Size'], - vol['VolumeType'], - ebs['VolumeId']) + vol = vol_map[ebs["VolumeId"]] + printf(" BlockDeviceMapping name=%s size=%d type=%s vol=%s", + bdev.get("DeviceName"), + vol["Size"], + vol["VolumeType"], + ebs["VolumeId"]) #print_json(vol) else: - printf(' BlockDeviceMapping name=%s', bdev.get('DeviceName')) + printf(" BlockDeviceMapping name=%s", bdev.get("DeviceName")) print_json(bdev) printf(" Tags:") - for tag in sorted(vm.get('Tags', []), key=lambda tag: tag['Key']): - printf(' %s=%s', tag['Key'], tag['Value']) - for k in ('State', 'StateReason', 'StateTransitionReason'): + for tag in sorted(vm.get("Tags", []), key=lambda tag: tag["Key"]): + printf(" %s=%s", tag["Key"], tag["Value"]) + for k in ("State", "StateReason", "StateTransitionReason"): if vm.get(k): - printf(' %s: %s', k, vm[k]) + printf(" %s: %s", k, vm[k]) if self.options.verbose > 2: print_json(vm) - printf('') - + printf("") def get_volume_map(self, vm_list): vmap = {} vols = set() for vm in vm_list: if not self.options.all: - if not self._check_tags(vm.get('Tags')): + if not self._check_tags(vm.get("Tags")): continue - for bdev in vm.get('BlockDeviceMappings'): - ebs = bdev.get('Ebs') + for bdev in vm.get("BlockDeviceMappings"): + ebs = bdev.get("Ebs") if ebs: - vols.add(ebs['VolumeId']) + vols.add(ebs["VolumeId"]) #printf("get_volume_map: %r", vols) - for vol in self.ec2_iter('describe_volumes', 'Volumes', VolumeIds=list(vols)): - vmap[vol['VolumeId']] = vol + for vol in self.ec2_iter("describe_volumes", "Volumes", VolumeIds=list(vols)): + vmap[vol["VolumeId"]] = vol return vmap - def get_ssh_kfile(self): - # load encrypted key - if self.options.ssh_key: - gpg_fn = self.options.ssh_key - else: - gpg_fn = self.cf.get('ssh_privkey_file') - gpg_fn = os.path.join(self.keys_dir, gpg_fn) - kdata = self.load_gpg_file(gpg_fn).strip() - - raw_fn = os.path.basename(gpg_fn).replace('.gpg', '') - - fn = os.path.join(self.ssh_dir, raw_fn) - - # check existing key - if os.path.isfile(fn): - curdata = open(fn, 'r').read().strip() - if curdata == kdata: - return fn - os.remove(fn) - - printf("Extracting keyfile %s to %s", gpg_fn, fn) - fd = os.open(fn, os.O_CREAT | os.O_WRONLY, stat.S_IRUSR | stat.S_IWUSR) - with os.fdopen(fd, "w") as f: - f.write(kdata + "\n") - return fn - - def get_ssh_known_hosts_file(self, vm_id): - return self.ssh_known_hosts + '_' + vm_id - - def ssh_cmdline(self, vm_id, use_admin=False): - if self.cf.getboolean('ssh_admin_user_disabled', False): - ssh_user = self.cf.get('user') - elif use_admin: - ssh_user = self.cf.get('ssh_admin_user') - else: - ssh_user = self.cf.get('user') - - ssh_debug = '-q' - if self.options.verbose: - ssh_debug = '-v' - - ssh_options = shlex.split(self.cf.get('ssh_options', '')) - - return ['ssh', ssh_debug, '-i', self.get_ssh_kfile(), '-l', ssh_user, - '-o', 'UserKnownHostsFile=' + self.get_ssh_known_hosts_file(vm_id)] + ssh_options - - def vm_exec_tmux(self, vm_id, cmdline, use_admin=False, title=None): - if self.options.tmux: - tmux_command = shlex.split(self.cf.get('tmux_command')) - if title: - tmux_command = [a.replace('{title}', title) for a in tmux_command] - cmdline = tmux_command + cmdline - self.vm_exec(vm_id, cmdline, use_admin=use_admin) - - def vm_exec(self, vm_id, cmdline, stdin=None, get_output=False, check_error=True, use_admin=False): - logging.debug("EXEC@%s: %s", vm_id, cmdline) - self.put_known_host_from_tags(vm_id) - - # only image default user works? - if not self.cf.getboolean('ssh_user_access_works', False): - use_admin = True - - ssh = self.ssh_cmdline(vm_id, use_admin=use_admin) - - if not stdin and not get_output and sys.stdout.isatty(): # pylint:disable=no-member - ssh.append('-t') - - if self.options.host: - # use host directly, dangerous - hostname = self.options.host - elif self.cf.getboolean('ssh_internal_ip_works', False): - vm = self.vm_lookup(vm_id) - hostname = vm.get('PrivateIpAddress') - else: - # FIXME: vm with ENI - vm = self.vm_lookup(vm_id) - #hostname = vm.get('PublicDnsName') - hostname = vm.get('PublicIpAddress') - last_idx = 600 * 1024 * 1024 * 1024 - if len(vm['NetworkInterfaces']) > 1: - for iface in vm['NetworkInterfaces']: - #print_json(iface) - idx = iface['Attachment']['DeviceIndex'] - if 1 or idx < last_idx: - assoc = iface.get('Association') - if assoc: - hostname = assoc['PublicIp'] - last_idx = idx - break - eprintf("SSH to %s", hostname) - if not hostname: - logging.error("Public DNS nor ip not yet available for node %r", vm_id) - #print_json(vm) - sys.exit(1) - - ssh.append(hostname) - if isinstance(cmdline, str): - ssh += [cmdline] - else: - logging.debug('EXEC: rsh_quote=%r', cmdline) - ssh += rsh_quote(cmdline) - out = None - kwargs = {} - if stdin is not None: - kwargs['stdin'] = subprocess.PIPE - if get_output: - kwargs['stdout'] = subprocess.PIPE - logging.debug('EXEC: cmd=%r', ssh) - logging.debug('EXEC: kwargs=%r', kwargs) - if kwargs: - p = subprocess.Popen(ssh, **kwargs) - out, err = p.communicate(stdin) - ret = p.returncode - else: - ret = subprocess.call(ssh) - if ret != 0: - if check_error: - raise UsageError("Errorcode: %r" % ret) - return None - return out - - def vm_rsync(self, *args, use_admin=False): - primary_id = None - nargs = [] - ids = [] - for a in args: - t = a.split(':', 1) - if len(t) == 1: - nargs.append(a) - continue - if t[0]: - vm_id = t[0] - elif primary_id: - vm_id = primary_id - else: - vm_id = primary_id = self.get_primary_vms()[0] - vm = self.vm_lookup(vm_id) - self.put_known_host_from_tags(vm_id) - vm = self.vm_lookup(vm_id) - if self.cf.getboolean('ssh_internal_ip_works', False): - hostname = vm.get('PrivateIpAddress') - else: - hostname = vm.get('PublicIpAddress') - a = "%s:%s" % (hostname, t[1]) - nargs.append(a) - ids.append(vm_id) - - ssh_list = self.ssh_cmdline(vm_id, use_admin=use_admin) - ssh_cmd = ' '.join(rsh_quote(ssh_list)) - - cmd = ['rsync', '-rtz', '-e', ssh_cmd] - if self.options.verbose: - cmd.append('-P') - cmd += nargs - self.log.debug("rsync: %r", cmd) - run_successfully(cmd) - def vm_lookup(self, vm_id, ignore_env=False, cache=True): if self._vm_map is None: self._vm_map = {} @@ -971,102 +682,15 @@ def vm_lookup(self, vm_id, ignore_env=False, cache=True): return self._vm_map[vm_id] for vm in self.ec2_iter_instances(InstanceIds=[vm_id]): - if vm['State']['Name'] != 'running': - raise UsageError("VM not running: %s / %r" % (vm_id, vm['State'])) + if vm["State"]["Name"] != "running": + raise UsageError("VM not running: %s / %r" % (vm_id, vm["State"])) if not ignore_env: - if not self._check_tags(vm.get('Tags')): + if not self._check_tags(vm.get("Tags")): continue self._vm_map[vm_id] = vm return vm raise UsageError("VM not found: %s" % vm_id) - def new_ssh_key(self, vm_id): - """Fetch output, parse keys. - """ - time_printf("Waiting for image copy, boot and SSH host key generation") - client = self.get_ec2_client() - keys = None - time.sleep(30) - retry = 0 - for i in range(100): - time.sleep(30) - # load console buffer from EC2 - for retry in range(3): - try: - cres = client.get_console_output(InstanceId=vm_id) - break - except socket.error as ex: - if ex.errno != errno.ETIMEDOUT: - raise - out = cres.get('Output') - if not out: - continue - keys = parse_console(out, ['ssh-ed25519']) - if keys is not None: - break - if not keys: - raise UsageError("Failed to get SSH keys") - - # set ssh key as tag - ssh_tags = [] - for n, kval in enumerate(keys): - ktype = kval[0] - kcert = kval[1] - tag = {'Key': ktype, 'Value': kcert} - ssh_tags.append(tag) - client.create_tags(Resources=[vm_id], Tags=ssh_tags) - - for vm in self.ec2_iter_instances(InstanceIds=[vm_id]): - pub_dns = vm.get('PublicDnsName') - pub_ip = vm.get('PublicIpAddress') - - if pub_ip: - for tag in ssh_tags: - ssh_add_known_host(self.get_ssh_known_hosts_file(vm_id), pub_dns, pub_ip, - tag['Key'], tag['Value'], vm_id) - - priv_dns = vm.get('PrivateDnsName') or None - priv_ip = vm.get('PrivateIpAddress') - if priv_ip: - for tag in ssh_tags: - ssh_add_known_host(self.get_ssh_known_hosts_file(vm_id), priv_dns, priv_ip, - tag['Key'], tag['Value'], vm_id) - - def put_known_host_from_tags(self, vm_id): - """Get ssh keys from tags. - """ - vm = self.vm_lookup(vm_id) - iplist = [] - if self.cf.getboolean('ssh_internal_ip_works', False): - iplist.append(vm['PrivateIpAddress']) - else: - for iface in vm['NetworkInterfaces']: - assoc = iface.get('Association') - if assoc: - ip = assoc['PublicIp'] - if ip: - iplist.append(ip) - ip = vm.get('PublicIpAddress') - if ip and ip not in iplist: - iplist.append(ip) - - old_keys = [] - new_keys = [] - for tag in vm.get('Tags', []): - k = tag['Key'] - v = tag['Value'] - if k.startswith('ecdsa-'): - old_keys.append( (k, v) ) - elif k.startswith('ssh-'): - new_keys.append( (k, v) ) - - if new_keys: - old_keys = [] - dns = None - for k, v in old_keys + new_keys: - for ip in iplist: - ssh_add_known_host(self.get_ssh_known_hosts_file(vm_id), dns, ip, k, v, vm_id) - def get_env_filters(self): """Return default filters based on command-line swithces. """ @@ -1078,16 +702,15 @@ def make_env_filters(self, role_name=None, running=True, allenvs=False): filters = [] if not allenvs: - filters.append({'Name': 'tag:Env', 'Values': [self.env_name]}) + filters.append({"Name": "tag:Env", "Values": [self.env_name]}) if role_name or self.role_name: - filters.append({'Name': 'tag:Role', 'Values': [role_name or self.role_name]}) + filters.append({"Name": "tag:Role", "Values": [role_name or self.role_name]}) if running: - filters.append({'Name': 'instance-state-name', 'Values': ['running']}) + filters.append({"Name": "instance-state-name", "Values": ["running"]}) return filters - def get_running_vms(self, role_name=None): vmlist = [] @@ -1096,76 +719,76 @@ def get_running_vms(self, role_name=None): filters = self.make_env_filters(role_name=role_name, running=True) for vm in self.ec2_iter_instances(Filters=filters): - if not self._check_tags(vm.get('Tags'), force_role=True, role_name=role_name): + if not self._check_tags(vm.get("Tags"), force_role=True, role_name=role_name): continue - if vm['State']['Name'] == 'running': + if vm["State"]["Name"] == "running": vmlist.append(vm) return vmlist def get_dead_primary(self): ec2 = self.get_ec2_client() - eip = self.cf.get('domain_eip', '') + eip = self.cf.get("domain_eip", "") main_vms = [] if eip: ipfilter = { - 'Name': 'public-ip', - 'Values': [eip] + "Name": "public-ip", + "Values": [eip] } res = ec2.describe_addresses(Filters=[ipfilter]) - for addr in res['Addresses']: - if not addr.get('InstanceId'): + for addr in res["Addresses"]: + if not addr.get("InstanceId"): continue - if addr['PublicIp'] == eip: - main_vms.append(addr['InstanceId']) + if addr["PublicIp"] == eip: + main_vms.append(addr["InstanceId"]) break if main_vms: for vm in self.ec2_iter_instances(Filters=self.get_env_filters(), InstanceIds=main_vms): - if not self._check_tags(vm.get('Tags'), True): + if not self._check_tags(vm.get("Tags"), True): continue - if vm['State']['Name'] != 'running': - eprintf("Dead Primary VM for %s is %s", self.full_role, ','.join(main_vms)) + if vm["State"]["Name"] != "running": + eprintf("Dead Primary VM for %s is %s", self.full_role, ",".join(main_vms)) return main_vms else: - raise UsageError('Primary VM still running') + raise UsageError("Primary VM still running") raise UsageError("Primary VM not found based on EIP") dnsmap = self.get_dns_map() for vm in self.ec2_iter_instances(Filters=self.get_env_filters()): - if not self._check_tags(vm.get('Tags'), True): + if not self._check_tags(vm.get("Tags"), True): continue - if vm.get('PrivateIpAddress') in dnsmap: + if vm.get("PrivateIpAddress") in dnsmap: pass - elif vm.get('PublicIpAddress') in dnsmap: + elif vm.get("PublicIpAddress") in dnsmap: pass else: continue - if vm['State']['Name'] == 'running': - raise UsageError('Primary VM still running') - main_vms.append(vm['InstanceId']) + if vm["State"]["Name"] == "running": + raise UsageError("Primary VM still running") + main_vms.append(vm["InstanceId"]) if not main_vms: raise UsageError("Dead Primary VM not found") - eprintf("Dead Primary VM for %s is %s", self.full_role, ','.join(main_vms)) + eprintf("Dead Primary VM for %s is %s", self.full_role, ",".join(main_vms)) return main_vms def get_primary_for_role(self, role_name, instance_id=None): filters = self.make_env_filters(role_name=role_name, running=True) dns_map = self.get_dns_map(True) for vm in self.ec2_iter_instances(Filters=filters): - if not self._check_tags(vm.get('Tags'), role_name=role_name, force_role=True): + if not self._check_tags(vm.get("Tags"), role_name=role_name, force_role=True): continue - if vm['State']['Name'] != 'running': + if vm["State"]["Name"] != "running": continue # ignore IP checks if instance_id is manually provided if instance_id is not None: - if vm['InstanceId'] == instance_id: + if vm["InstanceId"] == instance_id: return vm - elif vm.get('PrivateIpAddress') in dns_map: + elif vm.get("PrivateIpAddress") in dns_map: return vm - #elif vm.get('PublicIpAddress') in dns_map: + #elif vm.get("PublicIpAddress") in dns_map: # return vm raise UsageError("Primary VM not found: %s" % role_name) @@ -1178,44 +801,44 @@ def get_primary_vms(self): main_vms = self._get_primary_vms() if main_vms: - eprintf("Primary VM for %s is %s", self.full_role, ','.join(main_vms)) + eprintf("Primary VM for %s is %s", self.full_role, ",".join(main_vms)) return main_vms raise UsageError("Primary VM not found") def _get_primary_vms(self): ec2 = self.get_ec2_client() - eip = self.cf.get('domain_eip', '') + eip = self.cf.get("domain_eip", "") main_vms = [] if eip: ipfilter = { - 'Name': 'public-ip', - 'Values': [eip] + "Name": "public-ip", + "Values": [eip] } res = ec2.describe_addresses(Filters=[ipfilter]) - for addr in res['Addresses']: - if not addr.get('InstanceId'): + for addr in res["Addresses"]: + if not addr.get("InstanceId"): continue - if addr['PublicIp'] == eip: - main_vms.append(addr['InstanceId']) + if addr["PublicIp"] == eip: + main_vms.append(addr["InstanceId"]) break return main_vms - internal_hostname = self.cf.get('internal_hostname') + internal_hostname = self.cf.get("internal_hostname") dnsmap = self.get_dns_map() for vm in self.ec2_iter_instances(Filters=self.get_env_filters()): - if not self._check_tags(vm.get('Tags'), True): + if not self._check_tags(vm.get("Tags"), True): continue - if vm['State']['Name'] != 'running': + if vm["State"]["Name"] != "running": continue - if vm.get('PrivateIpAddress') in dnsmap: + if vm.get("PrivateIpAddress") in dnsmap: if internal_hostname: - dns_name = dnsmap[vm['PrivateIpAddress']].rstrip(".") + dns_name = dnsmap[vm["PrivateIpAddress"]].rstrip(".") if dns_name != internal_hostname: continue - main_vms.append(vm['InstanceId']) - elif vm.get('PublicIpAddress') in dnsmap: - main_vms.append(vm['InstanceId']) + main_vms.append(vm["InstanceId"]) + elif vm.get("PublicIpAddress") in dnsmap: + main_vms.append(vm["InstanceId"]) return main_vms def get_all_role_vms(self): @@ -1226,20 +849,20 @@ def get_all_role_vms(self): all_vms = [] for vm in self.ec2_iter_instances(Filters=self.get_env_filters()): - if not self._check_tags(vm.get('Tags'), True): + if not self._check_tags(vm.get("Tags"), True): continue - if vm['State']['Name'] != 'running': + if vm["State"]["Name"] != "running": continue # prepend primary vms - if vm['InstanceId'] in main_vms: - all_vms.insert(0, vm['InstanceId']) + if vm["InstanceId"] in main_vms: + all_vms.insert(0, vm["InstanceId"]) else: - all_vms.append(vm['InstanceId']) + all_vms.append(vm["InstanceId"]) if not all_vms: eprintf("No running VMs for %s", self.full_role) else: - eprintf("Running VMs for %s: %s", self.full_role, ' '.join(all_vms)) + eprintf("Running VMs for %s: %s", self.full_role, " ".join(all_vms)) return all_vms def get_all_role_fo_vms(self): @@ -1250,30 +873,30 @@ def get_all_role_fo_vms(self): all_vms = [] for vm in self.ec2_iter_instances(Filters=self.get_env_filters()): - if not self._check_tags(vm.get('Tags'), True): + if not self._check_tags(vm.get("Tags"), True): continue - if vm['State']['Name'] != 'running': + if vm["State"]["Name"] != "running": continue # skip primary vms - if vm['InstanceId'] in main_vms: + if vm["InstanceId"] in main_vms: pass else: all_vms.append(vm) - all_vms = [it['InstanceId'] for it in sorted(all_vms, key=lambda it: it['LaunchTime'])] + all_vms = [it["InstanceId"] for it in sorted(all_vms, key=lambda it: it["LaunchTime"])] if not all_vms: eprintf("No running failover VMs for %s", self.full_role) elif self.options.earlier_fo_vms: - if len(all_vms) == 1: + if len(all_vms) == 1: all_vms = [] else: all_vms = all_vms[:-1] - eprintf("No running earlier failover VMs for %s: %s", self.full_role, ' '.join(all_vms)) + eprintf("No running earlier failover VMs for %s: %s", self.full_role, " ".join(all_vms)) elif self.options.latest_fo_vm: all_vms = all_vms[-1:] - eprintf("No running latest failover VM for %s: %s", self.full_role, ' '.join(all_vms)) + eprintf("No running latest failover VM for %s: %s", self.full_role, " ".join(all_vms)) else: - eprintf("Running failover VMs for %s: %s", self.full_role, ' '.join(all_vms)) + eprintf("Running failover VMs for %s: %s", self.full_role, " ".join(all_vms)) return all_vms @@ -1285,13 +908,13 @@ def _check_tags(self, taglist, force_role=False, role_name=None): gotenv = gotrole = False for tag in taglist: - if tag['Key'] == 'Env': + if tag["Key"] == "Env": gotenv = True - if tag['Value'] != self.env_name: + if tag["Value"] != self.env_name: return False - if tag['Key'] == 'Role': + if tag["Key"] == "Role": gotrole = True - if role_name and tag['Value'] != role_name: + if role_name and tag["Value"] != role_name: return False if not gotenv: return False @@ -1306,7 +929,7 @@ def get_vm_args(self, args, allow_multi=False): returns: (vm-id, args) """ - if args and args[0][:2] == 'i-': + if args and args[0][:2] == "i-": vm_list = [args[0]] args = args[1:] else: @@ -1328,16 +951,16 @@ def cmd_show_vms(self, *cmdargs): adrmap = {} res = client.describe_addresses() - for adr in res['Addresses']: - if adr.get('InstanceId'): - adrmap[adr['InstanceId']] = adr['PublicIp'] + for adr in res["Addresses"]: + if adr.get("InstanceId"): + adrmap[adr["InstanceId"]] = adr["PublicIp"] dnsmap = self.get_dns_map(True) args = {} - args['Filters'] = self.get_env_filters() + args["Filters"] = self.get_env_filters() if cmdargs: - args['InstanceIds'] = cmdargs + args["InstanceIds"] = cmdargs vm_list = [] for vm in self.ec2_iter_instances(**args): @@ -1368,14 +991,17 @@ def cmd_show_reserved(self, *cmdargs): """ client = self.get_ec2_client() response = client.describe_reserved_instances() - wres = response['ReservedInstances'] + wres = response["ReservedInstances"] for rvm in wres: - tstart = rvm['Start'].isoformat()[:10] - tend = rvm['End'].isoformat()[:10] - plist = ','.join(['{Amount}/{Frequency}'.format(**p) for p in rvm['RecurringCharges']]) + tstart = rvm["Start"].isoformat()[:10] + tend = rvm["End"].isoformat()[:10] + plist = ",".join(["{Amount}/{Frequency}".format(**p) for p in rvm["RecurringCharges"]]) printf("{ReservedInstancesId} type={InstanceType} count={InstanceCount} state={State}".format(**rvm)) - printf(" offering: class={OfferingClass} payment=[{OfferingType}] os=[{ProductDescription}] scope={Scope}".format(**rvm)) + printf( + " offering: class={OfferingClass} payment=[{OfferingType}] os=[{ProductDescription}] scope={Scope}" + .format(**rvm) + ) printf(" Price: fixed={FixedPrice} usage={UsagePrice} recur=".format(**rvm) + plist) printf(" Dur: start=%s end=%s", tstart, tend) @@ -1387,39 +1013,39 @@ def show_vmcost(self, region, vmtype, nActive, nReserved, names): if nActive > nReserved: odCount = nActive - nReserved price = self.get_vm_pricing(region, vmtype) - rawMonth = int(nActive * price['onDemandHourly'] * 24 * 30) - odMonth = int(odCount * price['onDemandHourly'] * 24 * 30) - rMonth = int(nReserved * price['reservedHourly'] * 24 * 30) - odPrice = '($%d/m)' % odMonth - rPrice = '($%d/m)' % rMonth - - odStr = '' - resStr = '' + rawMonth = int(nActive * price["onDemandHourly"] * 24 * 30) + odMonth = int(odCount * price["onDemandHourly"] * 24 * 30) + rMonth = int(nReserved * price["reservedHourly"] * 24 * 30) + odPrice = "($%d/m)" % odMonth + rPrice = "($%d/m)" % rMonth + + odStr = "" + resStr = "" if odCount: - odStr = 'ondemand: %2d %-9s' % (odCount, odPrice) + odStr = "ondemand: %2d %-9s" % (odCount, odPrice) if nReserved: - resStr = 'reserved: %d %s' % (nReserved, rPrice) - nfirst = '' + resStr = "reserved: %d %s" % (nReserved, rPrice) + nfirst = "" if names: - nfirst = '[%s]' % ', '.join(names[:nstep]) + nfirst = "[%s]" % ", ".join(names[:nstep]) names = names[nstep:] printf(" %-12s: running: %2d %-23s %-23s%s", vmtype, nActive, odStr, resStr, nfirst) while names: - printf("%76s[%s]", ' ', ', '.join(names[:nstep])) + printf("%76s[%s]", " ", ", ".join(names[:nstep])) names = names[nstep:] return rawMonth, odMonth + rMonth def load_vmenv(self, vm): env = None role = None - for tag in vm.get('Tags', []): - if tag['Key'] == 'Env': - env = tag['Value'] - elif tag['Key'] == 'Role': - role = tag['Value'] + for tag in vm.get("Tags", []): + if tag["Key"] == "Env": + env = tag["Value"] + elif tag["Key"] == "Role": + role = tag["Value"] if env: if role: - return env + '.' + role + return env + "." + role return env return None @@ -1428,7 +1054,7 @@ def cmd_show_vmcost(self): Group: pricing """ - all_regions = self.cf.getlist('all_regions') + all_regions = self.cf.getlist("all_regions") rawTotal = 0 total = 0 for region in all_regions: @@ -1438,17 +1064,17 @@ def cmd_show_vmcost(self): client = self.get_ec2_client(region) # scan reserved instances - for rvm in client.describe_reserved_instances()['ReservedInstances']: - if rvm['State'] == 'active': - vm_type = rvm['InstanceType'] + for rvm in client.describe_reserved_instances()["ReservedInstances"]: + if rvm["State"] == "active": + vm_type = rvm["InstanceType"] if vm_type not in rmap: rmap[vm_type] = 0 - rmap[vm_type] += rvm['InstanceCount'] + rmap[vm_type] += rvm["InstanceCount"] # scan running instances - flist = [{'Name': 'instance-state-name', 'Values': ['running']}] + flist = [{"Name": "instance-state-name", "Values": ["running"]}] for vm in self.ec2_iter_instances(region=region, Filters=flist): - vm_type = vm['InstanceType'] + vm_type = vm["InstanceType"] if vm_type not in tmap: tmap[vm_type] = 0 tmap[vm_type] += 1 @@ -1461,7 +1087,7 @@ def cmd_show_vmcost(self): if not tmap and not rmap: continue - printf('-- %s --', region) + printf("-- %s --", region) for vm_type in sorted(tmap): names = list(sorted(envmap[vm_type])) rawSum, curSum = self.show_vmcost(region, vm_type, tmap[vm_type], rmap.get(vm_type, 0), names) @@ -1472,7 +1098,7 @@ def cmd_show_vmcost(self): rawSum, curSum = self.show_vmcost(region, vm_type, 0, rmap[vm_type], []) rawTotal += rawSum total += curSum - printf('total: $%d/m reserved bonus: $%d/m', total, rawTotal - total) + printf("total: $%d/m reserved bonus: $%d/m", total, rawTotal - total) def cmd_show_ebscost(self): """Show disk cost. @@ -1481,36 +1107,34 @@ def cmd_show_ebscost(self): """ def addVol(info, vol): - vtype = vol['VolumeType'] + vtype = vol["VolumeType"] if vtype not in info: info[vtype] = 0 - info[vtype] += vol['Size'] + info[vtype] += vol["Size"] def show(name, info, region): parts = [] for t in sorted(info): - s = '%s=%d' % (t, info[t]) - if not t.startswith('vm-'): + s = "%s=%d" % (t, info[t]) + if not t.startswith("vm-"): p = self.get_volume_pricing(region, t) * info[t] - if p < 1: - p = 1 - s += ' ($%d/m)' % int(p) + s += " ($%d/m)" % int(max(p, 1)) parts.append(s) if not parts: - parts = ['-'] - printf('%-20s %s', name+':', ', '.join(parts)) + parts = ["-"] + printf("%-20s %s", name + ":", ", ".join(parts)) - all_regions = self.cf.getlist('all_regions') + all_regions = self.cf.getlist("all_regions") for region in all_regions: - printf('-- %s --', region) + printf("-- %s --", region) envmap = {} vol_map = {} totals = {} gotVol = set() - for vol in self.ec2_iter('describe_volumes', 'Volumes', region=region): - vol_map[vol['VolumeId']] = vol + for vol in self.ec2_iter("describe_volumes", "Volumes", region=region): + vol_map[vol["VolumeId"]] = vol for vm in self.ec2_iter_instances(region=region, Filters=[]): rname = self.load_vmenv(vm) @@ -1518,27 +1142,27 @@ def show(name, info, region): envmap[rname] = {} rinfo = envmap[rname] - sname = 'vm-' + vm['State']['Name'] + sname = "vm-" + vm["State"]["Name"] if sname not in rinfo: rinfo[sname] = 0 rinfo[sname] += 1 - for bdev in vm.get('BlockDeviceMappings', []): - ebs = bdev.get('Ebs') + for bdev in vm.get("BlockDeviceMappings", []): + ebs = bdev.get("Ebs") if ebs: - gotVol.add(ebs['VolumeId']) - vol = vol_map.get(ebs['VolumeId']) + gotVol.add(ebs["VolumeId"]) + vol = vol_map.get(ebs["VolumeId"]) if vol: addVol(totals, vol) addVol(rinfo, vol) else: - printf('Missing vol: %s, instance: %s', ebs['VolumeId'], vm['InstanceId']) + printf("Missing vol: %s, instance: %s", ebs["VolumeId"], vm["InstanceId"]) if totals or vol_map: for rname in sorted(envmap): info = envmap[rname] show(rname, info, region) - show('* total', totals, region) + show("* total", totals, region) for vol_id in vol_map: if vol_id not in gotVol: @@ -1551,29 +1175,29 @@ def cmd_show_s3cost(self): """ def show(name, info, region): - line = ['%-30s' % name] + line = ["%-30s" % name] for k, v in info.items(): - gbs = int(v / (1024*1024*1024)) + gbs = int(v / (1024 * 1024 * 1024)) total = self.get_s3_pricing(region, k, gbs) - line.append('%s=%d ($%d/m)' % (k, int(gbs), total)) - print(' '.join(line)) + line.append("%s=%d ($%d/m)" % (k, int(gbs), total)) + print(" ".join(line)) - all_regions = self.cf.getlist('all_regions') + all_regions = self.cf.getlist("all_regions") for region in all_regions: - printf('-- %s --', region) + printf("-- %s --", region) totals = {} - for bucket in self.get_s3(region).list_buckets()['Buckets']: - bucket_name = bucket['Name'] + for bucket in self.get_s3(region).list_buckets()["Buckets"]: + bucket_name = bucket["Name"] bucket_info = {} - for obj in self.s3_iter('list_object_versions', 'Versions', region=region, Bucket=bucket_name): + for obj in self.s3_iter("list_object_versions", "Versions", region=region, Bucket=bucket_name): # Size, StorageClass, IsLatest, LastModified - sclass = obj['StorageClass'] - size = obj['Size'] # round to block? + sclass = obj["StorageClass"] + size = obj["Size"] # round to block? bucket_info[sclass] = bucket_info.get(sclass, 0) + size for k, v in bucket_info.items(): totals[k] = totals.get(k, 0) + v show(bucket_name, bucket_info, region) - show('* total *', totals, region) + show("* total *", totals, region) def cmd_show_untagged(self): """Show VMs without tags. @@ -1584,16 +1208,16 @@ def cmd_show_untagged(self): adrmap = {} res = client.describe_addresses() - for adr in res['Addresses']: - if adr.get('InstanceId'): - adrmap[adr['InstanceId']] = adr['PublicIp'] + for adr in res["Addresses"]: + if adr.get("InstanceId"): + adrmap[adr["InstanceId"]] = adr["PublicIp"] dnsmap = self.get_dns_map(True) args = {} vm_list = [] for vm in self.ec2_iter_instances(**args): - if not vm.get('Tags'): + if not vm.get("Tags"): vm_list.append(vm) self.options.all = True @@ -1606,10 +1230,10 @@ def cmd_show_lbs(self): """ client = self.get_elb() res = client.describe_load_balancers() - for lb in res['LoadBalancerDescriptions']: - printf("Name: %s", lb['DNSName']) - printf(" SrcSecGroup: %r", lb['SourceSecurityGroup']['GroupName']) - printf(" ExtraSecGroups: %r", lb['SecurityGroups']) + for lb in res["LoadBalancerDescriptions"]: + printf("Name: %s", lb["DNSName"]) + printf(" SrcSecGroup: %r", lb["SourceSecurityGroup"]["GroupName"]) + printf(" ExtraSecGroups: %r", lb["SecurityGroups"]) def cmd_show_sgs(self): """Show security groups. @@ -1620,12 +1244,12 @@ def cmd_show_sgs(self): res = client.describe_security_groups() # item, owner_id, region, rules, rules_egress, tags, vpc_id - for sg in res['SecurityGroups']: - printf("%s - %s - %s", sg['GroupId'], sg['GroupName'], sg['Description']) - printf(" RulesIn: %r", len(sg['IpPermissions'])) - printf(" RulesOut: %r", len(sg['IpPermissionsEgress'])) - if sg.get('Tags'): - printf(" Tags: %r", sg['Tags']) + for sg in res["SecurityGroups"]: + printf("%s - %s - %s", sg["GroupId"], sg["GroupName"], sg["Description"]) + printf(" RulesIn: %r", len(sg["IpPermissions"])) + printf(" RulesOut: %r", len(sg["IpPermissionsEgress"])) + if sg.get("Tags"): + printf(" Tags: %r", sg["Tags"]) def cmd_show_buckets(self): """Show S3 buckets. @@ -1634,15 +1258,15 @@ def cmd_show_buckets(self): """ s3 = self.get_s3() res = s3.list_buckets() - for b in res['Buckets']: - printf("%s", b['Name']) + for b in res["Buckets"]: + printf("%s", b["Name"]) def cmd_show_files(self, *blist): """Show files in a S3 bucket. Group: s3 """ - cur_bucket = self.cf.get('files_bucket') + cur_bucket = self.cf.get("files_bucket") if not blist: blist = [cur_bucket] @@ -1650,20 +1274,20 @@ def cmd_show_files(self, *blist): eprintf("---- %s ----", bname) for kx in self.s3_iter_objects(bname): if self.options.verbose: - self.s3_show_obj_head(bname, kx['Key'], kx) + self.s3_show_obj_head(bname, kx["Key"], kx) else: - printf("%s", kx['Key']) + printf("%s", kx["Key"]) def s3_get_obj_head(self, bucket, key): return self.get_s3().head_object(Bucket=bucket, Key=key) def s3_show_obj_head(self, bucket, key, res): printf("%s", key) - for a in ('ContentLength', 'ContentType', 'ContentEncoding', 'ContentDisposition', - 'ContentLanguage', 'Metadata', 'CacheControl', - 'ETag', 'LastModified', 'StorageClass', 'ReplicationStatus', - 'ServerSideEncryption', 'PartsCount', - 'SSECustomerKeyMD5', 'SSEKMSKeyId', 'SSECustomerAlgorithm'): + for a in ("ContentLength", "ContentType", "ContentEncoding", "ContentDisposition", + "ContentLanguage", "Metadata", "CacheControl", + "ETag", "LastModified", "StorageClass", "ReplicationStatus", + "ServerSideEncryption", "PartsCount", + "SSECustomerKeyMD5", "SSEKMSKeyId", "SSECustomerAlgorithm"): v = res.get(a) if v: printf(" %s: %r", a, v) @@ -1673,32 +1297,32 @@ def s3_show_obj_info(self, bucket, key, info): for k in info: v = info.get(k) if isinstance(v, datetime.datetime): - v = v.isoformat(' ') - if k != 'Key' and v: + v = v.isoformat(" ") + if k != "Key" and v: printf(" %s: %r", k, v) def s3_iter_objects(self, bucket, prefix=None): s3client = self.get_s3() - pg_list_objects = s3client.get_paginator('list_objects') + pg_list_objects = s3client.get_paginator("list_objects") - args = {'Bucket': bucket} + args = {"Bucket": bucket} if prefix: - args['Prefix'] = prefix + args["Prefix"] = prefix for pres in pg_list_objects.paginate(**args): - for obj in pres.get('Contents') or []: + for obj in pres.get("Contents") or []: yield obj def s3_iter_object_versions(self, bucket, prefix=None): s3client = self.get_s3() - pg_list_object_versions = s3client.get_paginator('list_object_versions') + pg_list_object_versions = s3client.get_paginator("list_object_versions") - args = {'Bucket': bucket} + args = {"Bucket": bucket} if prefix: - args['Prefix'] = prefix + args["Prefix"] = prefix for pres in pg_list_object_versions.paginate(**args): - for obj in pres.get('Versions') or []: + for obj in pres.get("Versions") or []: yield obj def cmd_show_backups(self, *slot_list): @@ -1706,10 +1330,10 @@ def cmd_show_backups(self, *slot_list): Group: backup """ - slot_filter = '' + slot_filter = "" - bucket_name = self.cf.get('backup_aws_bucket') - pfx = self.cf.get('backup_prefix') + bucket_name = self.cf.get("backup_aws_bucket") + pfx = self.cf.get("backup_prefix") if slot_list: slot_filter = slot_list[0] pfx += slot_filter @@ -1718,28 +1342,28 @@ def cmd_show_backups(self, *slot_list): eprintf("---- %s ----", bucket_name) slots = {} - backup_domain = pfx.split('/')[0] + backup_domain = pfx.split("/")[0] for kx in self.s3_iter_objects(bucket_name, pfx): - parts = kx['Key'].split('/') + parts = kx["Key"].split("/") if parts[0] != backup_domain: continue - size = kx['Size'] - if parts[2] == 'base': - slot = '/'.join(parts[1:4]) + size = kx["Size"] + if parts[2] == "base": + slot = "/".join(parts[1:4]) else: - slot = '/'.join(parts[1:3]) + slot = "/".join(parts[1:3]) if slot not in slots: slots[slot] = 0 slots[slot] += size if not summary_output: - #head = self.s3_get_obj_head(bucket_name, kx['Key']) - #self.s3_show_obj_head(bucket_name, kx['Key'], head) - self.s3_show_obj_info(bucket_name, kx['Key'], kx) + #head = self.s3_get_obj_head(bucket_name, kx["Key"]) + #self.s3_show_obj_head(bucket_name, kx["Key"], head) + self.s3_show_obj_info(bucket_name, kx["Key"], kx) if summary_output: for slot in sorted(slots): - print("%s: %d GB" % (slot, int(slots[slot] / (1024*1024*1024)))) + print("%s: %d GB" % (slot, int(slots[slot] / (1024 * 1024 * 1024)))) def cmd_get_backup(self, *slot_list): """Download backup files from S3. @@ -1795,7 +1419,7 @@ def progcb(cur_read, total=total_size, kname=kname): return amount = cur - last[1] dur = now - last[2] - sys.stdout.write('\r%-30s %.1f%% of %d [%.1f kb/s] ' % (kname, perc, total, amount / (dur * 1024.0))) + sys.stdout.write("\r%-30s %.1f%% of %d [%.1f kb/s] " % (kname, perc, total, amount / (dur * 1024.0))) sys.stdout.flush() last[1] = cur last[2] = now @@ -1803,7 +1427,7 @@ def progcb(cur_read, total=total_size, kname=kname): last[0], last[1], last[2] = 0, 0, time.time() s3.download_file(Bucket=bucket_name, Key=kname, Filename=fn, Callback=progcb, Config=tx_config) - sys.stdout.write('\n') + sys.stdout.write("\n") def cmd_clean_backups(self): """Clean backup slots in S3. @@ -1815,41 +1439,40 @@ def cmd_clean_backups(self): # keep daily days = 6 * 30 dt_pos = datetime.datetime.utcnow() - datetime.timedelta(days=days) - min_slot = dt_pos.strftime('%Y/%m/%d') + min_slot = dt_pos.strftime("%Y/%m/%d") - bucket_name = self.cf.get('backup_aws_bucket') - pfx = self.cf.get('backup_prefix') - rc_test = re.compile(r'^\d\d\d\d/\d\d/\d\d$') + bucket_name = self.cf.get("backup_aws_bucket") + pfx = self.cf.get("backup_prefix") + rc_test = re.compile(r"^\d\d\d\d/\d\d/\d\d$") printf("---- %s ----", bucket_name) - slots = {} del_list = [] keep_set = set() - backup_domain = pfx.split('/')[0] + backup_domain = pfx.split("/", 1)[0] for kx in self.s3_iter_object_versions(bucket_name, pfx): - parts = kx['Key'].split(':')[0].split('/') + parts = kx["Key"].split(":")[0].split("/") if parts[0] != backup_domain: continue - slot = '/'.join(parts[1:]) + slot = "/".join(parts[1:]) if not rc_test.match(slot): - raise Exception('Unexpected slot format: %r' % slot) + raise Exception("Unexpected slot format: %r" % slot) if slot >= min_slot: keep_set.add(slot) continue - ref = {'Key': kx['Key']} - if kx.get('VersionId'): - ref['VersionId'] = kx['VersionId'] + ref = {"Key": kx["Key"]} + if kx.get("VersionId"): + ref["VersionId"] = kx["VersionId"] del_list.append(ref) if len(del_list) >= 500: printf("Deleting files: %d", len(del_list)) - s3client.delete_objects(Bucket=bucket_name, Delete={'Objects': del_list, 'Quiet': True}) + s3client.delete_objects(Bucket=bucket_name, Delete={"Objects": del_list, "Quiet": True}) del_list = [] if del_list: printf("Deleting files: %d", len(del_list)) - s3client.delete_objects(Bucket=bucket_name, Delete={'Objects': del_list, 'Quiet': True}) + s3client.delete_objects(Bucket=bucket_name, Delete={"Objects": del_list, "Quiet": True}) printf("Kept %d slots for %s", len(keep_set), backup_domain) @@ -1858,58 +1481,53 @@ def cmd_ls_backups(self): Group: backup """ - s3client = self.get_s3() - - bucket_name = self.cf.get('backup_aws_bucket') - pfx = self.cf.get('backup_prefix') + bucket_name = self.cf.get("backup_aws_bucket") + pfx = self.cf.get("backup_prefix") smap = { - 'STANDARD': 'S', - 'STANDARD_IA': 'I', - 'ONEZONE_IA': 'Z', - 'GLACIER': 'G', - 'REDUCED_REDUNDANCY': 'R', + "STANDARD": "S", + "STANDARD_IA": "I", + "ONEZONE_IA": "Z", + "GLACIER": "G", + "REDUCED_REDUNDANCY": "R", } printf("---- %s ----", bucket_name) - n = 0 for kx in self.s3_iter_object_versions(bucket_name, pfx): #print_json(kx) - lmod = kx['LastModified'] - size = kx['Size'] - age = kx['IsLatest'] and '!' or '~' - scls = smap.get(kx['StorageClass'], kx['StorageClass']) - ver = kx['VersionId'] - printf("%s %s", kx['Key'], scls + age) + #lmod = kx["LastModified"] + #size = kx["Size"] + #ver = kx["VersionId"] + age = kx["IsLatest"] and "!" or "~" + scls = smap.get(kx["StorageClass"], kx["StorageClass"]) + printf("%s %s", kx["Key"], scls + age) def cmd_ls_files(self): """Show backup slots in S3. Group: backup """ - s3client = self.get_s3() - - bucket_name = self.cf.get('files_bucket') - pfx = '' + bucket_name = self.cf.get("files_bucket") + pfx = "" smap = { - 'STANDARD': 'S', - 'STANDARD_IA': 'I', - 'ONEZONE_IA': 'Z', - 'GLACIER': 'G', - 'REDUCED_REDUNDANCY': 'R', + "STANDARD": "S", + "STANDARD_IA": "I", + "ONEZONE_IA": "Z", + "GLACIER": "G", + "REDUCED_REDUNDANCY": "R", } eprintf("---- %s ----", bucket_name) for kx in self.s3_iter_object_versions(bucket_name, pfx): #print_json(kx) - mtime = kx['LastModified'].isoformat()[:10] - size = kx['Size'] - age = kx['IsLatest'] and '!' or '~' - scls = smap.get(kx['StorageClass'], kx['StorageClass']) + mtime = kx["LastModified"].isoformat()[:10] + size = kx["Size"] + age = kx["IsLatest"] and "!" or "~" + scls = smap.get(kx["StorageClass"], kx["StorageClass"]) tag = scls + age - ver = kx['VersionId'] - name = kx['Key'] + #ver = kx["VersionId"] + name = kx["Key"] printf("mtime=%s tag=%s size=%d key=%s", mtime, tag, size, name) def cmd_show_ips(self): @@ -1919,11 +1537,11 @@ def cmd_show_ips(self): """ client = self.get_ec2_client() res = client.describe_addresses() - for a in res['Addresses']: - #tags = ['%s: %s' for k,v in a.tags.items()] - #st = ', '.join(tags) + for a in res["Addresses"]: + #tags = ["%s: %s" for k,v in a.tags.items()] + #st = ", ".join(tags) #st = repr(dir(a)) - printf("%s - vm=%s domain=%s", a.get('PublicIp'), a.get('InstanceId', '-'), a.get('Domain', '-')) + printf("%s - vm=%s domain=%s", a.get("PublicIp"), a.get("InstanceId", "-"), a.get("Domain", "-")) def cmd_show_ebs(self): """Show EBS volumes. @@ -1932,13 +1550,13 @@ def cmd_show_ebs(self): """ client = self.get_ec2_client() res = client.describe_volumes() - for v in res['Volumes']: - a = v.get('Attachments') - vm_id = '-' + for v in res["Volumes"]: + a = v.get("Attachments") + vm_id = "-" if a: - vm_id = a[0].get('InstanceId') - t = v.get('CreateTime').strftime('%Y-%m-%d') - print("%s@%s size=%dG stat=%s created=%s" % (v['VolumeId'], vm_id, v['Size'], v['State'], t)) + vm_id = a[0].get("InstanceId") + t = v.get("CreateTime").strftime("%Y-%m-%d") + print("%s@%s size=%dG stat=%s created=%s" % (v["VolumeId"], vm_id, v["Size"], v["State"], t)) def cmd_show_tables(self): """Show DynamoDB tables. @@ -1946,7 +1564,7 @@ def cmd_show_tables(self): Group: dynamodb """ ddb = self.get_ddb() - for t in ddb.list_tables()['TableNames']: + for t in ddb.list_tables()["TableNames"]: print(t) def cmd_describe_table(self, tblname): @@ -1955,7 +1573,7 @@ def cmd_describe_table(self, tblname): Group: dynamodb """ ddb = self.get_ddb() - desc = ddb.describe_table(TableName=tblname)['Table'] + desc = ddb.describe_table(TableName=tblname)["Table"] print_json(desc) def cmd_get_item(self, tbl_name, item_key): @@ -1964,19 +1582,14 @@ def cmd_get_item(self, tbl_name, item_key): Group: dynamodb """ ddb = self.get_ddb() - res = ddb.get_item(TableName=tbl_name, Key={'hash_key': {'S': item_key}}) + res = ddb.get_item(TableName=tbl_name, Key={"hash_key": {"S": item_key}}) print_json(res) - def get_stamp(self): - commit_id = local_cmd(['git', 'rev-parse', 'HEAD']) - commit_id = commit_id[:7] # same length as git log --abbrev-commit - return commit_id - def load_tags(self, obj): tags = {} - if obj and obj.get('Tags'): - for tag in obj.get('Tags'): - tags[tag['Key']] = tag['Value'] + if obj and obj.get("Tags"): + for tag in obj.get("Tags"): + tags[tag["Key"]] = tag["Value"] return tags def set_stamp(self, vm_id, name, commit_id, *dirs): @@ -1987,10 +1600,10 @@ def set_stamp(self, vm_id, name, commit_id, *dirs): vm = self.vm_lookup(vm_id) old_tags = self.load_tags(vm) - tags = [{'Key': name, 'Value': commit_id}] + tags = [{"Key": name, "Value": commit_id}] client.create_tags(Resources=[vm_id], Tags=tags) - old_id = old_tags.get('Commit', '?') + old_id = old_tags.get("Commit", "?") old_id = old_tags.get(name, old_id) if commit_id == old_id: printf("%s: %s - no new commits", name, vm_id) @@ -2000,7 +1613,7 @@ def set_stamp(self, vm_id, name, commit_id, *dirs): def gen_user_data(self): rnd = secrets.token_urlsafe(20) - mimedata = USERDATA.replace('RND', rnd) + mimedata = USERDATA.replace("RND", rnd) if "AUTHORIZED_USER_CREATION" in mimedata: mimedata = mimedata.replace( "AUTHORIZED_USER_CREATION", self.make_user_creation() @@ -2016,68 +1629,16 @@ def cmd_create(self): self.vm_create_finish(ids) return ids - def get_disk_map(self): - """Parse disk_map option. - """ - disk_map = self.cf.getdict("disk_map", {}) - if not disk_map: - disk_map = {"root": "size=12"} - - res_map = {} - for dev in disk_map: - val = disk_map[dev] - local = {} - for opt in val.split(":"): - if "=" in opt: - k, v = opt.split("=", 1) - k = k.strip() - v = v.strip() - else: - k = v = opt.strip() - if not k: - continue - if k.startswith("ephemeral"): - k = "ephemeral" - if k in ("size", "count", "iops", "throughput"): - v = int(v) - local[k] = v - if "count" not in local: - local["count"] = 1 - if "size" not in local: - raise UsageError("Each element in disk_map needs size") - res_map[dev] = local - - # sanity check if requested - disk_require_order = self.cf.getlist("disk_require_order", []) - if disk_require_order: - # order from disk_map - got_order = sorted([ - (res_map[name]["size"], res_map[name]["count"], name) - for name in res_map - ]) - names_order = [f"{name}:{count}" for size, count, name in got_order] - - # order from disk_require_order - counted_order = [ - key if ":" in key else key + ":1" - for key in disk_require_order - ] - - if names_order != counted_order: - raise UsageError("Order mismatch:\n require=%r\n got=%r" % (counted_order, names_order)) - - return res_map - def get_next_raw_device(self, base_dev, used): prefix = base_dev[:-1] last = ord(base_dev[-1]) - while chr(last) <= 'z': - current_dev = '%s%c' % (prefix, last) + while chr(last) <= "z": + current_dev = "%s%c" % (prefix, last) if current_dev not in used: used.add(current_dev) return current_dev last += 1 - raise Exception('Failed to generate disk name: %r used=%r' % (base_dev, used)) + raise Exception("Failed to generate disk name: %r used=%r" % (base_dev, used)) def vm_create_start(self): """Create instance. @@ -2086,44 +1647,44 @@ def vm_create_start(self): """ client = self.get_ec2_client() - image_type = self.cf.get('image_type') - image_id = self.cf.get(image_type + '_image_id', '') + image_type = self.cf.get("image_type") + image_id = self.cf.get(image_type + "_image_id", "") if image_id: image_name = "" else: - image_name = self.cf.get('image_name') + image_name = self.cf.get("image_name") image_id = self.get_image_id(image_name) if not image_id: eprintf("ERROR: no image for name: %r" % image_name) sys.exit(1) - key_name = self.cf.get('key_name') - vm_type = self.cf.get('vm_type') - sg_list = self.cf.getlist('security_groups') - zone = self.cf.get('zone', '') - cpu_credits = self.cf.get('cpu_credits', '') - cpu_count = self.cf.getint('cpu_count', 0) - cpu_thread_count = self.cf.getint('cpu_thread_count', 0) - aws_extra_tags = self.cf.getdict('aws_extra_tags', {}) - xname = 'vm.' + self.env_name + key_name = self.cf.get("key_name") + vm_type = self.cf.get("vm_type") + sg_list = self.cf.getlist("security_groups") + zone = self.cf.get("zone", "") + cpu_credits = self.cf.get("cpu_credits", "") + cpu_count = self.cf.getint("cpu_count", 0) + cpu_thread_count = self.cf.getint("cpu_thread_count", 0) + aws_extra_tags = self.cf.getdict("aws_extra_tags", {}) + xname = "vm." + self.env_name if self.role_name: - xname += '.' + self.role_name + xname += "." + self.role_name if not zone: zone = None - ebs_optimized = self.cf.getboolean('ebs_optimized', False) - disk_type = self.cf.get('disk_type', 'gp2') + ebs_optimized = self.cf.getboolean("ebs_optimized", False) + disk_type = self.cf.get("disk_type", "gp2") disk_map = self.get_disk_map() if not disk_map: - disk_map = {'root': {'size': 12}} + disk_map = {"root": {"size": 12}} # device name may be different for different AMIs res = client.describe_images(ImageIds=[image_id]) - if not res.get('Images'): + if not res.get("Images"): eprintf("ERROR: no image: %r" % image_id) sys.exit(1) - for img in res['Images']: - root_device_name = img['RootDeviceName'] + for img in res["Images"]: + root_device_name = img["RootDeviceName"] devlog = [] bdm = [] @@ -2132,68 +1693,68 @@ def vm_create_start(self): used_raw_devs = set() for dev in disk_map: - bdev = {'DeviceName': dev} + bdev = {"DeviceName": dev} count = 1 ebs = {} for k, v in disk_map[dev].items(): - if k == 'size': - ebs['VolumeSize'] = int(v) - elif k == 'iops': - ebs['Iops'] = int(v) - elif k == 'throughput': - ebs['Throughput'] = int(v) - elif k == 'count': + if k == "size": + ebs["VolumeSize"] = int(v) + elif k == "iops": + ebs["Iops"] = int(v) + elif k == "throughput": + ebs["Throughput"] = int(v) + elif k == "count": count = int(v) - elif k == 'type': + elif k == "type": # https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/EBSVolumeTypes.html # Values: standard, gp2, io1, st1, sc1 - ebs['VolumeType'] = v - elif k == 'local': - bdev['VirtualName'] = v - elif k == 'encrypted': - if v == 'encrypted': - v = '1' - ebs['Encrypted'] = bool(int(v)) + ebs["VolumeType"] = v + elif k == "local": + bdev["VirtualName"] = v + elif k == "encrypted": + if v == "encrypted": + v = "1" + ebs["Encrypted"] = bool(int(v)) elif k in self.VOL_TYPES: - ebs['VolumeType'] = k + ebs["VolumeType"] = k elif k in self.VOL_ENC_TYPES: - ebs['VolumeType'] = k.split('-')[1] - ebs['Encrypted'] = True - elif k == 'ephemeral': - bdev['VirtualName'] = v + ebs["VolumeType"] = k.split("-")[1] + ebs["Encrypted"] = True + elif k == "ephemeral": + bdev["VirtualName"] = v else: eprintf("ERROR: unknown disk param: %r", k) sys.exit(1) - if bdev.get('VirtualName'): - ebs.pop('VolumeSize', 0) + if bdev.get("VirtualName"): + ebs.pop("VolumeSize", 0) if ebs: eprintf("ERROR: ephemeral device cannot have EBS params: %r", ebs) sys.exit(1) elif ebs: - if 'VolumeSize' not in ebs: - ebs['VolumeSize'] = 10 - if 'VolumeType' not in ebs: - ebs['VolumeType'] = disk_type - ebs['DeleteOnTermination'] = True + if "VolumeSize" not in ebs: + ebs["VolumeSize"] = 10 + if "VolumeType" not in ebs: + ebs["VolumeType"] = disk_type + ebs["DeleteOnTermination"] = True - bdev['Ebs'] = ebs + bdev["Ebs"] = ebs for _ in range(count): bdev = bdev.copy() # fill DeviceName, mainly used for root selection, otherwise mostly useless - if dev in ROOT_DEV_NAMES: - bdev['DeviceName'] = root_device_name - if root_device_name in ['/dev/sda1']: + if dev in self.ROOT_DEV_NAMES: + bdev["DeviceName"] = root_device_name + if root_device_name in ["/dev/sda1"]: used_raw_devs.add(root_device_name[:-1]) else: used_raw_devs.add(root_device_name) elif ebs: - bdev['DeviceName'] = self.get_next_raw_device('/dev/sdf', used_raw_devs) + bdev["DeviceName"] = self.get_next_raw_device("/dev/sdf", used_raw_devs) else: - bdev['DeviceName'] = self.get_next_raw_device('/dev/sdb', used_raw_devs) + bdev["DeviceName"] = self.get_next_raw_device("/dev/sdb", used_raw_devs) if "VirtualName" in bdev: bdev["VirtualName"] = "ephemeral%d" % (ephemeral_idx,) @@ -2201,17 +1762,17 @@ def vm_create_start(self): bdm.append(bdev) - devlog.append('%s=%s' % (dev, bdev['DeviceName'])) + devlog.append("%s=%s" % (dev, bdev["DeviceName"])) time_printf("AWS=%s Env=%s Role=%s Key=%s Image=%s(%s) AZ=%d", - self.cf.get('aws_main_account'), - self.env_name, self.role_name or '-', key_name, + self.cf.get("aws_main_account"), + self.env_name, self.role_name or "-", key_name, image_name, image_id, self.availability_zone) - time_printf("Creating VM, storage: %s" % ', '.join(devlog)) + time_printf("Creating VM, storage: %s" % ", ".join(devlog)) # lookup subnet - subnet_id = self.cf.get('subnet_id') + subnet_id = self.cf.get("subnet_id") # manual lookup for sgs sg_ids = self.sgroups_lookup(sg_list) @@ -2219,65 +1780,65 @@ def vm_create_start(self): eprintf("ERROR: failed to resolve security groups: %r" % sg_list) sys.exit(1) - instance_profile_arn = self.cf.get('instance_profile_arn', '') + instance_profile_arn = self.cf.get("instance_profile_arn", "") if not instance_profile_arn: instance_profile_arn = None - instance_associate_public_ip = self.cf.getboolean('instance_associate_public_ip', False) + instance_associate_public_ip = self.cf.getboolean("instance_associate_public_ip", False) user_data = self.gen_user_data() main_iface = { - 'DeviceIndex': 0, - 'Description': '%s' % self.full_role, - 'SubnetId': subnet_id, - 'AssociatePublicIpAddress': instance_associate_public_ip, - 'DeleteOnTermination': True, - 'Groups': sg_ids, + "DeviceIndex": 0, + "Description": "%s" % self.full_role, + "SubnetId": subnet_id, + "AssociatePublicIpAddress": instance_associate_public_ip, + "DeleteOnTermination": True, + "Groups": sg_ids, } args = { - 'ImageId': image_id, - 'InstanceType': vm_type, - 'KeyName': key_name, - 'BlockDeviceMappings': bdm, - 'MinCount': 1, - 'MaxCount': 1, - 'NetworkInterfaces': [main_iface] + "ImageId": image_id, + "InstanceType": vm_type, + "KeyName": key_name, + "BlockDeviceMappings": bdm, + "MinCount": 1, + "MaxCount": 1, + "NetworkInterfaces": [main_iface] } if zone: - args['Placement'] = {'AvailabilityZone': zone} + args["Placement"] = {"AvailabilityZone": zone} if instance_profile_arn: - args['IamInstanceProfile'] = {'Arn': instance_profile_arn} + args["IamInstanceProfile"] = {"Arn": instance_profile_arn} if ebs_optimized: - args['EbsOptimized'] = True + args["EbsOptimized"] = True if user_data: - args['UserData'] = user_data + args["UserData"] = user_data if cpu_credits: # standard / unlimited (for t2.* instances) - args['CreditSpecification'] = {'CpuCredits': cpu_credits} + args["CreditSpecification"] = {"CpuCredits": cpu_credits} if cpu_count or cpu_thread_count: - args['CpuOptions'] = {} + args["CpuOptions"] = {} if cpu_count: - args['CpuOptions']['CoreCount'] = cpu_count + args["CpuOptions"]["CoreCount"] = cpu_count if cpu_thread_count: - args['CpuOptions']['ThreadsPerCore'] = cpu_thread_count + args["CpuOptions"]["ThreadsPerCore"] = cpu_thread_count # pre-fill tags self.new_commit = self.get_stamp() tags = [ - {'Key': 'Name', 'Value': xname}, - {'Key': 'Env', 'Value': self.env_name}, - {'Key': 'Commit', 'Value': self.new_commit}, - {'Key': 'Date', 'Value': time.strftime("%Y%m%d")}, - {'Key': 'VmState', 'Value': VmState.SECONDARY}, + {"Key": "Name", "Value": xname}, + {"Key": "Env", "Value": self.env_name}, + {"Key": "Commit", "Value": self.new_commit}, + {"Key": "Date", "Value": time.strftime("%Y%m%d")}, + {"Key": "VmState", "Value": VmState.SECONDARY}, ] if self.role_name: - tags.append({'Key': 'Role', 'Value': self.role_name}) + tags.append({"Key": "Role", "Value": self.role_name}) for k, v in aws_extra_tags.items(): - tags.append({'Key': k, 'Value': v}) - args['TagSpecifications'] = [ - {'ResourceType': 'instance', 'Tags': tags}, - {'ResourceType': 'volume', 'Tags': tags}, + tags.append({"Key": k, "Value": v}) + args["TagSpecifications"] = [ + {"ResourceType": "instance", "Tags": tags}, + {"ResourceType": "volume", "Tags": tags}, ] # actual launch @@ -2287,17 +1848,17 @@ def vm_create_start(self): # collect ids ids = [] - for vm in res['Instances']: - vm_id = vm['InstanceId'] + for vm in res["Instances"]: + vm_id = vm["InstanceId"] ids.append(vm_id) - time_printf("Created: %s", ' '.join(ids)) + time_printf("Created: %s", " ".join(ids)) show_first = True - while 1: + while True: ok = True vm_list = [] for vm in self.ec2_iter_instances(InstanceIds=ids): - if vm['State']['Name'] != 'running': + if vm["State"]["Name"] != "running": ok = False vm_list.append(vm) @@ -2365,7 +1926,7 @@ def cmd_create_secondary(self): """Create secondary vm. Group: vm """ - self.cf.set('vm_state', VmState.SECONDARY) + self.cf.set("vm_state", VmState.SECONDARY) start = time.time() self.modcmd_init(VmCmd.PREP) @@ -2382,7 +1943,7 @@ def cmd_create_secondary(self): printf("Total time: %d", int(end - start)) # reset vm state - self.cf.set('vm_state', VmState.PRIMARY) + self.cf.set("vm_state", VmState.PRIMARY) return first def cmd_add_key(self, vm_id): @@ -2404,18 +1965,18 @@ def cmd_tag(self): raise Exception("No role_name") tags = [] - aws_extra_tags = self.cf.getdict('aws_extra_tags', {}) + aws_extra_tags = self.cf.getdict("aws_extra_tags", {}) for k, v in aws_extra_tags.items(): - tags.append({'Key': k, 'Value': v}) + tags.append({"Key": k, "Value": v}) if tags: client = self.get_ec2_client() for vm in self.ec2_iter_instances(Filters=self.get_env_filters()): - client.create_tags(Resources=[vm['InstanceId']], Tags=tags) - for bdm in vm.get('BlockDeviceMappings', []): - ebs = bdm.get('Ebs') + client.create_tags(Resources=[vm["InstanceId"]], Tags=tags) + for bdm in vm.get("BlockDeviceMappings", []): + ebs = bdm.get("Ebs") if ebs: - client.create_tags(Resources=[ebs['VolumeId']], Tags=tags) + client.create_tags(Resources=[ebs["VolumeId"]], Tags=tags) def cmd_start(self, *ids): """Start instance. @@ -2446,12 +2007,12 @@ def cmd_terminate(self, *ids): stopped = set() for vm in self.ec2_iter_instances(Filters=self.get_env_filters()): - if vm['State']['Name'] != 'stopped': + if vm["State"]["Name"] != "stopped": continue if not self.options.all: - if not self._check_tags(vm.get('Tags')): + if not self._check_tags(vm.get("Tags")): continue - stopped.add(str(vm['InstanceId'])) + stopped.add(str(vm["InstanceId"])) bad = [] for vm_id in ids: @@ -2468,32 +2029,32 @@ def cmd_gc(self): Group: vm """ - gc_keep_count = self.cf.getint('gc_keep_count', 0) - gc_keep_days = self.cf.getint('gc_keep_days', 0) + gc_keep_count = self.cf.getint("gc_keep_count", 0) + gc_keep_days = self.cf.getint("gc_keep_days", 0) s_max_time = None if gc_keep_days > 0: max_time = datetime.datetime.utcnow() - datetime.timedelta(days=gc_keep_days) s_max_time = max_time.isoformat() if gc_keep_days or gc_keep_count: - print('gc: gc_keep_days: %d gc_keep_count: %d maxtime: %r' % ( + print("gc: gc_keep_days: %d gc_keep_count: %d maxtime: %r" % ( gc_keep_days, gc_keep_count, s_max_time)) client = self.get_ec2_client() garbage = [] vms_iter = self.ec2_iter_instances(Filters=self.get_env_filters()) - vms_sorted = sorted(vms_iter, key=lambda vm: vm['LaunchTime']) + vms_sorted = sorted(vms_iter, key=lambda vm: vm["LaunchTime"]) keep_count = 0 for vm in vms_sorted: - if vm['State']['Name'] != 'stopped': + if vm["State"]["Name"] != "stopped": continue if not self.options.all: - if not self._check_tags(vm.get('Tags')): + if not self._check_tags(vm.get("Tags")): continue - vm_launchtime = vm['LaunchTime'].isoformat() + vm_launchtime = vm["LaunchTime"].isoformat() if s_max_time and vm_launchtime >= s_max_time: keep_count += 1 continue - garbage.append(str(vm['InstanceId'])) + garbage.append(str(vm["InstanceId"])) # remove some if necessary while garbage and gc_keep_count > keep_count: @@ -2545,196 +2106,6 @@ def cmd_rsync(self, *args): raise UsageError("Need source and dest for rsync") self.vm_rsync(*args) - def filter_key_lookup(self, predef, key, fname): - if key in predef: - return predef[key] - - if key == 'MASTER_KEYS': - master_key_list = [] - nr = 1 - while 1: - kname = "master_key_%d" % nr - v = self.cf.get(kname, '') - if not v: - break - master_key_list.append("%s = %s" % (kname, v)) - nr += 1 - if not master_key_list: - raise Exception("No master keys found") - master_key_conf = "\n".join(master_key_list) - return master_key_conf - - if key == 'SYSRANDOM': - blk = os.urandom(3*16) - b64 = binascii.b2a_base64(blk).strip() - return b64.decode('utf8') - - if key == 'AUTHORIZED_KEYS': - auth_users = self.cf.getlist('ssh_authorized_users', []) - pat = self.cf.get('ssh_pubkey_pattern') - keys = [] - for user in sorted(set(auth_users)): - fn = os.path.join(self.keys_dir, pat.replace('USER', user)) - pubkey = open(fn, 'r').read().strip() - keys.append(pubkey) - return '\n'.join(keys) - - if key == 'AUTHORIZED_USER_CREATION': - return self.make_user_creation() - - try: - return self.cf.get(key) - except NoOptionError: - raise UsageError("%s: key not found: %s" % (fname, key)) - - def make_user_creation(self): - auth_groups = self.cf.getlist('authorized_user_groups', []) - auth_users = self.cf.getlist('ssh_authorized_users', []) - pat = self.cf.get('ssh_pubkey_pattern') - script = [] - for user in sorted(set(auth_users)): - fn = os.path.join(self.keys_dir, pat.replace('USER', user)) - pubkey = open(fn).read().strip() - script.append(mk_sshuser_script(user, auth_groups, pubkey)) - return '\n'.join(script) - - def make_tar_filter(self, extra_defs=None): - defs = {} - if extra_defs: - defs.update(extra_defs) - tb = TarFilter(self.filter_key_lookup, defs) - tb.set_live(self.is_live) - return tb - - def conf_func_file(self, arg, sect, kname): - """Returns contents of file, optionally gpg-decrypted. - - Usage: ${FILE ! filename} - """ - if self.options.verbose: - printf("FILE: %s", arg) - fn = os.path.join(self.keys_dir, arg) - if not os.path.isfile(fn): - raise UsageError('%s - FILE missing: %s' % (kname, arg)) - if fn.endswith('.gpg'): - return self.load_gpg_file(fn).rstrip('\n') - return open(fn, 'r').read().rstrip('\n') - - def conf_func_key(self, arg, sect, kname): - """Returns key from Terraform state file. - - Usage: ${KEY ! fn : key} - """ - bfn, subkey = arg.split(':') - if self.options.verbose: - printf("KEY: %s : %s", bfn.strip(), subkey.strip()) - fn = os.path.join(self.keys_dir, bfn.strip()) - if not os.path.isfile(fn): - raise UsageError('%s - KEY file missing: %s' % (kname, fn)) - cf = self.load_gpg_config(fn, 'vm-config') - subkey = as_unicode(subkey.strip()) - try: - return cf.get(subkey) - except: - raise UsageError("%s - Key '%s' unset in '%s'" % (kname, subkey, fn)) - - def conf_func_tf(self, arg, sect, kname): - """Returns key from Terraform state file. - - Usage: ${TF ! tfvar} - """ - if ":" in arg: - state_file, arg = [s.strip() for s in arg.split(":", 1)] - else: - state_file = self.cf.get('tf_state_file') - val = tf_load_output_var(state_file, arg) - - # configparser expects strings - if isinstance(val, str): - # work around tf dots in route53 data - val = val.strip().rstrip('.') - elif isinstance(val, int): - val = str(val) - elif isinstance(val, float): - val = repr(val) - elif isinstance(val, bool): - val = str(val).lower() - else: - raise UsageError("TF function got invalid type: %s - %s" % (kname, type(val))) - return val - - def conf_func_members(self, arg, sect, kname): - """Returns field that match patters. - - Usage: ${MEMBERS ! pat : fn : field} - """ - pats, bfn, field = arg.split(':') - fn = os.path.join(self.keys_dir, bfn.strip()) - if not os.path.isfile(fn): - raise UsageError('%s - MEMBERS file missing: %s' % (kname, fn)) - - idx = int(field, 10) - - findLabels = [] - for p in pats.split(','): - p = p.strip() - if p: - findLabels.append(p) - - res = [] - for ln in open(fn): - ln = ln.strip() - if not ln or ln[0] == '#': - continue - got = False - parts = ln.split(':') - user = parts[0].strip() - for label in parts[idx].split(','): - label = label.strip() - if label and label in findLabels: - got = True - break - if got and user not in res: - res.append(user) - - return ', '.join(res) - - def conf_func_tfaz(self, arg, sect, kname): - """Returns key from Terraform state file. - - Usage: ${TFAZ ! tfvar} - """ - if self.options.verbose: - printf("TFAZ: %s", arg) - if ":" in arg: - state_file, arg = [s.strip() for s in arg.split(":", 1)] - else: - state_file = self.cf.get('tf_state_file') - val = tf_load_output_var(state_file, arg) - if not isinstance(val, list): - raise UsageError("TFAZ function expects list param: %s" % kname) - if self.availability_zone < 0 or self.availability_zone >= len(val): - raise UsageError("AZ value out of range") - return val[self.availability_zone] - - def conf_func_primary_vm(self, arg, sect, kname): - """Lookup primary vm. - - Usage: ${PRIMARY_VM ! ${other_role}} - """ - vm = self.get_primary_for_role(arg) - return vm['InstanceId'] - - def conf_func_network(self, arg, sect, kname): - """Extract network address from CIDR. - """ - return str(ipaddress.ip_network(arg).network_address) - - def conf_func_netmask(self, arg, sect, kname): - """Extract 32-bit netmask from CIDR. - """ - return str(ipaddress.ip_network(arg).netmask) - def do_prep(self, vm_id: str): """Run initialized 'prep' command. """ @@ -2749,298 +2120,18 @@ def load_vm_file(self, vm_id, fn): return self.vm_exec(vm_id, load_cmd, get_output=True) def load_secondary_vars(self, primary_id): - vmap = self.cf.getdict('load_secondary_files', {}) + vmap = self.cf.getdict("load_secondary_files", {}) for vname, primary_file in vmap.items(): eprintf("Loading %s:%s", primary_id, primary_file) data = self.load_vm_file(primary_id, primary_file) self.cf.set(vname, as_unicode(data)) - _PREP_TGZ_CACHE = {} # cmd->tgz - _PREP_STAMP_CACHE = {} # cmd->stamp - - def cmd_mod_test(self, cmd_name): - """Test if payload can be created for command. - - Group: internal - """ - self.modcmd_init(cmd_name) - data = self._PREP_TGZ_CACHE[cmd_name] - print("Data size: %d bytes" % len(data)) - return data - - def cmd_mod_dump(self, cmd_name): - """Write tarball of command payload. - - Group: internal - """ - self.modcmd_init(cmd_name) - data = self._PREP_TGZ_CACHE[cmd_name] - fn = 'data.tgz' - open(fn, 'wb').write(data) - print("%s: %d bytes" % (fn, len(data))) - - def cmd_mod_show(self, cmd_name): - """Show vmlibs used for command. - - Group: internal - """ - cwd = self.git_dir - os.chdir(cwd) - - cmd_cf = self.cf.view_section('cmd.%s' % cmd_name) - - vmlibs = cmd_cf.getlist('vmlibs') - - print("Included libs") - got = set() - for mod in vmlibs: - if mod not in got: - print("+ " + mod) - got.add(mod) - - exc_libs = [] - for mod in xglob('vmlib/**/setup.sh'): - mod = '/'.join(mod.split('/')[1:-1]) - if mod not in got: - exc_libs.append(mod) - exc_libs.sort() - - print("Excluded libs") - for mod in exc_libs: - print("- " + mod) - - def has_modcmd(self, cmd_name: VmCmd): - """Return true if command is configured from config. - """ - return self.cf.has_section('cmd.%s' % cmd_name) - - def modcmd_init(self, cmd_name: VmCmd): - """Run init script for command. - """ - cmd_cf = self.cf.view_section('cmd.%s' % cmd_name) - init_script = cmd_cf.get('init', '') - if init_script: - # let subprocess see current env - subenv = os.environ.copy() - subenv['VMTOOL_ENV_NAME'] = self.full_role - run_successfully([init_script], cwd=self.git_dir, shell=True, env=subenv) - - self.modcmd_prepare(cmd_name) - - def modcmd_prepare(self, cmd_name: VmCmd): - """Prepare data package for command. - """ - cmd_cf = self.cf.view_section('cmd.%s' % cmd_name) - stamp_dirs = cmd_cf.getlist('stamp_dirs', []) - cmd_abbr = cmd_cf.get('command_tag', '') - globs = cmd_cf.getlist('files', []) - use_admin = cmd_cf.getboolean('use_admin', False) - - self._PREP_TGZ_CACHE[cmd_name] = b'' - self.modcmd_build_tgz(cmd_name, globs, cmd_cf) - - self._PREP_STAMP_CACHE[cmd_name] = { - 'cmd_abbr': cmd_abbr, - 'stamp_dirs': stamp_dirs, - 'stamp': self.get_stamp(), - 'use_admin': use_admin, - } - - def modcmd_run(self, cmd_name, vm_ids): - """Send mod data to server and run it. - """ - info = self._PREP_STAMP_CACHE[cmd_name] - data_info = 0 - for vm_id in vm_ids: - data = self._PREP_TGZ_CACHE[cmd_name] - if not data_info: - data_info = 1 - print('RUNNING...') - self.run_mod_data(data, vm_id, use_admin=info['use_admin'], title=cmd_name) - if info['cmd_abbr']: - self.set_stamp(vm_id, info['cmd_abbr'], info['stamp'], *info['stamp_dirs']) - - def process_pkgs(self): - """Merge per-pkg variables into main config. - - Converts: - - [pkg.foo] - pkg_pyinstall_vmlibs = a, b - [pkg.bar] - pkg_pyinstall_vmlibs = c, d - - To: - [vm-config] - pkg_pyinstall_vmlibs = a, b, c, d - """ - cf = self.cf.cf - vmap = {} - for sect in cf.sections(): - if sect.startswith('pkg.'): - for opt in cf.options(sect): - if opt not in vmap: - vmap[opt] = [] - done = set(vmap[opt]) - val = cf.get(sect, opt) - for v in val.split(','): - v = v.strip() - if v and (v not in done): - vmap[opt].append(v) - done.add(v) - for k, v in vmap.items(): - cf.set('vm-config', k, ', '.join(v)) - - # in use - def modcmd_build_tgz(self, cmd_name, globs, cmd_cf=None): - cwd = self.git_dir - os.chdir(cwd) - - defs = {} - mods_ok = True - vmlibs = [] - cert_fns = set() - if cmd_cf: - vmlibs = cmd_cf.getlist('vmlibs', []) - if vmlibs: - done_vmlibs = [] - vmdir = 'vmlib' - globs = list(globs) - for mod in vmlibs: - if mod in done_vmlibs: - continue - if not mod: - continue - mdir = os.path.join(vmdir, mod) - if not os.path.isdir(mdir): - printf("Missing module: %s" % mdir) - mods_ok = False - elif not os.path.isfile(mdir + '/setup.sh'): - printf("Broken module, no setup.sh: %s" % mdir) - mods_ok = False - globs.append('vmlib/%s/**' % mod) - done_vmlibs.append(mod) - - cert_ini = os.path.join(mdir, 'certs.ini') - if os.path.isfile(cert_ini): - cert_fns.add(cert_ini) - defs['vm_modules'] = '\n'.join(done_vmlibs) + '\n' - globs.append('vmlib/runner.*') - globs.append('vmlib/shared/**') - if not mods_ok: - sys.exit(1) - - dst = self.make_tar_filter(defs) - - for tmp in globs: - subdir = '.' - if isinstance(tmp, str): - flist = xglob(tmp) - else: - subdir = tmp[1] - if subdir and subdir != '.': - os.chdir(subdir) - else: - subdir = '.' - flist = xglob(tmp[0]) - if len(tmp) > 2: - exlist = tmp[2:] - flist2 = [] - for fn in flist: - skip = False - for ex in exlist: - if fnmatch(fn, ex): - skip = True - break - if not skip: - flist2.append(fn) - flist = iter(flist2) - if subdir: - os.chdir(cwd) - - for fn in flist: - real_fn = os.path.join(subdir, fn) - if os.path.isdir(real_fn): - #dst.add_dir(item.path, stat.S_IRWXU, item.mtime) - pass - else: - with open(real_fn, 'rb') as f: - st = os.fstat(f.fileno()) - data = f.read() - dst.add_file_data(fn, data, st.st_mode & stat.S_IRWXU, st.st_mtime) - - # pass parameters to cert.ini files - defs = {'env_name': self.env_name} - if self.role_name: - defs['role_name'] = self.role_name - if self.cf.has_section('ca-config'): - items = self.cf.view_section('ca-config').items() - defs.update(items) - - # create keys & certs - for cert_ini in cert_fns: - printf("Processing certs: %s", cert_ini) - mdir = os.path.dirname(cert_ini) - keys = load_cert_config(cert_ini, self.load_ca_keypair, defs) - for kname in keys: - key, cert, _ = keys[kname] - key_fn = '%s/%s.key' % (mdir, kname) - cert_fn = '%s/%s.crt' % (mdir, kname) - dst.add_file_data(key_fn, key, 0o600) - dst.add_file_data(cert_fn, cert, 0o600) - - # finish - dst.close() - tgz = dst.getvalue() - self._PREP_TGZ_CACHE[cmd_name] = tgz - time_printf("%s: tgz bytes: %s", cmd_name, len(tgz)) - - def load_ca_keypair(self, ca_name): - intca_dir = self.cf.get(ca_name + '_dir', '') - if not intca_dir: - intca_dir = self.cf.get('intca_dir') - pat = '%s/%s/%s_*.key.gpg' % (self.keys_dir, intca_dir, ca_name) - res = list(sorted(xglob(pat))) - if not res: - raise UsageError("CA not found: %s - %s" % (ca_name, intca_dir)) - #names = [fn.split('/')[-1] for fn in res] - idx = 0 # -1 - last_key = res[idx] - #printf("CA: using %s from [%s]", names[idx], ', '.join(names)) - last_crt = last_key.replace('.key.gpg', '.crt') - if not os.path.isfile(last_crt): - raise UsageError("CA cert not found: %s" % last_crt) - if not os.path.isfile(last_key): - raise UsageError("CA key not found: %s" % last_key) - return (last_key, last_crt) - - def run_mod_data(self, data, vm_id, use_admin=False, title=None): - - tmp_uuid = str(uuid.uuid4()) - run_user = 'root' - - launcher = './tmp/%s/vmlib/runner.sh "%s"' % (tmp_uuid, vm_id) - rm_cmd = 'rm -rf' - if run_user: - launcher = 'sudo -nH -u %s %s' % (run_user, launcher) - rm_cmd = 'sudo -nH ' + rm_cmd - - time_printf("%s: Sending data - %d bytes", vm_id, len(data)) - decomp_script = 'install -d -m 711 tmp && mkdir -p "tmp/%s" && tar xzf - --warning=no-timestamp -C "tmp/%s"' % ( - tmp_uuid, tmp_uuid - ) - self.vm_exec(vm_id, ["/bin/sh", "-c", decomp_script, 'decomp'], data, use_admin=use_admin) - - time_printf("%s: Running", vm_id) - cmdline = ["/bin/sh", "-c", launcher, 'runit'] - self.vm_exec_tmux(vm_id, cmdline, use_admin=use_admin, title=title) - def cmd_tmux_attach(self, vm_id): """Attach to regular non-admin session. Group: vm """ - cmdline = shlex.split(self.cf.get('tmux_attach')) + cmdline = shlex.split(self.cf.get("tmux_attach")) self.vm_exec(vm_id, cmdline, None, use_admin=False) def cmd_tmux_attach_admin(self, vm_id): @@ -3048,7 +2139,7 @@ def cmd_tmux_attach_admin(self, vm_id): Group: vm """ - cmdline = shlex.split(self.cf.get('tmux_attach')) + cmdline = shlex.split(self.cf.get("tmux_attach")) self.vm_exec(vm_id, cmdline, None, use_admin=True) def cmd_get_output(self, vm_id): @@ -3058,10 +2149,10 @@ def cmd_get_output(self, vm_id): """ client = self.get_ec2_client() res = client.get_console_output(InstanceId=vm_id) - if res.get('Output'): + if res.get("Output"): # py3 still manages to organize codec=ascii errors - f = os.fdopen(sys.stdout.fileno(), 'wb', buffering=0) - v = res['Output'].encode('utf8', 'replace') + f = os.fdopen(sys.stdout.fileno(), "wb", buffering=0) + v = res["Output"].encode("utf8", "replace") f.write(v) def cmd_show_primary(self): @@ -3084,13 +2175,13 @@ def get_private_iface(self, vm_id): last_idx = None iface_id = None for vm in self.ec2_iter_instances(InstanceIds=[vm_id]): - if vm['InstanceId'] != vm_id: + if vm["InstanceId"] != vm_id: continue - for iface in vm['NetworkInterfaces']: - cur_idx = iface['Attachment']['DeviceIndex'] + for iface in vm["NetworkInterfaces"]: + cur_idx = iface["Attachment"]["DeviceIndex"] if last_idx is None or cur_idx < last_idx: - iface_id = iface['NetworkInterfaceId'] - last_idx = iface['Attachment']['DeviceIndex'] + iface_id = iface["NetworkInterfaceId"] + last_idx = iface["Attachment"]["DeviceIndex"] return iface_id def raw_assign_vm_private_ip(self, vm_id, private_ip): @@ -3106,50 +2197,50 @@ def raw_assign_vm_private_ip(self, vm_id, private_ip): def raw_assign_vm(self, vm_id): """Actual assign(). Returns old vm_id. """ - res = res2 = res3 = res4 = None - domain_eip = self.cf.get('domain_eip', '') + res = res2 = None + domain_eip = self.cf.get("domain_eip", "") if domain_eip: res = self.raw_assign_vm_eip(vm_id, domain_eip) - assign_private_ip = self.cf.get('assign_private_ip', '') + assign_private_ip = self.cf.get("assign_private_ip", "") if assign_private_ip: - res4 = self.raw_assign_vm_private_ip(vm_id, assign_private_ip) + res2 = self.raw_assign_vm_private_ip(vm_id, assign_private_ip) - public_dns_zone_id = self.cf.get('public_dns_zone_id', '') - zone_id = self.cf.get('internal_dns_zone_id', '') + public_dns_zone_id = self.cf.get("public_dns_zone_id", "") + zone_id = self.cf.get("internal_dns_zone_id", "") if zone_id or public_dns_zone_id: self.cmd_assign_dns(vm_id) - internal_eni = self.cf.get('internal_eni', '') + internal_eni = self.cf.get("internal_eni", "") if internal_eni: self.cmd_assign_eni(vm_id) - return res or res4 + return res or res2 def cmd_assign_eni(self, vm_id): """Assign Elastic Network Interface to VM. Group: vm """ - internal_eni = self.cf.get('internal_eni') + internal_eni = self.cf.get("internal_eni") client = self.get_ec2_client() res = client.describe_network_interfaces(NetworkInterfaceIds=[internal_eni]) - for iface in res['NetworkInterfaces']: - att = iface.get('Attachment') - if att and att.get('InstanceId'): - att_id = att['AttachmentId'] - old_vm_id = att['InstanceId'] + for iface in res["NetworkInterfaces"]: + att = iface.get("Attachment") + if att and att.get("InstanceId"): + att_id = att["AttachmentId"] + old_vm_id = att["InstanceId"] - printf('detaching %s from %s', att_id, old_vm_id) + printf("detaching %s from %s", att_id, old_vm_id) client.detach_network_interface(AttachmentId=att_id, Force=True) printf("waiting until ENI is detached") while True: time.sleep(5) wres = client.describe_network_interfaces(NetworkInterfaceIds=[internal_eni]) - if wres['NetworkInterfaces'][0]['Status'] == 'available': + if wres["NetworkInterfaces"][0]["Status"] == "available": break printf("attaching ENI") @@ -3166,16 +2257,16 @@ def raw_assign_vm_eip(self, vm_id, ip): alloc_id = None cur_vm_id = None res = client.describe_addresses() # FIXME: filter early - for a in res['Addresses']: - if a.get('PublicIp') == ip: - cur_vm_id = a.get('InstanceId') - alloc_id = a.get('AllocationId') + for a in res["Addresses"]: + if a.get("PublicIp") == ip: + cur_vm_id = a.get("InstanceId") + alloc_id = a.get("AllocationId") break - subnet_id = self.cf.get('subnet_id') + #subnet_id = self.cf.get("subnet_id") args = dict(InstanceId=vm_id) - args['AllocationId'] = alloc_id - args['AllowReassociation'] = True + args["AllocationId"] = alloc_id + args["AllowReassociation"] = True client.associate_address(**args) self.wait_switch(vm_id, ip) time.sleep(10) @@ -3198,28 +2289,28 @@ def assign_vm(self, vm_id, stop_old_vm=False): def wait_switch(self, vm_id, ip, debug=False): printf("waiting for ip switch") - while 1: + while True: time.sleep(10) vm = self.vm_lookup(vm_id, cache=False) - if vm.get('PublicIpAddress') == ip: + if vm.get("PublicIpAddress") == ip: break # reset cache self._vm_map = {} printf("waiting until vm is online") - while 1: + while True: time.sleep(10) # look if SSH works - hdr = b'' + hdr = b"" try: s = socket.create_connection((ip, 22), 10) hdr = s.recv(128) # pylint:disable=no-member s.close() if debug: print(repr(hdr)) - if hdr.find(b'OpenSSH') < 0: + if hdr.find(b"OpenSSH") < 0: continue except Exception as d: if debug: @@ -3229,9 +2320,9 @@ def wait_switch(self, vm_id, ip, debug=False): # check actual instance id if True: return - cmd = ['wget', '-q', '-O-', 'http://169.254.169.254/latest/meta-data/instance-id'] + cmd = ["wget", "-q", "-O-", "http://169.254.169.254/latest/meta-data/instance-id"] cur_id = self.vm_exec(vm_id, cmd, get_output=True, check_error=False) - if cur_id == vm_id.encode('utf8'): + if cur_id == vm_id.encode("utf8"): return def cmd_test_wait(self): @@ -3241,7 +2332,7 @@ def cmd_test_wait(self): """ ids = self.get_primary_vms() vm = self.vm_lookup(ids[0]) - self.wait_switch(vm['InstanceId'], vm['PublicIpAddress'], True) + self.wait_switch(vm["InstanceId"], vm["PublicIpAddress"], True) def cmd_failover(self, secondary_id, *old_primary_ids): """Takeover for dead primary. @@ -3249,7 +2340,7 @@ def cmd_failover(self, secondary_id, *old_primary_ids): Group: vm """ if self.options.tmux: - raise UsageError('This command does not support tmux') + raise UsageError("This command does not support tmux") self.change_cwd_adv() @@ -3259,10 +2350,10 @@ def cmd_failover(self, secondary_id, *old_primary_ids): else: primary_ids = self.get_dead_primary() if len(primary_ids) > 1: - raise UsageError('Dont know how to handle several primaries') + raise UsageError("Dont know how to handle several primaries") primary_id = primary_ids[0] - self.cf.set('primary_vm_id', primary_id) + self.cf.set("primary_vm_id", primary_id) # make sure it exists self.vm_lookup(secondary_id) @@ -3299,7 +2390,7 @@ def cmd_takeover(self, secondary_id): # old primary primary_id = vm_ids[0] - self.cf.set('primary_vm_id', primary_id) + self.cf.set("primary_vm_id", primary_id) cmd = VmCmd.TAKEOVER_PREPARE_PRIMARY if self.has_modcmd(cmd): @@ -3382,80 +2473,6 @@ def cmd_drop_node(self, vm_id): self.cmd_stop(vm_id) - def load_modcmd_args(self, args): - vms = [] - for a in args: - if a.startswith('i-'): - vms.append(a) - else: - raise UsageError("command supports only vmid args") - if vms: - return vms - return self.get_primary_vms() - - def work(self): - cmd = self.options.command - cmdargs = self.options.args - if not cmd: - raise UsageError("Need command") - #eprintf('vmtool - env_name: %s git_dir: %s', self.env_name, self.git_dir) - cmd_section = 'cmd.%s' % cmd - if self.cf.has_section(cmd_section): - cf2 = self.cf.view_section(cmd_section) - if cf2.get('vmlibs', ''): - vms = self.load_modcmd_args(cmdargs) - self.change_cwd_adv() - self.modcmd_init(cmd) - self.modcmd_run(cmd, vms) - else: - self.run_console_cmd(cmd, cmdargs) - else: - super(VmTool, self).work() - - def run_console_cmd(self, cmd, cmdargs): - cmd_cf = self.cf.view_section('cmd.%s' % cmd) - cmdline = cmd_cf.get('vmrun') - argparam = cmd_cf.get('vmrun_arg_param', '') - - fullcmd = shlex.split(cmdline) - vm_ids, args = self.get_vm_args(cmdargs, allow_multi=True) - if args: - if argparam: - fullcmd = fullcmd + [argparam, ' '.join(args)] - else: - fullcmd = fullcmd + args - - if len(vm_ids) > 1 and self.options.tmux: - raise UsageError("Cannot use tmux in parallel") - - for vm_id in vm_ids: - if len(vm_ids) > 1: - time_printf("Running on VM %s", vm_id) - self.vm_exec_tmux(vm_id, fullcmd, title=cmd) - - def change_cwd_adv(self): - # cd .. until there is .git - if not self._change_cwd_gittop(): - os.chdir(self.git_dir) - - def _change_cwd_gittop(self): - vmlib = 'vmlib/runner.sh' - num = 0 - maxstep = 30 - pfx = '.' - while True: - if os.path.isdir(os.path.join(pfx, '.git')): - if os.path.isfile(os.path.join(pfx, vmlib)): - os.chdir(pfx) - return True - else: - break - if num > maxstep: - break - pfx = os.path.join(pfx, '..') - num += 1 - return False - def get_image_id(self, image_name): client = self.get_ec2_client() res = client.describe_images(Owners=['self'], Filters=[{'Name': 'name', 'Values': [image_name]}]) @@ -3564,44 +2581,50 @@ def show_image_list(self, image_list, grprx=None): def show_image(self, img): """Details about single image. """ - printf("%s state=%s owner=%s alias=%s", img['ImageId'], img['State'], img['OwnerId'], img.get('ImageOwnerAlias', '-')) + printf( + "%s state=%s owner=%s alias=%s", + img["ImageId"], + img["State"], + img["OwnerId"], + img.get("ImageOwnerAlias", "-") + ) printf(" type=%s/%s/%s/%s/%s ctime=%s", - img['VirtualizationType'], img['RootDeviceType'], - img['Architecture'], img['Hypervisor'], - img['Public'] and 'public' or 'private', - img['CreationDate']) - printf(" name=%s", img['Name']) - if img.get('Description'): - printf(" desc=%s", img.get('Description')) - printf(" location=%s", img['ImageLocation']) + img["VirtualizationType"], img["RootDeviceType"], + img["Architecture"], img["Hypervisor"], + img["Public"] and "public" or "private", + img["CreationDate"]) + printf(" name=%s", img["Name"]) + if img.get("Description"): + printf(" desc=%s", img.get("Description")) + printf(" location=%s", img["ImageLocation"]) if self.load_tags(img): printf(" tags=%s", self.load_tags(img)) if not self.options.verbose: return printf(" disk_mapping:") - for bdt in img['BlockDeviceMappings']: - ebs = bdt.get('Ebs') or {} - if ebs.get('SnapshotId'): + for bdt in img["BlockDeviceMappings"]: + ebs = bdt.get("Ebs") or {} + if ebs.get("SnapshotId"): printf(" %s: snapshot=%s size=%s", - bdt.get('DeviceName'), ebs.get('SnapshotId'), ebs.get('VolumeSize')) + bdt.get("DeviceName"), ebs.get("SnapshotId"), ebs.get("VolumeSize")) else: printf(" %s: ephemeral=%s", - bdt.get('DeviceName'), bdt.get('VirtualName')) + bdt.get("DeviceName"), bdt.get("VirtualName")) def show_public_images(self, owner_id, namefilter, grprx): """Filtered request for public images. """ client = self.get_ec2_client() res = client.describe_images(Owners=[owner_id], Filters=[ - {'Name': 'state', 'Values': ['available']}, - {'Name': 'is-public', 'Values': ['true']}, - {'Name': 'architecture', 'Values': ['x86_64']}, # x86_64 / i386 / arm - {'Name': 'virtualization-type', 'Values': ['hvm']}, # paravirtual / hvm - {'Name': 'root-device-type', 'Values': ['ebs']}, # ebs / instance-store - {'Name': 'name', 'Values': [namefilter]}, + {"Name": "state", "Values": ["available"]}, + {"Name": "is-public", "Values": ["true"]}, + {"Name": "architecture", "Values": ["x86_64"]}, # x86_64 / i386 / arm + {"Name": "virtualization-type", "Values": ["hvm"]}, # paravirtual / hvm + {"Name": "root-device-type", "Values": ["ebs"]}, # ebs / instance-store + {"Name": "name", "Values": [namefilter]}, ]) - self.show_image_list(res['Images'], grprx) + self.show_image_list(res["Images"], grprx) def cmd_show_image(self, *amis): """Show specific public images @@ -3610,62 +2633,61 @@ def cmd_show_image(self, *amis): """ for ami in amis: region = None - if ':' in ami: + if ":" in ami: region, ami = ami.split(":") client = self.get_ec2_client(region) res = client.describe_images(ImageIds=[ami]) - self.show_image_list(res['Images']) + self.show_image_list(res["Images"]) def cmd_show_images_debian(self, *codes): """Show Debian images Group: image """ - owner_id = '379101102735' # https://wiki.debian.org/Cloud/AmazonEC2Image + owner_id = "379101102735" # https://wiki.debian.org/Cloud/AmazonEC2Image - pat = 'debian-*' + pat = "debian-*" if codes: - pat = 'debian-%s-*' % codes[0] - self.show_public_images(owner_id, pat, r'debian-\w+-') + pat = "debian-%s-*" % codes[0] + self.show_public_images(owner_id, pat, r"debian-\w+-") def cmd_show_images_debian_new(self, *codes): """Show Debian images Group: image """ - owner_id = '136693071363' # https://wiki.debian.org/Cloud/AmazonEC2Image/Buster + owner_id = "136693071363" # https://wiki.debian.org/Cloud/AmazonEC2Image/Buster - pat = 'debian-*' + pat = "debian-*" if codes: - pat = 'debian-%s-*' % codes[0] - self.show_public_images(owner_id, pat, r'debian-\w+-') + pat = "debian-%s-*" % codes[0] + self.show_public_images(owner_id, pat, r"debian-\w+-") def cmd_show_images_ubuntu(self, *codes): """Show Ubuntu images Group: image """ - owner_id = '099720109477' # Owner of images from https://cloud-images.ubuntu.com/ - #owner_id = '679593333241' # Marketplace user 'Canonical Group Limited' + owner_id = "099720109477" # Owner of images from https://cloud-images.ubuntu.com/ + #owner_id = "679593333241" # Marketplace user "Canonical Group Limited" - pat = 'ubuntu/images/*' + pat = "ubuntu/images/*" if codes: - pat += '/ubuntu-%s-*' % codes[0] - self.show_public_images(owner_id, pat, r'.*/ubuntu-\w+-') - + pat += "/ubuntu-%s-*" % codes[0] + self.show_public_images(owner_id, pat, r".*/ubuntu-\w+-") def cmd_show_images_ubuntu_minimal(self, *codes): """Show Ubuntu minimal images Group: image """ - owner_id = '099720109477' # Owner of images from https://cloud-images.ubuntu.com/ - #owner_id = '679593333241' # Marketplace user 'Canonical Group Limited' + owner_id = "099720109477" # Owner of images from https://cloud-images.ubuntu.com/ + #owner_id = "679593333241" # Marketplace user "Canonical Group Limited" - pat = 'ubuntu-minimal/images/*' + pat = "ubuntu-minimal/images/*" if codes: - pat += '/ubuntu-%s-*' % codes[0] - self.show_public_images(owner_id, pat, r'.*/ubuntu-\w+-') + pat += "/ubuntu-%s-*" % codes[0] + self.show_public_images(owner_id, pat, r".*/ubuntu-\w+-") def cmd_show_zones(self): """Show DNS zones set up under Route53. @@ -3674,87 +2696,87 @@ def cmd_show_zones(self): """ client = self.get_route53() res = client.list_hosted_zones() - for zone in res['HostedZones']: - printf('%s - privale=%s desc=%s', zone['Name'], - zone['Config']['PrivateZone'], zone['Config']['Comment']) + for zone in res["HostedZones"]: + printf("%s - privale=%s desc=%s", zone["Name"], + zone["Config"]["PrivateZone"], zone["Config"]["Comment"]) def cmd_show_zone(self): """Show records under one DNS zone. Group: info """ - zone_id = self.cf.get('internal_dns_zone_id') + zone_id = self.cf.get("internal_dns_zone_id") for rres in self.route53_iter_rrsets(HostedZoneId=zone_id): - printf('%s %s', rres['Name'], rres['Type']) - for vrec in rres['ResourceRecords']: - printf(' %s', vrec['Value']) + printf("%s %s", rres["Name"], rres["Type"]) + for vrec in rres["ResourceRecords"]: + printf(" %s", vrec["Value"]) def cmd_assign_dns(self, vm_id): """Assign DNS entries to VM. Group: vm """ - zone_id = self.cf.get('internal_dns_zone_id') - rev_zone_id = self.cf.get('internal_arpa_zone_id', '') - zone_name = self.cf.get('internal_dns_zone_name') - local_name = self.cf.get('internal_dns_vm_name') - public_dns_zone_id = self.cf.get('public_dns_zone_id', '') - public_dns_full_name = self.cf.get('public_dns_full_name', '') - public_dns_ttl = self.cf.get('public_dns_ttl', '60') + zone_id = self.cf.get("internal_dns_zone_id") + rev_zone_id = self.cf.get("internal_arpa_zone_id", "") + zone_name = self.cf.get("internal_dns_zone_name") + local_name = self.cf.get("internal_dns_vm_name") + public_dns_zone_id = self.cf.get("public_dns_zone_id", "") + public_dns_full_name = self.cf.get("public_dns_full_name", "") + public_dns_ttl = self.cf.get("public_dns_ttl", "60") vm = self.vm_lookup(vm_id) - internal_ip = vm['PrivateIpAddress'] - public_ip = vm.get('PublicIpAddress') + internal_ip = vm["PrivateIpAddress"] + public_ip = vm.get("PublicIpAddress") # internal dns - int_full_name = '%s.%s' % (local_name, zone_name) - if not int_full_name.endswith('.'): - int_full_name = int_full_name + '.' + int_full_name = "%s.%s" % (local_name, zone_name) + if not int_full_name.endswith("."): + int_full_name = int_full_name + "." changes = [ - {'Action': 'UPSERT', - 'ResourceRecordSet': { - 'Name': int_full_name, - 'Type': 'A', - 'TTL': int(public_dns_ttl), - 'ResourceRecords': [{'Value': internal_ip}]}}] - batch = {'Comment': 'assign-dns', 'Changes': changes} + {"Action": "UPSERT", + "ResourceRecordSet": { + "Name": int_full_name, + "Type": "A", + "TTL": int(public_dns_ttl), + "ResourceRecords": [{"Value": internal_ip}]}}] + batch = {"Comment": "assign-dns", "Changes": changes} time_printf("Assigning internal dns: %s -> %s", int_full_name, internal_ip) client = self.get_route53() res = client.change_resource_record_sets(HostedZoneId=zone_id, ChangeBatch=batch) - if res['ResponseMetadata']['HTTPStatusCode'] != 200: - eprintf('failed to set internal dns: %r', res) + if res["ResponseMetadata"]["HTTPStatusCode"] != 200: + eprintf("failed to set internal dns: %r", res) sys.exit(1) # internal reverse dns if rev_zone_id: - rev_name = '.'.join(reversed(internal_ip.split('.'))) + '.in-addr.arpa' + rev_name = ".".join(reversed(internal_ip.split("."))) + ".in-addr.arpa" changes = [ - {'Action': 'UPSERT', - 'ResourceRecordSet': { - 'Name': rev_name, 'Type': 'PTR', 'TTL': 60, - 'ResourceRecords': [{'Value': int_full_name}]}}] - batch = {'Comment': 'assign-rdns', 'Changes': changes} + {"Action": "UPSERT", + "ResourceRecordSet": { + "Name": rev_name, "Type": "PTR", "TTL": 60, + "ResourceRecords": [{"Value": int_full_name}]}}] + batch = {"Comment": "assign-rdns", "Changes": changes} time_printf("Assigning reverse dns: %s -> %s", rev_name, int_full_name) res = client.change_resource_record_sets(HostedZoneId=rev_zone_id, ChangeBatch=batch) - if res['ResponseMetadata']['HTTPStatusCode'] != 200: - eprintf('failed to set reverse dns: %r', res) + if res["ResponseMetadata"]["HTTPStatusCode"] != 200: + eprintf("failed to set reverse dns: %r", res) sys.exit(1) # public dns if public_dns_full_name: if not public_ip: - eprintf('request for public dns but vm does not have public ip: %r', vm_id) + eprintf("request for public dns but vm does not have public ip: %r", vm_id) sys.exit(1) - changes = [{'Action': 'UPSERT', - 'ResourceRecordSet': { - 'Name': public_dns_full_name, 'Type': 'A', 'TTL': 60, - 'ResourceRecords': [{'Value': public_ip}]}}] - batch = {'Comment': 'assign-dns', 'Changes': changes} + changes = [{"Action": "UPSERT", + "ResourceRecordSet": { + "Name": public_dns_full_name, "Type": "A", "TTL": 60, + "ResourceRecords": [{"Value": public_ip}]}}] + batch = {"Comment": "assign-dns", "Changes": changes} time_printf("Assigning public dns: %s -> %s", public_dns_full_name, public_ip) res = client.change_resource_record_sets(HostedZoneId=public_dns_zone_id, ChangeBatch=batch) - if res['ResponseMetadata']['HTTPStatusCode'] != 200: - eprintf('failed to set public dns: %r', res) + if res["ResponseMetadata"]["HTTPStatusCode"] != 200: + eprintf("failed to set public dns: %r", res) sys.exit(1) # wait until locally seen @@ -3771,32 +2793,32 @@ def cmd_clean_dns(self): Group: vm """ - zone_id = self.cf.get('internal_dns_zone_id') - rev_zone_id = self.cf.get('internal_arpa_zone_id', '') + zone_id = self.cf.get("internal_dns_zone_id") + rev_zone_id = self.cf.get("internal_arpa_zone_id", "") - internal_subnet_cidr = self.cf.get('internal_subnet_cidr') + internal_subnet_cidr = self.cf.get("internal_subnet_cidr") net = ipaddress.IPv4Network(as_unicode(internal_subnet_cidr)) used_ips = set() for rec in self.route53_iter_rrsets(HostedZoneId=zone_id): - if rec['Type'] != 'A': + if rec["Type"] != "A": continue - for vrec in rec['ResourceRecords']: - ip = vrec['Value'] + for vrec in rec["ResourceRecords"]: + ip = vrec["Value"] addr = ipaddress.IPv4Address(as_unicode(ip)) if addr in net: used_ips.add(ip) for rec in self.route53_iter_rrsets(HostedZoneId=rev_zone_id): - if rec['Type'] != 'PTR': + if rec["Type"] != "PTR": continue - name = rec['Name'] - if not name.endswith('.in-addr.arpa.'): + name = rec["Name"] + if not name.endswith(".in-addr.arpa."): print(repr(rec)) continue - name = name.replace('.in-addr.arpa.', '') - ip = '.'.join(reversed(name.split('.'))) + name = name.replace(".in-addr.arpa.", "") + ip = ".".join(reversed(name.split("."))) addr = ipaddress.IPv4Address(as_unicode(ip)) if addr in net: if ip in used_ips: @@ -3805,50 +2827,50 @@ def cmd_clean_dns(self): print("Old: " + ip) def get_internal_dns_ips(self): - local_name = self.cf.get('internal_dns_vm_name', '') + local_name = self.cf.get("internal_dns_vm_name", "") if not local_name: return [] - zone_id = self.cf.get('internal_dns_zone_id') - zone_name = self.cf.get('internal_dns_zone_name') - full_name = '%s.%s' % (local_name, zone_name) + zone_id = self.cf.get("internal_dns_zone_id") + zone_name = self.cf.get("internal_dns_zone_name") + full_name = "%s.%s" % (local_name, zone_name) iplist = [] for rec in self.route53_iter_rrsets(HostedZoneId=zone_id, StartRecordName=full_name): - if rec['Type'] not in ('A', 'AAAA'): + if rec["Type"] not in ("A", "AAAA"): continue - if not rec['Name'].startswith(full_name): + if not rec["Name"].startswith(full_name): continue - for vrec in rec['ResourceRecords']: - iplist.append(vrec['Value']) + for vrec in rec["ResourceRecords"]: + iplist.append(vrec["Value"]) return iplist def get_dns_map(self, full=False): ipmap = {} - #local_name = self.cf.get('internal_dns_vm_name', '') - zone_id = self.cf.get('internal_dns_zone_id', '') + #local_name = self.cf.get("internal_dns_vm_name", "") + zone_id = self.cf.get("internal_dns_zone_id", "") if zone_id: for rec in self.route53_iter_rrsets(HostedZoneId=zone_id): - if rec['Type'] not in ('A', 'AAAA'): + if rec["Type"] not in ("A", "AAAA"): continue - for vrec in rec['ResourceRecords']: - ipmap[vrec['Value']] = rec['Name'] + for vrec in rec["ResourceRecords"]: + ipmap[vrec["Value"]] = rec["Name"] - #pub_name = self.cf.get('public_dns_full_name', '') - zone_id = self.cf.get('public_dns_zone_id', '') + #pub_name = self.cf.get("public_dns_full_name", "") + zone_id = self.cf.get("public_dns_zone_id", "") if zone_id: for rec in self.route53_iter_rrsets(HostedZoneId=zone_id): - if rec['Type'] not in ('A', 'AAAA'): + if rec["Type"] not in ("A", "AAAA"): continue - for vrec in rec['ResourceRecords']: - ipmap[vrec['Value']] = rec['Name'] + for vrec in rec["ResourceRecords"]: + ipmap[vrec["Value"]] = rec["Name"] # consider other zones - for zone_id in self.cf.getlist('extra_internal_dns_zone_ids', []): + for zone_id in self.cf.getlist("extra_internal_dns_zone_ids", []): for rec in self.route53_iter_rrsets(HostedZoneId=zone_id): - if rec['Type'] not in ('A', 'AAAA'): + if rec["Type"] not in ("A", "AAAA"): continue - for vrec in rec['ResourceRecords']: - ipmap[vrec['Value']] = rec['Name'] + for vrec in rec["ResourceRecords"]: + ipmap[vrec["Value"]] = rec["Name"] return ipmap @@ -3857,92 +2879,30 @@ def cmd_show_tf(self): Group: config """ - state_file = self.cf.get('tf_state_file') + state_file = self.cf.get("tf_state_file") tfvars = tf_load_all_vars(state_file) for k in sorted(tfvars.keys()): - parts = k.split('.') + parts = k.split(".") if len(parts) <= 3 or self.options.all: printf("%s = %s", k, tfvars[k]) - def cmd_show_config(self, *args): - """Show filled config for current VM. - - Group: config - """ - desc = self.env_name - if self.role_name: - desc += '.' + self.role_name - - fail = 0 - for sect in sorted(self.cf.sections()): - sect_header = f'[{sect}]' - for k in sorted(self.cf.cf.options(sect)): - if args and k not in args: - continue - if sect_header: - printf(sect_header) - sect_header = '' - try: - raw = self.cf.cf.get(sect, k, raw=True) - v = self.cf.cf.get(sect, k) - vs = v - if not self.options.verbose: - vs = vs.strip() - if vs.startswith('----') or vs.startswith('{'): - vs = vs.split('\n')[0] - else: - vs = re.sub(r'\n\s*', ' ', vs) - printf("%s = %s", k, vs) - else: - printf("%s = %s [%s] (%s)", k, vs, desc, raw) - except Exception as ex: - fail = 1 - eprintf("### ERROR ### key: '%s.%s' err: %s", sect, k, str(ex)) - if not sect_header: - printf('') - if fail: - sys.exit(fail) - - def cmd_show_config_raw(self, *args): - """Show filled config for current VM. - - Group: config - """ - self.cf.cf.write(sys.stdout) - - def cmd_check_config(self): - """Check if config works. - - Group: config - """ - fail = 0 - for k in self.cf.options(): - try: - self.cf.getlist(k) - except Exception as ex: - fail = 1 - printf("key: '%s' err: %s", k, str(ex)) - if fail: - printf("--problems--") - sys.exit(fail) - def cmd_test(self): """Test both config and initial payload for VM. Group: config """ self.cmd_check_config() - self.cmd_mod_test('prep') + self.cmd_mod_test("prep") def cmd_test_files(self): """Show contents of prep command payload. Group: internal """ - data = self.cmd_mod_test('prep') - rf = gzip.GzipFile(mode='rb', fileobj=io.BytesIO(data)) - tar = tarfile.TarFile(fileobj=rf) - tar.list() + data = self.cmd_mod_test("prep") + with gzip.GzipFile(mode="rb", fileobj=io.BytesIO(data)) as rf: + with tarfile.TarFile(fileobj=rf) as tar: + tar.list() def cmd_sts_decode(self, msg): """Decode payload from UnauthorizedOperation error. @@ -3950,9 +2910,9 @@ def cmd_sts_decode(self, msg): Group: internal """ # req: sts:DecodeAuthorizationMessage - client = self.get_boto3_client('sts') + client = self.get_boto3_client("sts") res = client.decode_authorization_message(EncodedMessage=msg) - dec = res['DecodedMessage'] + dec = res["DecodedMessage"] data = json.loads(dec) print_json(data) @@ -3960,20 +2920,20 @@ def cmd_sts_decode(self, msg): # Gen client certs # - def cmd_list_keys(self, path = ''): + def cmd_list_keys(self, path=""): """List issued keys. Group: kms """ for section_name in self.cf.sections(): - if not section_name.startswith('secrets'): + if not section_name.startswith("secrets"): continue secret_cf = self.cf.view_section(section_name) self._list_keys(secret_cf, path) def _list_keys(self, secret_cf, path): - kind = secret_cf.get('kind') - if path == 'ALL': + kind = secret_cf.get("kind") + if path == "ALL": pass elif path.startswith(kind): pass @@ -3982,17 +2942,17 @@ def _list_keys(self, secret_cf, path): cwd = self.git_dir os.chdir(cwd) - certs_dir = secret_cf.get('certs_dir') - certs_ini = os.path.join(certs_dir, 'certs.ini') + certs_dir = secret_cf.get("certs_dir") + certs_ini = os.path.join(certs_dir, "certs.ini") if not os.path.isfile(certs_ini): - raise ValueError('File not found: %s' % certs_ini) + raise ValueError("File not found: %s" % certs_ini) keys = load_cert_config(certs_ini, self.load_ca_keypair, {}) - client = self.get_boto3_client('secretsmanager') + client = self.get_boto3_client("secretsmanager") for kname, value in keys.items(): - if path == 'ALL': + if path == "ALL": pass - elif f'{kind}.{kname}'.startswith(path): + elif f"{kind}.{kname}".startswith(path): pass else: continue @@ -4000,40 +2960,40 @@ def _list_keys(self, secret_cf, path): self._list_key(client, secret_cf, cert_cf) def _list_key(self, client, secret_cf, cert_cf): - namespace = secret_cf.get('namespace') - stage = secret_cf.get('stage') - kind = secret_cf.get('kind') + namespace = secret_cf.get("namespace") + stage = secret_cf.get("stage") + kind = secret_cf.get("kind") - srvc_type = cert_cf['srvc_type'] - srvc_temp = cert_cf['srvc_temp'] - srvc_name = cert_cf['srvc_name'] - srvc_repo = cert_cf['srvc_repo'] + srvc_type = cert_cf["srvc_type"] + srvc_temp = cert_cf["srvc_temp"] + srvc_name = cert_cf["srvc_name"] + srvc_repo = cert_cf["srvc_repo"] secret_name = f"{namespace}/{stage}/{kind}/{srvc_repo}/{srvc_type}/{srvc_temp}/{srvc_name}" try: - r_description = client.describe_secret(SecretId = secret_name) + r_description = client.describe_secret(SecretId=secret_name) r_value = client.get_secret_value( - SecretId = secret_name) + SecretId=secret_name) printf(secret_name) - printf(pprint.pformat(r_description['Tags'])) - printf(pprint.pformat(json.loads(r_value['SecretString']))) + printf(pprint.pformat(r_description["Tags"])) + printf(pprint.pformat(json.loads(r_value["SecretString"]))) except client.exceptions.ResourceNotFoundException: pass - def cmd_upload_keys(self, path = ''): + def cmd_upload_keys(self, path=""): """Issue new certificates. Group: kms """ for section_name in self.cf.sections(): - if not section_name.startswith('secrets'): + if not section_name.startswith("secrets"): continue secret_cf = self.cf.view_section(section_name) self._upload_certs(secret_cf, path) def _upload_certs(self, secret_cf, path): - kind = secret_cf.get('kind') - if path == 'ALL': + kind = secret_cf.get("kind") + if path == "ALL": pass elif path.startswith(kind): pass @@ -4042,17 +3002,17 @@ def _upload_certs(self, secret_cf, path): cwd = self.git_dir os.chdir(cwd) - certs_dir = secret_cf.get('certs_dir') - certs_ini = os.path.join(certs_dir, 'certs.ini') + certs_dir = secret_cf.get("certs_dir") + certs_ini = os.path.join(certs_dir, "certs.ini") if not os.path.isfile(certs_ini): - raise ValueError('File not found') + raise ValueError("File not found") keys = load_cert_config(certs_ini, self.load_ca_keypair, {}) - client = self.get_boto3_client('secretsmanager') + client = self.get_boto3_client("secretsmanager") for kname, value in keys.items(): - if path == 'ALL': + if path == "ALL": pass - elif f'{kind}.{kname}'.startswith(path): + elif f"{kind}.{kname}".startswith(path): pass else: continue @@ -4061,93 +3021,93 @@ def _upload_certs(self, secret_cf, path): self._upload_cert(client, secret_cf, kname, key, cert, cert_cf) def _upload_cert(self, client, secret_cf, kname, key, cert, cert_cf): - namespace = secret_cf.get('namespace') - stage = secret_cf.get('stage') - kind = secret_cf.get('kind') + namespace = secret_cf.get("namespace") + stage = secret_cf.get("stage") + kind = secret_cf.get("kind") - srvc_type = cert_cf['srvc_type'] - srvc_temp = cert_cf['srvc_temp'] - srvc_name = cert_cf['srvc_name'] - srvc_repo = cert_cf['srvc_repo'] + srvc_type = cert_cf["srvc_type"] + srvc_temp = cert_cf["srvc_temp"] + srvc_name = cert_cf["srvc_name"] + srvc_repo = cert_cf["srvc_repo"] - db_name = cert_cf.get('db_name') - db_user = cert_cf.get('db_user') + db_name = cert_cf.get("db_name") + db_user = cert_cf.get("db_user") - ca_name = cert_cf.get('ca_name') + #ca_name = cert_cf.get("ca_name") root_cert = self._get_root_cert(cert_cf) secret_name = f"{namespace}/{stage}/{kind}/{srvc_repo}/{srvc_type}/{srvc_temp}/{srvc_name}" secret_data = { - 'key': key.decode('utf-8'), - 'crt': cert.decode('utf-8'), + "key": key.decode("utf-8"), + "crt": cert.decode("utf-8"), } - if cert_cf['usage'] == 'client': - secret_data['server_root_crt'] = root_cert.decode('utf-8') - elif cert_cf['usage'] == 'server': - secret_data['client_root_crt'] = root_cert.decode('utf-8') + if cert_cf["usage"] == "client": + secret_data["server_root_crt"] = root_cert.decode("utf-8") + elif cert_cf["usage"] == "server": + secret_data["client_root_crt"] = root_cert.decode("utf-8") else: - raise ValueError('Invalid value for usage: %s' % cert_cf['usage']) + raise ValueError("Invalid value for usage: %s" % cert_cf["usage"]) if db_name: - secret_data['db_name'] = db_name + secret_data["db_name"] = db_name if db_user: - secret_data['db_user'] = db_user - #if self.cf.has_option('%s_url' % ca_name): - # base_url = self.cf.get('%s_url' % ca_name) + secret_data["db_user"] = db_user + #if self.cf.has_option("%s_url" % ca_name): + # base_url = self.cf.get("%s_url" % ca_name) # srvc_url = f"https://{kind}-{srvc_type}-{srvc_temp}.{base_url}" - # secret_data['url'] = srvc_url + # secret_data["url"] = srvc_url secret_str = json.dumps(secret_data) secret_tags = [ - {'Key': 'namespace', 'Value': namespace}, - {'Key': 'stage', 'Value': stage}, - {'Key': 'kind', 'Value': kind}, - {'Key': 'srvc_type', 'Value': srvc_type}, - {'Key': 'srvc_temp', 'Value': srvc_temp}, - {'Key': 'srvc_name', 'Value': srvc_name}, - {'Key': 'srvc_repo', 'Value': srvc_repo}, + {"Key": "namespace", "Value": namespace}, + {"Key": "stage", "Value": stage}, + {"Key": "kind", "Value": kind}, + {"Key": "srvc_type", "Value": srvc_type}, + {"Key": "srvc_temp", "Value": srvc_temp}, + {"Key": "srvc_name", "Value": srvc_name}, + {"Key": "srvc_repo", "Value": srvc_repo}, ] - sec_extra_tags = self.cf.getdict('sec_extra_tags', {}) + sec_extra_tags = self.cf.getdict("sec_extra_tags", {}) for k, v in sec_extra_tags.items(): - secret_tags.append({'Key': k, 'Value': v}) + secret_tags.append({"Key": k, "Value": v}) try: - client.describe_secret(SecretId = secret_name) + client.describe_secret(SecretId=secret_name) is_existing_secret = True except client.exceptions.ResourceNotFoundException: is_existing_secret = False if is_existing_secret: response = client.update_secret( - SecretId = secret_name, - Description = secret_name, - KmsKeyId = secret_cf.get('kms_key_id'), - SecretString = secret_str) - printf('Updated secret: %s' % secret_name) + SecretId=secret_name, + Description=secret_name, + KmsKeyId=secret_cf.get("kms_key_id"), + SecretString=secret_str) + printf("Updated secret: %s" % response["Name"]) else: response = client.create_secret( - Name = secret_name, - Description = secret_name, - KmsKeyId = secret_cf.get('kms_key_id'), - SecretString = secret_str, - Tags = secret_tags) - printf('Created secret: %s' % secret_name) + Name=secret_name, + Description=secret_name, + KmsKeyId=secret_cf.get("kms_key_id"), + SecretString=secret_str, + Tags=secret_tags) + printf("Created secret: %s" % response["Name"]) def _get_root_cert(self, cf): - ca_dir = self.cf.get('%s_dir' % cf['ca_name']) - if cf['usage'] == 'client': - root_crt_fname = cf['server_root_crt'] - elif cf['usage'] == 'server': - root_crt_fname = cf['client_root_crt'] + ca_dir = self.cf.get("%s_dir" % cf["ca_name"]) + if cf["usage"] == "client": + root_crt_fname = cf["server_root_crt"] + elif cf["usage"] == "server": + root_crt_fname = cf["client_root_crt"] else: - raise ValueError('Invalid value for usage: %s' % cf['usage']) + raise ValueError("Invalid value for usage: %s" % cf["usage"]) - root_crt = '%s/%s/%s' % (self.keys_dir, ca_dir, root_crt_fname) - with open(root_crt, 'rb') as f: + root_crt = "%s/%s/%s" % (self.keys_dir, ca_dir, root_crt_fname) + with open(root_crt, "rb") as f: return f.read() def cmd_log_keys(self): @@ -4159,27 +3119,27 @@ def cmd_log_keys(self): os.chdir(cwd) for section_name in self.cf.sections(): - if not section_name.startswith('secrets'): + if not section_name.startswith("secrets"): continue secret_cf = self.cf.view_section(section_name) self._log_keys(secret_cf) def _log_keys(self, secret_cf): - namespace = secret_cf.get('namespace') - stage = secret_cf.get('stage') - kind = secret_cf.get('kind') + namespace = secret_cf.get("namespace") + stage = secret_cf.get("stage") + kind = secret_cf.get("kind") - client = self.get_boto3_client('secretsmanager') + client = self.get_boto3_client("secretsmanager") list_secrets_pager = self.pager(client, "list_secrets", "SecretList") for secret in list_secrets_pager(): - if not secret['Name'].startswith(f'{namespace}/{stage}/{kind}'): + if not secret["Name"].startswith(f"{namespace}/{stage}/{kind}"): continue - name = secret['Name'] - tags = secret['Tags'] + name = secret["Name"] + tags = secret["Tags"] srvc_name = None for tag in tags: - if tag['Key'] == 'srvc_name': - srvc_name = tag['Value'] + if tag["Key"] == "srvc_name": + srvc_name = tag["Value"] break if srvc_name is None: @@ -4189,12 +3149,12 @@ def _log_keys(self, secret_cf): os.makedirs(name) r_value = client.get_secret_value( - SecretId = name) + SecretId=name) - timestamp = r_value['CreatedDate'].strftime("%Y%m%d-%H%M%S") + timestamp = r_value["CreatedDate"].strftime("%Y%m%d-%H%M%S") - crt = json.loads(r_value['SecretString'])['crt'].encode('utf-8') - with open(f'{name}/{timestamp}.crt', 'wb') as f: + crt = json.loads(r_value["SecretString"])["crt"].encode("utf-8") + with open(f"{name}/{timestamp}.crt", "wb") as f: f.write(crt) def cmd_tag_keys(self): @@ -4206,22 +3166,22 @@ def cmd_tag_keys(self): raise Exception("No env_name") tags = [] - sec_extra_tags = self.cf.getdict('sec_extra_tags', {}) + sec_extra_tags = self.cf.getdict("sec_extra_tags", {}) for k, v in sec_extra_tags.items(): - tags.append({'Key': k, 'Value': v}) + tags.append({"Key": k, "Value": v}) if tags: - client = self.get_boto3_client('secretsmanager') + client = self.get_boto3_client("secretsmanager") pager = self.pager(client, "list_secrets", "SecretList") for secret in pager(): - if secret['Name'].startswith(f'dp/{self.env_name}/'): - client.tag_resource(SecretId=secret['Name'], Tags=tags) + if secret["Name"].startswith(f"dp/{self.env_name}/"): + client.tag_resource(SecretId=secret["Name"], Tags=tags) def fetch_disk_info(self, vm_ids): args = {} - args['Filters'] = self.get_env_filters() + args["Filters"] = self.get_env_filters() if vm_ids: - args['InstanceIds'] = vm_ids + args["InstanceIds"] = vm_ids vm_list = [] for vm in self.ec2_iter_instances(**args): @@ -4235,36 +3195,36 @@ def fetch_disk_info(self, vm_ids): # load disks from config disk_map = self.get_disk_map() - vm_disk_names_size_order = self.cf.getlist('vm_disk_names_size_order') + vm_disk_names_size_order = self.cf.getlist("vm_disk_names_size_order") final_list = [] for vm in vm_list: - if vm['State']['Name'] != 'running': + if vm["State"]["Name"] != "running": continue final_info = { - 'vm': vm, - 'config_disk_map': disk_map, - 'volume_map': {}, # name -> volume - 'device_map': {}, # name -> DeviceName + "vm": vm, + "config_disk_map": disk_map, + "volume_map": {}, # name -> volume + "device_map": {}, # name -> DeviceName } # load disk from running vm root_vol_id = None cur_vol_list = [] dev_map = {} # vol_id->dev_name - for bdev in vm.get('BlockDeviceMappings', []): - ebs = bdev.get('Ebs') + for bdev in vm.get("BlockDeviceMappings", []): + ebs = bdev.get("Ebs") if not ebs: continue - vol = vol_map[ebs['VolumeId']] - vol_info = (vol['Size'], ebs['VolumeId']) - dev_name = bdev.get('DeviceName') - dev_map[ebs['VolumeId']] = dev_name - if dev_name in ROOT_DEV_NAMES: - root_vol_id = ebs['VolumeId'] - final_info['volume_map']['root'] = vol - final_info['device_map']['root'] = dev_name + vol = vol_map[ebs["VolumeId"]] + vol_info = (vol["Size"], ebs["VolumeId"]) + dev_name = bdev.get("DeviceName") + dev_map[ebs["VolumeId"]] = dev_name + if dev_name in self.ROOT_DEV_NAMES: + root_vol_id = ebs["VolumeId"] + final_info["volume_map"]["root"] = vol + final_info["device_map"]["root"] = dev_name else: cur_vol_list.append(vol_info) @@ -4274,7 +3234,7 @@ def fetch_disk_info(self, vm_ids): # insert local disks ephemeral_nr = 0 for disk_name, disk_conf in disk_map.items(): - eph_name = disk_conf.get('ephemeral') + eph_name = disk_conf.get("ephemeral") if not eph_name: continue for nr in range(disk_conf["count"]): @@ -4294,47 +3254,47 @@ def fetch_disk_info(self, vm_ids): cur_vol_list.sort() for nr, (size, vol_id) in enumerate(cur_vol_list): vol_name = vm_disk_names_size_order[nr] - final_info['volume_map'][vol_name] = vol_map[vol_id] - final_info['device_map'][vol_name] = dev_map[vol_id] + final_info["volume_map"][vol_name] = vol_map[vol_id] + final_info["device_map"][vol_name] = dev_map[vol_id] final_list.append(final_info) return final_list def lookup_disk_config(self, vm_disk_info, vol_name): - if vol_name.split('.')[-1].isdigit(): - xvol_name = '.'.join(vol_name.split('.')[:-1]) - vol_conf = vm_disk_info['config_disk_map'][xvol_name] + if vol_name.split(".")[-1].isdigit(): + xvol_name = ".".join(vol_name.split(".")[:-1]) + vol_conf = vm_disk_info["config_disk_map"][xvol_name] else: - vol_conf = vm_disk_info['config_disk_map'][vol_name] + vol_conf = vm_disk_info["config_disk_map"][vol_name] return vol_conf def show_disk_info(self, vm_disk_info, vol_name): - vm = vm_disk_info['vm'] - vol_info = vm_disk_info['volume_map'][vol_name] - dev_name = vm_disk_info['device_map'][vol_name] + vm = vm_disk_info["vm"] + vol_info = vm_disk_info["volume_map"][vol_name] + dev_name = vm_disk_info["device_map"][vol_name] vol_conf = self.lookup_disk_config(vm_disk_info, vol_name) - cursize = vol_info['Size'] - newsize = vol_conf['size'] + cursize = vol_info["Size"] + newsize = vol_conf["size"] # ephemeral disks do not have VolumeType property if vol_info.get("VolumeType"): - curtype = vol_info['VolumeType'] + curtype = vol_info["VolumeType"] else: - curtype = vol_info['State'] + curtype = vol_info["State"] newtype = None for k, v in vol_conf.items(): - if k == 'type': + if k == "type": newtype = v elif k in self.VOL_TYPES: newtype = k elif k in self.VOL_ENC_TYPES: - newtype = k.split('-')[1] + newtype = k.split("-")[1] if not newtype: newtype = curtype - print("{vm_id}/{vol_id}".format(vm_id=vm['InstanceId'], vol_id=vol_info['VolumeId'])) + print("{vm_id}/{vol_id}".format(vm_id=vm["InstanceId"], vol_id=vol_info["VolumeId"])) print(f" name: {vol_name}, device: {dev_name}") flag = "" @@ -4348,19 +3308,19 @@ def show_disk_info(self, vm_disk_info, vol_name): # attachement state xlist = [] - for att in vol_info.get('Attachments', []): + for att in vol_info.get("Attachments", []): # State/Device/InstanceId/VolumeId/DeleteOnTermination - xlist.append(att['State']) + xlist.append(att["State"]) attinfo = "" if xlist: attinfo = " / " + ",".join(xlist) - # state: 'creating'|'available'|'in-use'|'deleting'|'deleted'|'error', + # state: "creating"|"available"|"in-use"|"deleting"|"deleted"|"error", print(f" state: {vol_info['State']}{attinfo}") - if vol_info.get('Iops'): - curiops = vol_info['Iops'] - newiops = vol_conf.get('iops') + if vol_info.get("Iops"): + curiops = vol_info["Iops"] + newiops = vol_conf.get("iops") if not newiops: newiops = curiops @@ -4370,9 +3330,9 @@ def show_disk_info(self, vm_disk_info, vol_name): flag = " !!!" print(f" curiops: {curiops}, newiops: {newiops}{flag}") - if vol_info.get('Throughput') or vol_conf.get('throughput'): - curthroughput = vol_info.get('Throughput') - newthroughput = vol_conf.get('throughput') + if vol_info.get("Throughput") or vol_conf.get("throughput"): + curthroughput = vol_info.get("Throughput") + newthroughput = vol_conf.get("throughput") if not newthroughput: newthroughput = curthroughput @@ -4385,7 +3345,6 @@ def show_disk_info(self, vm_disk_info, vol_name): flag = " !!!" print(f" curthroughput: {curthroughput}, newthroughput: {newthroughput}{flag}") - def cmd_show_disks(self, *vm_ids): """Show detailed volume info. @@ -4395,11 +3354,11 @@ def cmd_show_disks(self, *vm_ids): last_vm_id = None for vm_disk_info in vm_disk_list: - if last_vm_id and last_vm_id != vm_disk_info['vm']['InstanceId']: + if last_vm_id and last_vm_id != vm_disk_info["vm"]["InstanceId"]: print("") - last_vm_id = vm_disk_info['vm']['InstanceId'] + last_vm_id = vm_disk_info["vm"]["InstanceId"] - for vol_name in vm_disk_info['volume_map']: + for vol_name in vm_disk_info["volume_map"]: self.show_disk_info(vm_disk_info, vol_name) def cmd_modify_disks(self, vm_id): @@ -4411,7 +3370,7 @@ def cmd_modify_disks(self, vm_id): client = self.get_ec2_client() modified_vol_ids = [] for vm_info in vm_disk_list: - volume_map = vm_info['volume_map'] + volume_map = vm_info["volume_map"] for vol_name in volume_map: modify_args = {} skip = False @@ -4422,76 +3381,79 @@ def cmd_modify_disks(self, vm_id): vol_conf = self.lookup_disk_config(vm_info, vol_name) for k, v in vol_conf.items(): - if k == 'size': - if v < vol_info['Size']: - eprintf("WARNING: cannot decrease size: vol_name=%s old=%r new=%r", vol_name, vol_info['Size'], v) + if k == "size": + if v < vol_info["Size"]: + eprintf( + "WARNING: cannot decrease size: vol_name=%s old=%r new=%r", + vol_name, + vol_info["Size"], + v) skip = True - if v != vol_info['Size']: - modify_args['Size'] = v + if v != vol_info["Size"]: + modify_args["Size"] = v logmsg += ", newsize=%d" % (v) - elif k == 'iops': - if v != vol_info['Iops']: - modify_args['Iops'] = v + elif k == "iops": + if v != vol_info["Iops"]: + modify_args["Iops"] = v logmsg += ", newiops=%d" % (v) - elif k == 'throughput': - if v != vol_info.get('Throughput'): - modify_args['Throughput'] = v + elif k == "throughput": + if v != vol_info.get("Throughput"): + modify_args["Throughput"] = v logmsg += ", newthroughput=%d" % (v) - elif k == 'type': + elif k == "type": newtype = v elif k in self.VOL_TYPES: newtype = k elif k in self.VOL_ENC_TYPES: - newtype = k.split('-')[1] + newtype = k.split("-")[1] - if newtype and newtype != vol_info['VolumeType']: - modify_args['VolumeType'] = newtype + if newtype and newtype != vol_info["VolumeType"]: + modify_args["VolumeType"] = newtype logmsg += ", newtype=%s" % (newtype) # Iops is required input for io1 and io2, regardless what boto3 documentation says - if newtype in ('io1', 'io2') and not modify_args.get('Iops'): - if not vol_conf.get('iops'): + if newtype in ("io1", "io2") and not modify_args.get("Iops"): + if not vol_conf.get("iops"): eprintf("WARNING: cannot modify to %s without specifying IOPS: vol_name=%s", newtype, vol_name) skip = True continue - modify_args['Iops'] = vol_conf['iops'] - + modify_args["Iops"] = vol_conf["iops"] if skip or not modify_args: continue - modify_args['VolumeId'] = vol_info['VolumeId'] + modify_args["VolumeId"] = vol_info["VolumeId"] self.show_disk_info(vm_info, vol_name) # request size increase - printf("Modifying %s%s", vol_info['VolumeId'], logmsg) + printf("Modifying %s%s", vol_info["VolumeId"], logmsg) client.modify_volume(**modify_args) printf("Done") - modified_vol_ids.append(vol_info['VolumeId']) + modified_vol_ids.append(vol_info["VolumeId"]) iter_vol_modifications = self.pager( - client, 'describe_volumes_modifications', 'VolumesModifications' + client, "describe_volumes_modifications", "VolumesModifications" ) # wait until complete while modified_vol_ids: incomplete = 0 for mod in iter_vol_modifications(VolumeIds=modified_vol_ids): - mstate = mod.get('ModificationState') + mstate = mod.get("ModificationState") if not mstate: continue if mstate not in ( - 'completed', 'failed', + "completed", "failed", # takes very long time but the volume is immediately usable - 'optimizing', + "optimizing", ): incomplete += 1 - msgstatus = '' + msgstatus = "" if mod.get("StatusMessage"): msgstatus = " msg={StatusMessage}".format(**mod) # throughputs are not available for some volume types - msgtarget = '' + msgtarget = "" if mod.get("TargetThroughput") and mod.get("OriginalThroughput"): msgtarget = " oldthroughput={OriginalThroughput} newthroughput={TargetThroughput}".format(**mod) printf( @@ -4503,7 +3465,7 @@ def cmd_modify_disks(self, vm_id): if not incomplete: break - printf('') + printf("") time.sleep(2) time_printf("Finished") @@ -4523,10 +3485,11 @@ def cmd_tag_vmstate(self): client = self.get_ec2_client() for vm in self.ec2_iter_instances(Filters=self.get_env_filters()): - if vm['InstanceId'] in primary_vms: + if vm["InstanceId"] in primary_vms: vm_state = VmState.PRIMARY else: vm_state = VmState.SECONDARY - tags = [{'Key': 'VmState', 'Value': vm_state}] - client.create_tags(Resources=[vm['InstanceId']], Tags=tags) + tags = [{"Key": "VmState", "Value": vm_state}] + client.create_tags(Resources=[vm["InstanceId"]], Tags=tags) + diff --git a/vmtool/base.py b/vmtool/base.py new file mode 100644 index 0000000..a2b6245 --- /dev/null +++ b/vmtool/base.py @@ -0,0 +1,1107 @@ +"""Common logic. +""" + +import argparse +import binascii +import enum +import fnmatch +import ipaddress +import logging +import os.path +import re +import shlex +import stat +import subprocess +import sys +import uuid + +from vmtool.certs import load_cert_config +from vmtool.config import Config, NoOptionError +from vmtool.envconfig import find_gittop, load_env_config +from vmtool.scripting import EnvScript, UsageError +from vmtool.tarfilter import TarFilter +from vmtool.terra import tf_load_output_var +from vmtool.util import ( + as_unicode, eprintf, local_cmd, printf, + rsh_quote, run_successfully, time_printf, +) +from vmtool.xglob import xglob + +SSH_USER_CREATION = """\ +if ! grep -q '^{user}:' /etc/passwd; then + echo "Adding user {user}" + adduser -q --gecos {user} --disabled-password {user} < /dev/null + install -d -o {user} -g {user} -m 700 ~{user}/.ssh + echo "{pubkey}" > ~{user}/.ssh/authorized_keys + chmod 600 ~{user}/.ssh/authorized_keys + chown {user}:{user} ~{user}/.ssh/authorized_keys + for grp in {auth_groups}; do + adduser -q {user} $grp + done +fi +""" + + +def mk_sshuser_script(user, auth_groups, pubkey): + return SSH_USER_CREATION.format(user=user, auth_groups=" ".join(auth_groups), pubkey=pubkey) + + +class VmCmd(enum.Enum): + """Sub-command names used internally in vmtool. + """ + PREP = "prep" + FAILOVER_PROMOTE_SECONDARY = "failover_promote_secondary" + + TAKEOVER_PREPARE_PRIMARY = "takeover_prepare_primary" + TAKEOVER_PREPARE_SECONDARY = "takeover_prepare_secondary" + TAKEOVER_FINISH_PRIMARY = "takeover_finish_primary" + TAKEOVER_FINISH_SECONDARY = "takeover_finish_secondary" + + DROP_NODE_PREPARE = "drop_node_prepare" + + +class VmToolBase(EnvScript): + # replace those with root specified by image + ROOT_DEV_NAMES = ("root",) + + log = logging.getLogger("vmtool") + + role_name = None + env_name = None # name of current env + full_role = None + ssh_dir = None + git_dir = None + keys_dir = None + is_live = False + availability_zone = None + ssh_known_hosts = None + + new_commit = None + old_commit = None + + def init_argparse(self, parser=None): + if parser is None: + parser = argparse.ArgumentParser(prog="vmtool") + p = super().init_argparse(parser) + #doc = self.__doc__.strip() + #p.set_usage(doc) + p.add_argument("--env", help="Set environment name (default comes from VMTOOL_ENV_NAME)") + p.add_argument("--role", help="Set role name (default: None)") + p.add_argument("--host", help="Use host instead detecting") + p.add_argument("--all", action="store_true", help="Make command work over all envs") + p.add_argument("--ssh-key", help="Use different SSH key") + p.add_argument("--all-role-vms", action="store_true", help="Run command on all vms for role") + p.add_argument("--all-role-fo-vms", action="store_true", help="Run command on all failover vms for role") + p.add_argument("--earlier-fo-vms", action="store_true", help="Run command on earlier failover vms for role") + p.add_argument("--latest-fo-vm", action="store_true", help="Run command on latest failover vm for rolw") + p.add_argument("--running", action="store_true", help="Show only running instances") + p.add_argument("--az", type=int, help="Set availability zone") + p.add_argument("--tmux", action="store_true", help="Wrap session in tmux") + return p + + def reload(self): + """Reload config. + """ + self.git_dir = find_gittop() + + # ~/.vmtool + ssh_dir = os.path.expanduser("~/.vmtool") + if not os.path.isdir(ssh_dir): + os.mkdir(ssh_dir, stat.S_IRWXU) + + keys_dir = os.environ.get("VMTOOL_KEY_DIR", os.path.join(self.git_dir, "keys")) + if not keys_dir or not os.path.isdir(keys_dir): + raise UsageError("Set vmtool config dir: VMTOOL_KEY_DIR") + + ca_log_dir = os.environ.get("VMTOOL_CA_LOG_DIR") + if not ca_log_dir or not os.path.isdir(ca_log_dir): + raise UsageError("Set vmtool config dir: VMTOOL_CA_LOG_DIR") + + env = os.environ.get("VMTOOL_ENV_NAME", "") + if self.options.env: + env = self.options.env + if not env: + raise UsageError("No envronment set: either set VMTOOL_ENV_NAME or give --env=ENV") + + env_name = env + self.full_role = env + if "." in env: + env_name, self.role_name = env.split(".") + self.env_name = env_name + if self.options.role: + self.role_name = self.options.role + self.full_role = "%s.%s" % (self.env_name, self.role_name) + + self.ca_log_dir = ca_log_dir + self.keys_dir = keys_dir + self.ssh_dir = ssh_dir + + self.cf = load_env_config(self.full_role, { + "FILE": self.conf_func_file, + "KEY": self.conf_func_key, + "TF": self.conf_func_tf, + "TFAZ": self.conf_func_tfaz, + "PRIMARY_VM": self.conf_func_primary_vm, + "NETWORK": self.conf_func_network, + "NETMASK": self.conf_func_netmask, + "MEMBERS": self.conf_func_members, + }) + self.process_pkgs() + + self._region = self.cf.get("region") + self.ssh_known_hosts = os.path.join(self.ssh_dir, "known_hosts") + self.is_live = self.cf.getint("is_live", 0) + + if self.options.az is not None: + self.availability_zone = self.options.az + else: + self.availability_zone = self.cf.getint("availability_zone", 0) + + # fill vm_ordered_disk_names + disk_map = self.get_disk_map() + if disk_map: + api_order = [] + size_order = [] + for dev in disk_map: + size = disk_map[dev]["size"] + count = disk_map[dev]["count"] + if size and dev not in self.ROOT_DEV_NAMES: + for i in range(count): + name = f"{dev}.{i}" if count > 1 else dev + size_order.append((size, i, name)) + api_order.append(name) + size_order.sort() + self.cf.set("vm_disk_names_size_order", ", ".join([elem[2] for elem in size_order])) + self.cf.set("vm_disk_names_api_order", ", ".join(api_order)) + + def get_disk_map(self): + """Parse disk_map option. + """ + disk_map = self.cf.getdict("disk_map", {}) + if not disk_map: + disk_map = {"root": "size=12"} + + res_map = {} + for dev in disk_map: + val = disk_map[dev] + local = {} + for opt in val.split(":"): + if "=" in opt: + k, v = opt.split("=", 1) + k = k.strip() + v = v.strip() + else: + k = v = opt.strip() + if not k: + continue + if k.startswith("ephemeral"): + k = "ephemeral" + if k in ("size", "count", "iops", "throughput"): + v = int(v) + local[k] = v + if "count" not in local: + local["count"] = 1 + if "size" not in local: + raise UsageError("Each element in disk_map needs size") + res_map[dev] = local + + # sanity check if requested + disk_require_order = self.cf.getlist("disk_require_order", []) + if disk_require_order: + # order from disk_map + got_order = sorted([ + (res_map[name]["size"], res_map[name]["count"], name) + for name in res_map + ]) + names_order = [f"{name}:{count}" for size, count, name in got_order] + + # order from disk_require_order + counted_order = [ + key if ":" in key else key + ":1" + for key in disk_require_order + ] + + if names_order != counted_order: + raise UsageError("Order mismatch:\n require=%r\n got=%r" % (counted_order, names_order)) + + return res_map + + _gpg_cache = None + def load_gpg_file(self, fn): + if self._gpg_cache is None: + self._gpg_cache = {} + if fn in self._gpg_cache: + return self._gpg_cache[fn] + if self.options.verbose: + printf("GPG: %s", fn) + # file data directly + if not os.path.isfile(fn): + raise UsageError("GPG file not found: %s" % fn) + data = self.popen(["gpg", "-q", "-d", "--batch", fn]) + res = as_unicode(data) + self._gpg_cache[fn] = res + return res + + def load_gpg_config(self, fn, main_section): + realfn = os.path.join(self.keys_dir, fn) + if not os.path.isfile(realfn): + raise UsageError("GPG file not found: %s" % realfn) + data = self.load_gpg_file(realfn) + cf = Config(main_section, None) + cf.cf.read_string(data, source=realfn) + return cf + + def popen(self, cmd, input_data=None, **kwargs): + """Read command stdout, check for exit code. + """ + pipes = {"stdout": subprocess.PIPE, "stderr": subprocess.PIPE} + if input_data is not None: + pipes["stdin"] = subprocess.PIPE + with subprocess.Popen(cmd, **kwargs, **pipes) as p: + out, err = p.communicate(input_data) + if p.returncode != 0: + raise Exception("command failed: %r - %r" % (cmd, err.strip())) + return out + + def load_command_docs(self): + doc = self.__doc__.strip() + doc = "" + grc = re.compile(r"Group: *(\w+)") + cmds = [] + + for fn in sorted(dir(self)): + if fn.startswith("cmd_"): + fobj = getattr(self, fn) + docstr = (getattr(fobj, "__doc__", "") or "").strip() + mgrp = grc.search(docstr) + grpname = mgrp and mgrp.group(1) or "" + lines = docstr.split("\n") + fdoc = lines[0] + cmd = fn[4:].replace("_", "-") + cmds.append((grpname, cmd, fdoc)) + + for sect in self.cf.sections(): + if sect.startswith("cmd.") or sect.startswith("alias."): + cmd = sect.split(".", 1)[1] + desc = "" + grpname = "" + if self.cf.cf.has_option(sect, "desc"): + desc = self.cf.cf.get(sect, "desc") + if self.cf.cf.has_option(sect, "group"): + grpname = self.cf.cf.get(sect, "group") + fdoc = desc.strip().split("\n")[0] + cmds.append((grpname, cmd, desc)) + + cmds.sort() + last_grp = None + sep = "" + for grpname, cmd, fdoc in cmds: + if grpname != last_grp: + doc += sep + "%s commands:\n" % (grpname or "ungrouped") + last_grp = grpname + sep = "\n" + doc += " %-30s - %s\n" % (cmd, fdoc) + return doc + + def cmd_help(self): + """Show help about commands. + + Group: info + """ + doc = self.load_command_docs() + printf(doc) + + def filter_key_lookup(self, predef, key, fname): + if key in predef: + return predef[key] + + if key == "MASTER_KEYS": + master_key_list = [] + nr = 1 + while True: + kname = "master_key_%d" % nr + v = self.cf.get(kname, "") + if not v: + break + master_key_list.append("%s = %s" % (kname, v)) + nr += 1 + if not master_key_list: + raise Exception("No master keys found") + master_key_conf = "\n".join(master_key_list) + return master_key_conf + + if key == "SYSRANDOM": + blk = os.urandom(3 * 16) + b64 = binascii.b2a_base64(blk).strip() + return b64.decode("utf8") + + if key == "AUTHORIZED_KEYS": + auth_users = self.cf.getlist("ssh_authorized_users", []) + pat = self.cf.get("ssh_pubkey_pattern") + keys = [] + for user in sorted(set(auth_users)): + fn = os.path.join(self.keys_dir, pat.replace("USER", user)) + with open(fn, "r", encoding="utf8") as f: + pubkey = f.read().strip() + keys.append(pubkey) + return "\n".join(keys) + + if key == "AUTHORIZED_USER_CREATION": + return self.make_user_creation() + + try: + return self.cf.get(key) + except NoOptionError: + raise UsageError("%s: key not found: %s" % (fname, key)) from None + + def make_user_creation(self): + auth_groups = self.cf.getlist("authorized_user_groups", []) + auth_users = self.cf.getlist("ssh_authorized_users", []) + pat = self.cf.get("ssh_pubkey_pattern") + script = [] + for user in sorted(set(auth_users)): + fn = os.path.join(self.keys_dir, pat.replace("USER", user)) + with open(fn, encoding="utf8") as f: + pubkey = f.read().strip() + script.append(mk_sshuser_script(user, auth_groups, pubkey)) + return "\n".join(script) + + def make_tar_filter(self, extra_defs=None): + defs = {} + if extra_defs: + defs.update(extra_defs) + tb = TarFilter(self.filter_key_lookup, defs) + tb.set_live(self.is_live) + return tb + + def conf_func_file(self, arg, sect, kname): + """Returns contents of file, optionally gpg-decrypted. + + Usage: ${FILE ! filename} + """ + if self.options.verbose: + printf("FILE: %s", arg) + fn = os.path.join(self.keys_dir, arg) + if not os.path.isfile(fn): + raise UsageError("%s - FILE missing: %s" % (kname, arg)) + if fn.endswith(".gpg"): + return self.load_gpg_file(fn).rstrip("\n") + with open(fn, "r", encoding="utf8") as f: + return f.read().rstrip("\n") + + def conf_func_key(self, arg, sect, kname): + """Returns key from Terraform state file. + + Usage: ${KEY ! fn : key} + """ + bfn, subkey = arg.split(":") + if self.options.verbose: + printf("KEY: %s : %s", bfn.strip(), subkey.strip()) + fn = os.path.join(self.keys_dir, bfn.strip()) + if not os.path.isfile(fn): + raise UsageError("%s - KEY file missing: %s" % (kname, fn)) + cf = self.load_gpg_config(fn, "vm-config") + subkey = as_unicode(subkey.strip()) + try: + return cf.get(subkey) + except BaseException: + raise UsageError("%s - Key '%s' unset in '%s'" % (kname, subkey, fn)) from None + + def conf_func_tf(self, arg, sect, kname): + """Returns key from Terraform state file. + + Usage: ${TF ! tfvar} + """ + if ":" in arg: + state_file, arg = [s.strip() for s in arg.split(":", 1)] + else: + state_file = self.cf.get("tf_state_file") + val = tf_load_output_var(state_file, arg) + + # configparser expects strings + if isinstance(val, str): + # work around tf dots in route53 data + val = val.strip().rstrip(".") + elif isinstance(val, int): + val = str(val) + elif isinstance(val, float): + val = repr(val) + elif isinstance(val, bool): + val = str(val).lower() + else: + raise UsageError("TF function got invalid type: %s - %s" % (kname, type(val))) + return val + + def conf_func_members(self, arg, sect, kname): + """Returns field that match patters. + + Usage: ${MEMBERS ! pat : fn : field} + """ + pats, bfn, field = arg.split(":") + fn = os.path.join(self.keys_dir, bfn.strip()) + if not os.path.isfile(fn): + raise UsageError("%s - MEMBERS file missing: %s" % (kname, fn)) + + idx = int(field, 10) + + findLabels = [] + for p in pats.split(","): + p = p.strip() + if p: + findLabels.append(p) + + res = [] + with open(fn, "r", encoding="utf8") as f: + for ln in f: + ln = ln.strip() + if not ln or ln[0] == "#": + continue + got = False + parts = ln.split(":") + user = parts[0].strip() + for label in parts[idx].split(","): + label = label.strip() + if label and label in findLabels: + got = True + break + if got and user not in res: + res.append(user) + + return ", ".join(res) + + def conf_func_tfaz(self, arg, sect, kname): + """Returns key from Terraform state file. + + Usage: ${TFAZ ! tfvar} + """ + if self.options.verbose: + printf("TFAZ: %s", arg) + if ":" in arg: + state_file, arg = [s.strip() for s in arg.split(":", 1)] + else: + state_file = self.cf.get("tf_state_file") + val = tf_load_output_var(state_file, arg) + if not isinstance(val, list): + raise UsageError("TFAZ function expects list param: %s" % kname) + if self.availability_zone < 0 or self.availability_zone >= len(val): + raise UsageError("AZ value out of range") + return val[self.availability_zone] + + def conf_func_primary_vm(self, arg, sect, kname): + """Lookup primary vm. + + Usage: ${PRIMARY_VM ! ${other_role}} + """ + raise NotImplementedError + + def conf_func_network(self, arg, sect, kname): + """Extract network address from CIDR. + """ + return str(ipaddress.ip_network(arg).network_address) + + def conf_func_netmask(self, arg, sect, kname): + """Extract 32-bit netmask from CIDR. + """ + return str(ipaddress.ip_network(arg).netmask) + + def get_ssh_kfile(self): + # load encrypted key + if self.options.ssh_key: + gpg_fn = self.options.ssh_key + else: + gpg_fn = self.cf.get("ssh_privkey_file") + gpg_fn = os.path.join(self.keys_dir, gpg_fn) + kdata = self.load_gpg_file(gpg_fn).strip() + + raw_fn = os.path.basename(gpg_fn).replace(".gpg", "") + + fn = os.path.join(self.ssh_dir, raw_fn) + + # check existing key + if os.path.isfile(fn): + with open(fn, "r", encoding="utf8") as f: + curdata = f.read().strip() + if curdata == kdata: + return fn + os.remove(fn) + + printf("Extracting keyfile %s to %s", gpg_fn, fn) + fd = os.open(fn, os.O_CREAT | os.O_WRONLY, stat.S_IRUSR | stat.S_IWUSR) + with os.fdopen(fd, "w") as f: + f.write(kdata + "\n") + return fn + + def get_ssh_known_hosts_file(self, vm_id): + return self.ssh_known_hosts + "_" + vm_id + + def ssh_cmdline(self, vm_id, use_admin=False, check_tty=False): + if self.cf.getboolean("ssh_admin_user_disabled", False): + ssh_user = self.cf.get("user") + elif use_admin: + ssh_user = self.cf.get("ssh_admin_user") + else: + ssh_user = self.cf.get("user") + + ssh_debug = "-q" + if self.options.verbose: + ssh_debug = "-v" + + ssh_options = shlex.split(self.cf.get("ssh_options", "")) + + if check_tty and sys.stdout.isatty(): # pylint:disable=no-member + ssh_options.append("-t") + + return ["ssh", ssh_debug, "-i", self.get_ssh_kfile(), "-l", ssh_user, + "-o", "UserKnownHostsFile=" + self.get_ssh_known_hosts_file(vm_id)] + ssh_options + + def vm_exec_tmux(self, vm_id, cmdline, use_admin=False, title=None): + if self.options.tmux: + tmux_command = shlex.split(self.cf.get("tmux_command")) + if title: + tmux_command = [a.replace("{title}", title) for a in tmux_command] + cmdline = tmux_command + cmdline + self.vm_exec(vm_id, cmdline, use_admin=use_admin) + + def vm_exec(self, vm_id, cmdline, stdin=None, get_output=False, check_error=True, use_admin=False): + self.log.debug("EXEC@%s: %s", vm_id, cmdline) + self.put_known_host_from_tags(vm_id) + + # only image default user works? + if not self.cf.getboolean("ssh_user_access_works", False): + use_admin = True + + if self.options.host: + # use host directly, dangerous + hostname = self.options.host + elif self.cf.getboolean("ssh_internal_ip_works", False): + vm = self.vm_lookup(vm_id) + hostname = vm.get("PrivateIpAddress") + else: + # FIXME: vm with ENI + vm = self.vm_lookup(vm_id) + #hostname = vm.get("PublicDnsName") + hostname = vm.get("PublicIpAddress") + last_idx = 600 * 1024 * 1024 * 1024 + if len(vm["NetworkInterfaces"]) > 1: + for iface in vm["NetworkInterfaces"]: + #print_json(iface) + idx = iface["Attachment"]["DeviceIndex"] + if 1 or idx < last_idx: + assoc = iface.get("Association") + if assoc: + hostname = assoc["PublicIp"] + last_idx = idx + break + eprintf("SSH to %s", hostname) + if not hostname: + self.log.error("Public DNS nor ip not yet available for node %r", vm_id) + #print_json(vm) + sys.exit(1) + + check_tty = not stdin and not get_output + ssh = self.ssh_cmdline(vm_id, use_admin=use_admin, check_tty=check_tty) + ssh.append(hostname) + if isinstance(cmdline, str): + ssh += [cmdline] + elif self.cf.getboolean("ssh_disable_quote", False): + ssh += cmdline + else: + ssh += rsh_quote(cmdline) + out = None + kwargs = {} + if stdin is not None: + kwargs["stdin"] = subprocess.PIPE + if get_output: + kwargs["stdout"] = subprocess.PIPE + self.log.debug("EXEC: cmd=%r", ssh) + self.log.debug("EXEC: kwargs=%r", kwargs) + if kwargs: + with subprocess.Popen(ssh, **kwargs) as p: + out, err = p.communicate(stdin) + ret = p.returncode + else: + ret = subprocess.call(ssh) + if ret != 0: + if check_error: + raise UsageError("Errorcode: %r" % ret) + return None + return out + + def vm_rsync(self, *args, use_admin=False): + primary_id = None + nargs = [] + ids = [] + vm_id = "?" + for a in args: + t = a.split(":", 1) + if len(t) == 1: + nargs.append(a) + continue + if t[0]: + vm_id = t[0] + elif primary_id: + vm_id = primary_id + else: + vm_id = primary_id = self.get_primary_vms()[0] + vm = self.vm_lookup(vm_id) + self.put_known_host_from_tags(vm_id) + vm = self.vm_lookup(vm_id) + if self.cf.getboolean("ssh_internal_ip_works", False): + hostname = vm.get("PrivateIpAddress") + else: + hostname = vm.get("PublicIpAddress") + a = "%s:%s" % (hostname, t[1]) + nargs.append(a) + ids.append(vm_id) + + ssh_list = self.ssh_cmdline(vm_id, use_admin=use_admin) + ssh_cmd = " ".join(rsh_quote(ssh_list)) + + cmd = ["rsync", "-rtz", "-e", ssh_cmd] + if self.options.verbose: + cmd.append("-P") + cmd += nargs + self.log.debug("rsync: %r", cmd) + run_successfully(cmd) + + _PREP_TGZ_CACHE = {} # cmd->tgz + _PREP_STAMP_CACHE = {} # cmd->stamp + + def cmd_mod_test(self, cmd_name): + """Test if payload can be created for command. + + Group: internal + """ + self.modcmd_init(cmd_name) + data = self._PREP_TGZ_CACHE[cmd_name] + print("Data size: %d bytes" % len(data)) + return data + + def cmd_mod_dump(self, cmd_name): + """Write tarball of command payload. + + Group: internal + """ + self.modcmd_init(cmd_name) + data = self._PREP_TGZ_CACHE[cmd_name] + fn = "data.tgz" + with open(fn, "wb", encoding="utf8") as f: + f.write(data) + print("%s: %d bytes" % (fn, len(data))) + + def cmd_mod_show(self, cmd_name): + """Show vmlibs used for command. + + Group: internal + """ + cwd = self.git_dir + os.chdir(cwd) + + cmd_cf = self.cf.view_section("cmd.%s" % cmd_name) + + vmlibs = cmd_cf.getlist("vmlibs") + + print("Included libs") + got = set() + for mod in vmlibs: + if mod not in got: + print("+ " + mod) + got.add(mod) + + exc_libs = [] + for mod in xglob("vmlib/**/setup.sh"): + mod = "/".join(mod.split("/")[1:-1]) + if mod not in got: + exc_libs.append(mod) + exc_libs.sort() + + print("Excluded libs") + for mod in exc_libs: + print("- " + mod) + + def has_modcmd(self, cmd_name: VmCmd): + """Return true if command is configured from config. + """ + return self.cf.has_section("cmd.%s" % cmd_name) + + def load_modcmd_args(self, args): + vms = [] + for a in args: + if a.startswith('i-'): + vms.append(a) + else: + raise UsageError("command supports only vmid args") + if vms: + return vms + return self.get_primary_vms() + + def modcmd_init(self, cmd_name: VmCmd): + """Run init script for command. + """ + cmd_cf = self.cf.view_section("cmd.%s" % cmd_name) + init_script = cmd_cf.get("init", "") + if init_script: + # let subprocess see current env + subenv = os.environ.copy() + subenv["VMTOOL_ENV_NAME"] = self.full_role + run_successfully([init_script], cwd=self.git_dir, shell=True, env=subenv) + + self.modcmd_prepare(cmd_name) + + def modcmd_prepare(self, cmd_name: VmCmd): + """Prepare data package for command. + """ + cmd_cf = self.cf.view_section("cmd.%s" % cmd_name) + stamp_dirs = cmd_cf.getlist("stamp_dirs", []) + cmd_abbr = cmd_cf.get("command_tag", "") + globs = cmd_cf.getlist("files", []) + use_admin = cmd_cf.getboolean("use_admin", False) + + self._PREP_TGZ_CACHE[cmd_name] = b"" + self.modcmd_build_tgz(cmd_name, globs, cmd_cf) + + self._PREP_STAMP_CACHE[cmd_name] = { + "cmd_abbr": cmd_abbr, + "stamp_dirs": stamp_dirs, + "stamp": self.get_stamp(), + "use_admin": use_admin, + } + + def modcmd_run(self, cmd_name, vm_ids): + """Send mod data to server and run it. + """ + info = self._PREP_STAMP_CACHE[cmd_name] + data_info = 0 + for vm_id in vm_ids: + data = self._PREP_TGZ_CACHE[cmd_name] + if not data_info: + data_info = 1 + print("RUNNING...") + self.run_mod_data(data, vm_id, use_admin=info["use_admin"], title=cmd_name) + if info["cmd_abbr"]: + self.set_stamp(vm_id, info["cmd_abbr"], info["stamp"], *info["stamp_dirs"]) + + def process_pkgs(self): + """Merge per-pkg variables into main config. + + Converts: + + [pkg.foo] + pkg_pyinstall_vmlibs = a, b + [pkg.bar] + pkg_pyinstall_vmlibs = c, d + + To: + [vm-config] + pkg_pyinstall_vmlibs = a, b, c, d + """ + cf = self.cf.cf + vmap = {} + for sect in cf.sections(): + if sect.startswith("pkg."): + for opt in cf.options(sect): + if opt not in vmap: + vmap[opt] = [] + done = set(vmap[opt]) + val = cf.get(sect, opt) + for v in val.split(","): + v = v.strip() + if v and (v not in done): + vmap[opt].append(v) + done.add(v) + for k, v in vmap.items(): + cf.set("vm-config", k, ", ".join(v)) + + # in use + def modcmd_build_tgz(self, cmd_name, globs, cmd_cf=None): + cwd = self.git_dir + os.chdir(cwd) + + defs = {} + mods_ok = True + vmlibs = [] + cert_fns = set() + if cmd_cf: + vmlibs = cmd_cf.getlist("vmlibs", []) + if vmlibs: + done_vmlibs = [] + vmdir = "vmlib" + globs = list(globs) + for mod in vmlibs: + if mod in done_vmlibs: + continue + if not mod: + continue + mdir = os.path.join(vmdir, mod) + if not os.path.isdir(mdir): + printf("Missing module: %s" % mdir) + mods_ok = False + elif not os.path.isfile(mdir + "/setup.sh"): + printf("Broken module, no setup.sh: %s" % mdir) + mods_ok = False + globs.append("vmlib/%s/**" % mod) + done_vmlibs.append(mod) + + cert_ini = os.path.join(mdir, "certs.ini") + if os.path.isfile(cert_ini): + cert_fns.add(cert_ini) + defs["vm_modules"] = "\n".join(done_vmlibs) + "\n" + globs.append("vmlib/runner.*") + globs.append("vmlib/shared/**") + if not mods_ok: + sys.exit(1) + + dst = self.make_tar_filter(defs) + + for tmp in globs: + subdir = "." + if isinstance(tmp, str): + flist = xglob(tmp) + else: + subdir = tmp[1] + if subdir and subdir != ".": + os.chdir(subdir) + else: + subdir = "." + flist = xglob(tmp[0]) + if len(tmp) > 2: + exlist = tmp[2:] + flist2 = [] + for fn in flist: + skip = False + for ex in exlist: + if fnmatch.fnmatch(fn, ex): + skip = True + break + if not skip: + flist2.append(fn) + flist = iter(flist2) + if subdir: + os.chdir(cwd) + + for fn in flist: + real_fn = os.path.join(subdir, fn) + if os.path.isdir(real_fn): + #dst.add_dir(item.path, stat.S_IRWXU, item.mtime) + pass + else: + with open(real_fn, "rb") as f: + st = os.fstat(f.fileno()) + data = f.read() + dst.add_file_data(fn, data, st.st_mode & stat.S_IRWXU, st.st_mtime) + + # pass parameters to cert.ini files + defs = {"env_name": self.env_name} + if self.role_name: + defs["role_name"] = self.role_name + if self.cf.has_section("ca-config"): + items = self.cf.view_section("ca-config").items() + defs.update(items) + + # create keys & certs + for cert_ini in cert_fns: + printf("Processing certs: %s", cert_ini) + mdir = os.path.dirname(cert_ini) + keys = load_cert_config(cert_ini, self.load_ca_keypair, defs) + for kname in keys: + key, cert, _ = keys[kname] + key_fn = "%s/%s.key" % (mdir, kname) + cert_fn = "%s/%s.crt" % (mdir, kname) + dst.add_file_data(key_fn, key, 0o600) + dst.add_file_data(cert_fn, cert, 0o600) + + # finish + dst.close() + tgz = dst.getvalue() + self._PREP_TGZ_CACHE[cmd_name] = tgz + time_printf("%s: tgz bytes: %s", cmd_name, len(tgz)) + + def load_ca_keypair(self, ca_name): + intca_dir = self.cf.get(ca_name + "_dir", "") + if not intca_dir: + intca_dir = self.cf.get("intca_dir") + pat = "%s/%s/%s_*.key.gpg" % (self.keys_dir, intca_dir, ca_name) + res = list(sorted(xglob(pat))) + if not res: + raise UsageError("CA not found: %s - %s" % (ca_name, intca_dir)) + #names = [fn.split("/")[-1] for fn in res] + idx = 0 # -1 + last_key = res[idx] + #printf("CA: using %s from [%s]", names[idx], ", ".join(names)) + last_crt = last_key.replace(".key.gpg", ".crt") + if not os.path.isfile(last_crt): + raise UsageError("CA cert not found: %s" % last_crt) + if not os.path.isfile(last_key): + raise UsageError("CA key not found: %s" % last_key) + return (last_key, last_crt) + + def run_mod_data(self, data, vm_id, use_admin=False, title=None): + + tmp_uuid = str(uuid.uuid4()) + run_user = "root" + + launcher = './tmp/%s/vmlib/runner.sh "%s"' % (tmp_uuid, vm_id) + rm_cmd = "rm -rf" + if run_user: + launcher = "sudo -nH -u %s %s" % (run_user, launcher) + rm_cmd = "sudo -nH " + rm_cmd + + time_printf("%s: Sending data - %d bytes", vm_id, len(data)) + decomp_script = 'install -d -m 711 tmp && mkdir -p "tmp/%s" && tar xzf - --warning=no-timestamp -C "tmp/%s"' % ( + tmp_uuid, tmp_uuid + ) + self.vm_exec(vm_id, ["/bin/sh", "-c", decomp_script, "decomp"], data, use_admin=use_admin) + + time_printf("%s: Running", vm_id) + cmdline = ["/bin/sh", "-c", launcher, "runit"] + self.vm_exec_tmux(vm_id, cmdline, use_admin=use_admin, title=title) + + def get_stamp(self): + commit_id = local_cmd(["git", "rev-parse", "HEAD"]) + commit_id = commit_id[:7] # same length as git log --abbrev-commit + return commit_id + + def put_known_host_from_tags(self, vm_id): + pass + + def change_cwd_adv(self): + # cd .. until there is .git + if not self._change_cwd_gittop(): + os.chdir(self.git_dir) + + def _change_cwd_gittop(self): + vmlib = "vmlib/runner.sh" + num = 0 + maxstep = 30 + pfx = "." + while True: + if os.path.isdir(os.path.join(pfx, ".git")): + if os.path.isfile(os.path.join(pfx, vmlib)): + os.chdir(pfx) + return True + else: + break + if num > maxstep: + break + pfx = os.path.join(pfx, "..") + num += 1 + return False + + def run_console_cmd(self, cmd, cmdargs): + cmd_cf = self.cf.view_section("cmd.%s" % cmd) + cmdline = cmd_cf.get("vmrun") + argparam = cmd_cf.get("vmrun_arg_param", "") + + fullcmd = shlex.split(cmdline) + vm_ids, args = self.get_vm_args(cmdargs, allow_multi=True) + if args: + if argparam: + fullcmd = fullcmd + [argparam, " ".join(args)] + else: + fullcmd = fullcmd + args + + if len(vm_ids) > 1 and self.options.tmux: + raise UsageError("Cannot use tmux in parallel") + + for vm_id in vm_ids: + if len(vm_ids) > 1: + time_printf("Running on VM %s", vm_id) + self.vm_exec_tmux(vm_id, fullcmd, title=cmd) + + def cmd_show_config(self, *args): + """Show filled config for current VM. + + Group: config + """ + desc = self.env_name + if self.role_name: + desc += "." + self.role_name + + fail = 0 + for sect in sorted(self.cf.sections()): + sect_header = f"[{sect}]" + for k in sorted(self.cf.cf.options(sect)): + if args and k not in args: + continue + if sect_header: + printf(sect_header) + sect_header = "" + try: + raw = self.cf.cf.get(sect, k, raw=True) + v = self.cf.cf.get(sect, k) + vs = v + if not self.options.verbose: + vs = vs.strip() + if vs.startswith("----") or vs.startswith("{"): + vs = vs.split("\n")[0] + else: + vs = re.sub(r"\n\s*", " ", vs) + printf("%s = %s", k, vs) + else: + printf("%s = %s [%s] (%s)", k, vs, desc, raw) + except Exception as ex: + fail = 1 + eprintf("### ERROR ### key: '%s.%s' err: %s", sect, k, str(ex)) + if not sect_header: + printf("") + if fail: + sys.exit(fail) + + def cmd_show_config_raw(self, *args): + """Show filled config for current VM. + + Group: config + """ + self.cf.cf.write(sys.stdout) + + def cmd_check_config(self): + """Check if config works. + + Group: config + """ + fail = 0 + for k in self.cf.options(): + try: + self.cf.getlist(k) + except Exception as ex: + fail = 1 + printf("key: '%s' err: %s", k, str(ex)) + if fail: + printf("--problems--") + sys.exit(fail) + + def work(self): + cmd = self.options.command + cmdargs = self.options.args + if not cmd: + raise UsageError("Need command") + #eprintf("vmtool - env_name: %s git_dir: %s", self.env_name, self.git_dir) + cmd_section = "cmd.%s" % cmd + if self.cf.has_section(cmd_section): + cf2 = self.cf.view_section(cmd_section) + if cf2.get("vmlibs", ""): + vms = self.load_modcmd_args(cmdargs) + self.change_cwd_adv() + self.modcmd_init(cmd) + self.modcmd_run(cmd, vms) + else: + self.run_console_cmd(cmd, cmdargs) + else: + super().work() + + def set_stamp(self, vm_id, name, commit_id, *dirs): + raise NotImplementedError + + def vm_lookup(self, vm_id, ignore_env=False, cache=True): + raise NotImplementedError + + def get_primary_vms(self): + raise NotImplementedError + + def get_vm_args(self, args, allow_multi=False): + """Check if args start with VM ID. + + returns: (vm-id, args) + """ + raise NotImplementedError + diff --git a/vmtool/certs.py b/vmtool/certs.py index 6bf01fd..7bbdbb7 100644 --- a/vmtool/certs.py +++ b/vmtool/certs.py @@ -2,15 +2,17 @@ """ from configparser import ConfigParser, ExtendedInterpolation + from sysca import api as sysca + from .util import as_bytes -__all__ = ['load_cert_config'] +__all__ = ["load_cert_config"] def load_cert_config(fn, load_ca, defs): cf = ConfigParser(defaults=defs, interpolation=ExtendedInterpolation(), - delimiters=['='], comment_prefixes=['#'], inline_comment_prefixes=['#']) + delimiters=["="], comment_prefixes=["#"], inline_comment_prefixes=["#"]) cf.read([fn]) return process_config(cf, load_ca) @@ -22,59 +24,59 @@ def process_config(cf, load_ca): for kname in cf.sections(): sect = dict(cf.items(kname)) - days = int(sect.get('days', '730')) - ktype = sect.get('ktype', 'ec') - alt_names = sect.get('alt_names') + days = int(sect.get("days", "730")) + ktype = sect.get("ktype", "ec") + alt_names = sect.get("alt_names") - subject = sect.get('subject') + subject = sect.get("subject") if not subject: subject = {} - common_name = sect.get('common_name') + common_name = sect.get("common_name") if not common_name: common_name = kname - common_name = common_name.rstrip('.') - subject['CN'] = common_name + common_name = common_name.rstrip(".") + subject["CN"] = common_name - sysfe_grants = sect.get('sysfe_grants') + sysfe_grants = sect.get("sysfe_grants") if sysfe_grants: sysfe_clean = [] - for rpcname in sysfe_grants.split(','): + for rpcname in sysfe_grants.split(","): rpcname = rpcname.strip() if rpcname: sysfe_clean.append(rpcname) - subject['OU'] = ':'.join(sysfe_clean) + subject["OU"] = ":".join(sysfe_clean) if not alt_names: - if '.' in common_name: - if '@' not in common_name: - alt_names = ['dns:' + common_name] + if "." in common_name: + if "@" not in common_name: + alt_names = ["dns:" + common_name] - ca_name = sect['ca_name'] + ca_name = sect["ca_name"] ca_key_fn, ca_cert_fn = load_ca(ca_name) ca_key = sysca.load_key(ca_key_fn) ca_cert = sysca.load_cert(ca_cert_fn) - usage = sect.get('usage') + usage = sect.get("usage") if not usage: - usage = ['client'] + usage = ["client"] inf = sysca.CertInfo(subject=subject, usage=usage, alt_names=alt_names) - tmp = ktype.split(':', 1) + tmp = ktype.split(":", 1) ktype = tmp[0] kparam = None if len(tmp) > 1: kparam = tmp[1] - if ktype == 'ec': - key = sysca.new_ec_key(kparam or 'secp256r1') - elif ktype == 'rsa': + if ktype == "ec": + key = sysca.new_ec_key(kparam or "secp256r1") + elif ktype == "rsa": bits = 2048 if kparam: bits = int(kparam) key = sysca.new_rsa_key(bits) else: - raise Exception('unknown key type: ' + ktype) + raise Exception("unknown key type: " + ktype) cert = sysca.create_x509_cert(ca_key, key.public_key(), inf, ca_cert, days) pem_key = as_bytes(sysca.serialize(key)) diff --git a/vmtool/config.py b/vmtool/config.py index c662dbf..934fbc4 100644 --- a/vmtool/config.py +++ b/vmtool/config.py @@ -1,20 +1,25 @@ """Nicer config class.""" +from typing import Union import os import os.path import re import socket - +from configparser import MAX_INTERPOLATION_DEPTH, ConfigParser +from configparser import Error as ConfigError from configparser import ( - NoOptionError, NoSectionError, InterpolationError, InterpolationDepthError, InterpolationSyntaxError, - Error as ConfigError, ConfigParser, MAX_INTERPOLATION_DEPTH, - Interpolation) + Interpolation, InterpolationDepthError, InterpolationError, + InterpolationSyntaxError, NoOptionError, NoSectionError, +) + +__all__ = ["Config", "NoOptionError", "ConfigError", "AdvancedInterpolation"] +class _UnSet: + pass -__all__ = ['Config', 'NoOptionError', 'ConfigError', 'AdvancedInterpolation'] +_UNSET = _UnSet() -_UNSET = object() class Config(object): """Bit improved ConfigParser. @@ -38,13 +43,13 @@ def __init__(self, main_section, filename, user_defs=None, override=None, ignore self.defs = {} else: self.defs = { - 'job_name': job_name, - 'service_name': main_section, - 'host_name': socket.gethostname(), + "job_name": job_name, + "service_name": main_section, + "host_name": socket.gethostname(), } if filename: - self.defs['config_dir'] = os.path.dirname(filename) - self.defs['config_file'] = filename + self.defs["config_dir"] = os.path.dirname(filename) + self.defs["config_file"] = filename if user_defs: self.defs.update(user_defs) @@ -52,12 +57,12 @@ def __init__(self, main_section, filename, user_defs=None, override=None, ignore self.filename = filename self.override = override or {} self.cf = ConfigParser(interpolation=AdvancedInterpolation(func_map=func_map), - delimiters=['='], comment_prefixes=['#'], inline_comment_prefixes=['#']) + delimiters=["="], comment_prefixes=["#"], inline_comment_prefixes=["#"]) if filename is None: self.cf.add_section(main_section) elif not os.path.isfile(filename): - raise ConfigError('Config file not found: ' + filename) + raise ConfigError("Config file not found: " + filename) self.reload() @@ -78,11 +83,11 @@ def reload(self): for k, v in self.override.items(): self.cf.set(self.main_section, k, v) - def get(self, key, default=_UNSET): + def get(self, key: str, default: Union[str, _UnSet] =_UNSET) -> str: """Reads string value, if not set then default.""" if not self.cf.has_option(self.main_section, key): - if default is _UNSET: + if isinstance(default, _UnSet): raise NoOptionError(key, self.main_section) return default @@ -156,7 +161,7 @@ def getdict(self, key, default=_UNSET): kv = kv.strip() if not kv: continue - tmp = kv.split(':', 1) + tmp = kv.split(":", 1) if len(tmp) > 1: k = tmp[0].strip() v = tmp[1].strip() @@ -164,7 +169,7 @@ def getdict(self, key, default=_UNSET): k = kv v = k if k in res: - raise KeyError('Duplicate key not allowed: %r' % k) + raise KeyError("Duplicate key not allowed: %r" % k) res[k] = v return res @@ -214,33 +219,33 @@ def view_section(self, section): return cf -_NEW_VAR_OPEN_RX = re.compile(r'\$\$|\$\{') -_NEW_VAR_BOTH_RX = re.compile(r'\$\$|\$\{|}') +_NEW_VAR_OPEN_RX = re.compile(r"\$\$|\$\{") +_NEW_VAR_BOTH_RX = re.compile(r"\$\$|\$\{|}") def _scan_key(cur_sect, cur_key, value, pos, lookup_func): dst = [] - while 1: + while True: m = _NEW_VAR_BOTH_RX.search(value, pos) if not m: - raise Exception('Closing brace not found') + raise Exception("Closing brace not found") pos2 = m.start() if pos2 > pos: dst.append(value[pos:pos2]) pos = m.end() tok = m.group(0) - if tok == '}': - subkey = ''.join(dst) + if tok == "}": + subkey = "".join(dst) subval = lookup_func(cur_sect, subkey) return subval, pos - elif tok == '$$': - dst.append('$') - elif tok == '${': + elif tok == "$$": + dst.append("$") + elif tok == "${": subval, pos = _scan_key(cur_sect, cur_key, value, pos, lookup_func) dst.append(subval) else: break - raise Exception('bad token') + raise Exception("bad token") def new_interpolate(cur_sect, cur_key, value, lookup_func): @@ -269,27 +274,27 @@ def new_interpolate(cur_sect, cur_key, value, lookup_func): dst.append(value[pos:pos2]) pos = m.end() tok = m.group(0) - if tok == '$$': - dst.append('$') - elif tok == '${': + if tok == "$$": + dst.append("$") + elif tok == "${": subval, pos = _scan_key(cur_sect, cur_key, value, pos, lookup_func) dst.append(subval) else: - raise InterpolationSyntaxError(cur_key, cur_sect, 'Interpolation parse error') - return ''.join(dst) + raise InterpolationSyntaxError(cur_key, cur_sect, "Interpolation parse error") + return "".join(dst) class AdvancedInterpolation(Interpolation): _func_map = None def __init__(self, func_map=None): - super(AdvancedInterpolation, self).__init__() + super().__init__() self._func_map = func_map def before_get(self, parser, section, option, value, defaults): dst = [] self._interpolate_ext_new(dst, parser, section, option, value, defaults, set()) - return ''.join(dst) + return "".join(dst) def before_set(self, parser, section, option, value): # cannot validate complex interpolation with regex @@ -304,22 +309,22 @@ def _interpolate_ext_new(self, dst, parser, section, option, rawval, defaults, l xloop = (section, option) if xloop in loop_detect: - raise InterpolationError(option, section, 'Loop detected: %r in %r' % (xloop, loop_detect)) + raise InterpolationError(option, section, "Loop detected: %r in %r" % (xloop, loop_detect)) loop_detect.add(xloop) def lookup_helper(lk_section, lk_option): - if '!' in lk_option: - funcname, val = lk_option.split('!', 1) + if "!" in lk_option: + funcname, val = lk_option.split("!", 1) func = None if self._func_map: func = self._func_map.get(funcname.strip()) if not func: - raise InterpolationError(option, section, 'Unknown interpolation function: %r' % funcname) + raise InterpolationError(option, section, "Unknown interpolation function: %r" % funcname) return func(val.strip(), lk_section, lk_option) # normal fetch - if ':' in lk_option: - ksect, key = lk_option.split(':', 1) + if ":" in lk_option: + ksect, key = lk_option.split(":", 1) ksect, key = ksect.strip(), key.strip() use_vars = None else: @@ -328,10 +333,10 @@ def lookup_helper(lk_section, lk_option): key = parser.optionxform(key) newpart = parser.get(ksect, key, raw=True, vars=use_vars) if newpart is None: - raise InterpolationError(key, ksect, 'Key referenced is None') + raise InterpolationError(key, ksect, "Key referenced is None") dst = [] self._interpolate_ext_new(dst, parser, ksect, key, newpart, defaults, loop_detect) - return ''.join(dst) + return "".join(dst) val = new_interpolate(section, option, rawval, lookup_helper) dst.append(val) diff --git a/vmtool/envconfig.py b/vmtool/envconfig.py index 9702a74..b79bcc0 100644 --- a/vmtool/envconfig.py +++ b/vmtool/envconfig.py @@ -3,24 +3,24 @@ import os import sys - from configparser import ConfigParser -from vmtool.config import Config, AdvancedInterpolation -__all__ = ['load_env', 'load_env_config'] +from vmtool.config import AdvancedInterpolation, Config + +__all__ = ["load_env", "load_env_config"] def find_gittop(): - vmlib = 'vmlib/runner.sh' + vmlib = "vmlib/runner.sh" pos = os.getcwd() - while pos != '/': - if os.path.isdir(os.path.join(pos, '.git')): + while pos != "/": + if os.path.isdir(os.path.join(pos, ".git")): if os.path.isfile(os.path.join(pos, vmlib)): return pos - pos = os.path.normpath(os.path.join(pos, '..')) + pos = os.path.normpath(os.path.join(pos, "..")) - pos = os.environ.get('VMTOOL_GIT_DIR', '') - if pos and os.path.isdir(os.path.join(pos, '.git')): + pos = os.environ.get("VMTOOL_GIT_DIR", "") + if pos and os.path.isdir(os.path.join(pos, ".git")): if os.path.isfile(os.path.join(pos, vmlib)): return pos @@ -32,28 +32,28 @@ def load_env(args): env_name = None role_name = None for a in args: - if a[0] != '-': + if a[0] != "-": break - elif a.startswith('--env'): - tmp = a.split('=', 1) + elif a.startswith("--env"): + tmp = a.split("=", 1) if len(tmp) != 2: print("Cannot parse --env arg") sys.exit(1) env_name = tmp[1] - elif a.startswith('--role'): - tmp = a.split('=', 1) + elif a.startswith("--role"): + tmp = a.split("=", 1) if len(tmp) != 2: print("Cannot parse --role arg") sys.exit(1) role_name = tmp[1] if not env_name: - env_name = os.environ.get('VMTOOL_ENV_NAME') + env_name = os.environ.get("VMTOOL_ENV_NAME") if not env_name: print("Need to use --env or set VMTOOL_ENV_NAME") sys.exit(1) if role_name: - return env_name.split('.')[0] + '.' + role_name + return env_name.split(".")[0] + "." + role_name return env_name @@ -61,16 +61,16 @@ def load_deps(section_name, fn, defs, seen_files): basedir = os.path.dirname(fn) cf = ConfigParser(interpolation=AdvancedInterpolation()) cf.read([fn]) - if cf.has_option(section_name, 'config_depends'): - deps = cf.get(section_name, 'config_depends') - for dep_fn in deps.split(','): + if cf.has_option(section_name, "config_depends"): + deps = cf.get(section_name, "config_depends") + for dep_fn in deps.split(","): dep_fn = dep_fn.strip() if not dep_fn: continue fqfn = os.path.normpath(os.path.join(basedir, dep_fn)) if fqfn not in seen_files: if not os.path.isfile(fqfn): - raise IOError('load_deps: config missing: %s' % dep_fn) + raise IOError("load_deps: config missing: %s" % dep_fn) seen_files.add(fqfn) @@ -82,35 +82,35 @@ def load_deps(section_name, fn, defs, seen_files): def load_env_config(env_name, func_map=None): if not env_name: - raise Exception('load_env_config: env missing') + raise Exception("load_env_config: env missing") git_dir = find_gittop() - conf_dir = os.path.join(git_dir, 'conf') + conf_dir = os.path.join(git_dir, "conf") vmcf_fn = os.path.join(conf_dir, "config_%s.ini" % env_name) if not os.path.isfile(vmcf_fn): print("Config not found: %s" % vmcf_fn) sys.exit(1) - for k in ('VMTOOL_USERNAME', 'USER', 'LOGNAME'): + for k in ("VMTOOL_USERNAME", "USER", "LOGNAME"): fl_user = os.environ.get(k) if fl_user: break if not fl_user: - fl_user = 'please_set_VMTOOL_USERNAME' + fl_user = "please_set_VMTOOL_USERNAME" - role_name = '' - if '.' in env_name: - env_name, role_name = env_name.split('.') + role_name = "" + if "." in env_name: + env_name, role_name = env_name.split(".") defs = { - 'env_name': env_name, - 'role_name': role_name, - 'git_dir': git_dir, - 'conf_dir': conf_dir, - 'user': fl_user, + "env_name": env_name, + "role_name": role_name, + "git_dir": git_dir, + "conf_dir": conf_dir, + "user": fl_user, } - main_section = 'vm-config' + main_section = "vm-config" deps = list(load_deps(main_section, vmcf_fn, defs, set())) deps.append(vmcf_fn) diff --git a/vmtool/run.py b/vmtool/run.py index c283d6f..750615d 100644 --- a/vmtool/run.py +++ b/vmtool/run.py @@ -5,9 +5,9 @@ Tool for managing AWS instances. """ -import sys import importlib import shlex +import sys from vmtool.envconfig import load_env, load_env_config @@ -15,9 +15,9 @@ def run_command(cf, args): """Command is implemented by class specified in vmtool_profile. """ - mod_name = cf.get('vmtool_profile') + mod_name = cf.get("vmtool_profile") mod = importlib.import_module(mod_name) - script = mod.VmTool('vmtool', args) + script = mod.VmTool("vmtool", args) script.start() sys.stdout.flush() sys.stderr.flush() @@ -40,12 +40,12 @@ def run_alias(env_name, alias, cmd, cmdpos, args, options, is_cmd): """ cmd_prefix = args[:cmdpos] - cmd_self = args[cmdpos:cmdpos+1] - cmd_suffix = args[cmdpos+1:] - for elem in alias.split(','): + cmd_self = args[cmdpos:cmdpos + 1] + cmd_suffix = args[cmdpos + 1:] + for elem in alias.split(","): elem = elem.strip() - if ':' in elem: - role, acmd = elem.split(':', 1) + if ":" in elem: + role, acmd = elem.split(":", 1) role, acmd = role.strip(), acmd.strip() xcmd = shlex.split(acmd) elif is_cmd: @@ -57,11 +57,11 @@ def run_alias(env_name, alias, cmd, cmdpos, args, options, is_cmd): xargs = cmd_prefix + xcmd + cmd_suffix if role: - xargs = ['--role=' + role] + xargs + xargs = ["--role=" + role] + xargs - extra = '' + extra = "" if options: - extra = ' [%s]' % ' '.join(options) + extra = " [%s]" % " ".join(options) xargs = options + xargs env_name = load_env(xargs) @@ -80,11 +80,11 @@ def main(): cmdpos = None for i, a in enumerate(args): if cmd is None: - if a[0] != '-': + if a[0] != "-": cmd = a cmdpos = i - elif a[0] == '-': - args.insert(cmdpos + 1, '--') + elif a[0] == "-": + args.insert(cmdpos + 1, "--") break # load config @@ -92,23 +92,23 @@ def main(): cf = load_env_config(env_name) # does role need replacing - alias_sect = 'alias.%s' % cmd + alias_sect = "alias.%s" % cmd if cmd and cf.has_section(alias_sect): - if cf.cf.has_option(alias_sect, 'roles'): - alias = cf.cf.get(alias_sect, 'roles') + if cf.cf.has_option(alias_sect, "roles"): + alias = cf.cf.get(alias_sect, "roles") is_cmd = False else: - alias = cf.cf.get(alias_sect, 'commands') + alias = cf.cf.get(alias_sect, "commands") is_cmd = True options = [] - if cf.cf.has_option(alias_sect, 'options'): - options = shlex.split(cf.cf.get(alias_sect, 'options')) + if cf.cf.has_option(alias_sect, "options"): + options = shlex.split(cf.cf.get(alias_sect, "options")) run_alias(env_name, alias, cmd, cmdpos, args, options, is_cmd) else: run_command(cf, args) -if __name__ == '__main__': +if __name__ == "__main__": try: main() except KeyboardInterrupt: diff --git a/vmtool/scripting.py b/vmtool/scripting.py index be9e8b7..8c139cb 100644 --- a/vmtool/scripting.py +++ b/vmtool/scripting.py @@ -2,16 +2,16 @@ """Useful functions and classes for Python command-line tools. """ -import sys +import argparse import inspect import logging import logging.config import logging.handlers -import argparse +import sys from vmtool.config import Config -__all__ = ['EnvScript', 'UsageError'] +__all__ = ["EnvScript", "UsageError"] class UsageError(Exception): @@ -36,7 +36,7 @@ class EnvScript(object): cf_defaults = {} # setup logger here, this allows override by subclass - log = logging.getLogger('EnvScript') + log = logging.getLogger("EnvScript") def __init__(self, service_name, args): """Script setup. @@ -79,7 +79,7 @@ def __init__(self, service_name, args): self.cf_override = {} if self.options.set: for a in self.options.set: - k, v = a.split('=', 1) + k, v = a.split("=", 1) self.cf_override[k.strip()] = v.strip() # read config file @@ -118,7 +118,7 @@ def init_argparse(self, parser=None): def reload(self): """Reload config. """ - self.log.debug('reload') + self.log.debug("reload") # avoid double loading on startup if not self.cf: self.cf = self.load_config() @@ -142,17 +142,17 @@ def run_func_safely(self, func, prefer_looping=False): return func() except UsageError as d: self.log.error(str(d)) - except MemoryError as d: + except MemoryError: try: # complex logging may not succeed self.log.exception("Job %s out of memory, exiting", self.job_name) except MemoryError: self.log.fatal("Out of memory") except SystemExit as d: raise d - except KeyboardInterrupt as d: + except KeyboardInterrupt: sys.exit(1) - except Exception as d: - self.log.exception('Command failed') + except Exception: + self.log.exception("Command failed") # done sys.exit(1) @@ -169,12 +169,13 @@ def work(self): cmdargs = self.options.args # find function - fname = "cmd_" + cmd.replace('-', '_') + fname = "cmd_" + cmd.replace("-", "_") if not hasattr(self, fname): - self.log.error('bad subcommand, see --help for usage') + self.log.error("bad subcommand, see --help for usage") sys.exit(1) fn = getattr(self, fname) b = inspect.signature(fn).bind(*cmdargs) fn(*b.args, **b.kwargs) + diff --git a/vmtool/simple.py b/vmtool/simple.py new file mode 100644 index 0000000..93d9985 --- /dev/null +++ b/vmtool/simple.py @@ -0,0 +1,225 @@ +"""Simple backend for vmtool. +""" + +import shlex +import sys + +from vmtool.base import VmToolBase +from vmtool.scripting import UsageError +from vmtool.util import eprintf, time_printf + + +class VmTool(VmToolBase): + __doc__ = __doc__ + + _vm_map = None + + def conf_func_tf(self, arg, sect, kname): + return "" + + def conf_func_primary_vm(self, arg, sect, kname): + """Lookup primary vm. + + Usage: ${PRIMARY_VM ! ${other_role}} + """ + vm = self.get_primary_for_role(arg) + return vm["InstanceId"] + + def vm_lookup(self, vm_id, ignore_env=False, cache=True): + if self._vm_map is None: + self._vm_map = {} + if vm_id in self._vm_map and cache: + return self._vm_map[vm_id] + + #res = self.cf.cf.get("primary-vms", vm_id) + vm = {"id": vm_id, "PublicIpAddress": vm_id, "PrivateIpAddress": vm_id, "NetworkInterfaces": []} + self._vm_map[vm_id] = vm + return vm + + def get_env_filters(self): + """Return default filters based on command-line swithces. + """ + return self.make_env_filters(role_name=self.role_name, running=self.options.running, allenvs=self.options.all) + + def make_env_filters(self, role_name=None, running=True, allenvs=False): + """Return filters for instance listing. + """ + filters = [] + + if not allenvs: + filters.append({"Name": "tag:Env", "Values": [self.env_name]}) + if role_name or self.role_name: + filters.append({"Name": "tag:Role", "Values": [role_name or self.role_name]}) + + if running: + filters.append({"Name": "instance-state-name", "Values": ["running"]}) + + return filters + + def get_primary_for_role(self, role_name, instance_id=None): + vm_id = self.cf.cf.get("primary-vms", role_name) + if vm_id: + return self.vm_lookup(vm_id) + raise UsageError("Primary VM not found: %s" % role_name) + + def get_primary_vms(self): + if self.options.all_role_vms: + return self.get_all_role_vms() + if self.options.all_role_fo_vms or self.options.earlier_fo_vms or self.options.latest_fo_vm: + return self.get_all_role_fo_vms() + + main_vms = self._get_primary_vms() + if main_vms: + eprintf("Primary VM for %s is %s", self.full_role, ",".join(main_vms)) + return main_vms + raise UsageError("Primary VM not found") + + def _get_primary_vms(self): + #return [self.role_name] + vm_id = self.cf.cf.get("primary-vms", self.role_name) + if vm_id: + return [vm_id] + return [] + + def get_all_role_vms(self): + if not self.role_name: + raise UsageError("Not in a role-based env") + + all_vms = self._get_primary_vms() + if not all_vms: + eprintf("No running VMs for %s", self.full_role) + else: + eprintf("Running VMs for %s: %s", self.full_role, " ".join(all_vms)) + return all_vms + + def get_all_role_fo_vms(self): + if not self.role_name: + raise UsageError("Not in a role-based env") + + eprintf("No running failover VMs for %s", self.full_role) + return [] + + def _check_tags(self, taglist, force_role=False, role_name=None): + if role_name is None: + role_name = self.role_name + if not taglist: + return False + + gotenv = gotrole = False + for tag in taglist: + if tag["Key"] == "Env": + gotenv = True + if tag["Value"] != self.env_name: + return False + if tag["Key"] == "Role": + gotrole = True + if role_name and tag["Value"] != role_name: + return False + if not gotenv: + return False + if not gotrole and role_name: + return False + elif force_role and not role_name: + return False + return True + + def get_vm_args(self, args, allow_multi=False): + """Check if args start with VM ID. + + returns: (vm-id, args) + """ + if args and args[0][:2] == "i-": + vm_list = [args[0]] + args = args[1:] + else: + vm_list = self.get_primary_vms() + + if allow_multi: + return vm_list, args + + if len(vm_list) != 1: + raise UsageError("Command does not support multiple vms") + return vm_list[0], args + + def ssh_cmdline(self, vm_id, use_admin=False, check_tty=False): + if check_tty and sys.stdout.isatty(): # pylint:disable=no-member + cmd = self.cf.get("ssh_tty_cmd") + else: + cmd = self.cf.get("ssh_connect_cmd") + return shlex.split(cmd) + + def cmd_ssh(self, *args): + """SSH to VM and run command (optional). + + Group: admin + """ + vm_ids, args = self.get_vm_args(args, allow_multi=True) + for vm_id in vm_ids: + if len(vm_ids) > 1: + time_printf("Running on VM %s", vm_id) + if len(args) == 1: + self.vm_exec_tmux(vm_id, args[0], title="ssh") + else: + self.vm_exec_tmux(vm_id, args or ["bash", "-l"], title="ssh") + + def cmd_ssh_admin(self, *args): + """SSH to VM and run command (optional). + + Group: admin + """ + vm_ids, args = self.get_vm_args(args, allow_multi=True) + for vm_id in vm_ids: + if len(vm_ids) > 1: + time_printf("Running on VM %s", vm_id) + if len(args) == 1: + self.vm_exec_tmux(vm_id, args[0], use_admin=True, title="ssh-admin") + else: + self.vm_exec_tmux(vm_id, args or [], use_admin=True, title="ssh-admin") + + def cmd_rsync(self, *args): + """Use rsync to transport files. + + Group: admin + """ + if len(args) < 2: + raise UsageError("Need source and dest for rsync") + self.vm_rsync(*args) + + def cmd_tmux_attach(self, vm_id): + """Attach to regular non-admin session. + + Group: vm + """ + cmdline = shlex.split(self.cf.get("tmux_attach")) + self.vm_exec(vm_id, cmdline, None, use_admin=False) + + def cmd_tmux_attach_admin(self, vm_id): + """Attach to admin session. + + Group: vm + """ + cmdline = shlex.split(self.cf.get("tmux_attach")) + self.vm_exec(vm_id, cmdline, None, use_admin=True) + + def cmd_show_primary(self): + """Show primary VM id. + + Group: internal + """ + ids = self.get_primary_vms() + print(ids[0]) + + def load_modcmd_args(self, args): + vms = [] + for a in args: + if a.startswith("i-"): + vms.append(a) + else: + raise UsageError("command supports only vmid args") + if vms: + return vms + return self.get_primary_vms() + + def set_stamp(self, vm_id, name, commit_id, *dirs): + return + diff --git a/vmtool/tarball.py b/vmtool/tarball.py index 148abf6..263da12 100644 --- a/vmtool/tarball.py +++ b/vmtool/tarball.py @@ -1,14 +1,14 @@ """Tarball creation with data filter. """ -import sys import io import os +import stat +import sys import tarfile import time -import stat -__all__ = ['TarBall'] +__all__ = ["TarBall"] # mode for normal files TAR_FILE_MODE = stat.S_IRUSR | stat.S_IWUSR | stat.S_IRGRP | stat.S_IROTH @@ -22,8 +22,9 @@ class TarBall(object): def __init__(self): + # pylint:disable=consider-using-with self.buf = io.BytesIO() - self.tf = tarfile.open('buf.tgz', 'w|gz', self.buf, format=tarfile.PAX_FORMAT) + self.tf = tarfile.open("buf.tgz", "w|gz", self.buf, format=tarfile.PAX_FORMAT) def filter_data(self, fname, data): """Overridable function.""" @@ -44,7 +45,7 @@ def add_file(self, fpath): if st.st_mode & stat.S_IXUSR > 0: mode = TAR_EXEC_MODE - with open(fpath, 'rb') as f: + with open(fpath, "rb") as f: data = f.read() self.add_file_data(fpath, data, mode) @@ -52,7 +53,7 @@ def add_file(self, fpath): def add_file_data(self, fpath, data, mode=TAR_FILE_MODE, mtime=None): """Add data as filename.""" origdata = data - fpath = fpath.replace('\\', '/') + fpath = fpath.replace("\\", "/") fpath, data = self.filter_data(fpath, data) if not fpath: return @@ -62,17 +63,17 @@ def add_file_data(self, fpath, data, mode=TAR_FILE_MODE, mtime=None): inf.mtime = mtime or time.time() inf.uid = 1000 inf.gid = 1000 - inf.uname = 'nobody' - inf.gname = 'nobody' + inf.uname = "nobody" + inf.gname = "nobody" inf.mode = mode - base = fpath.split('/')[-1] + base = fpath.split("/")[-1] ext = None - if '.' in base: - ext = base.split('.')[-1] - if ext == 'sh': + if "." in base: + ext = base.split(".")[-1] + if ext == "sh": inf.mode = TAR_EXEC_MODE - elif base in ('fl_start_services', 'job_setup', 'user_setup'): + elif base in ("fl_start_services", "job_setup", "user_setup"): inf.mode = TAR_EXEC_MODE inf.size = len(data) @@ -81,17 +82,17 @@ def add_file_data(self, fpath, data, mode=TAR_FILE_MODE, mtime=None): def add_dir(self, dpath, mode=TAR_DIR_MODE, mtime=None): """Add directory entry.""" - dpath = dpath.replace('\\', '/') + dpath = dpath.replace("\\", "/") dpath, data = self.filter_data(dpath, None) if not dpath: return - inf = tarfile.TarInfo(dpath + '/') + inf = tarfile.TarInfo(dpath + "/") inf.mtime = mtime or time.time() inf.uid = 1000 inf.gid = 1000 - inf.uname = 'nobody' - inf.gname = 'nobody' + inf.uname = "nobody" + inf.gname = "nobody" inf.mode = mode inf.type = tarfile.DIRTYPE self.tf.addfile(inf) @@ -112,10 +113,12 @@ def main(): for fn in sys.argv[1:]: tb.add_path(fn) tb.close() - sys.stdout.write(tb.getvalue()) - sys.stdout.flush() + + with os.fdopen(sys.stdout.fileno(), "wb", closefd=False) as out: + out.write(tb.getvalue()) + out.flush() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/vmtool/tarfilter.py b/vmtool/tarfilter.py index a582501..1a94724 100644 --- a/vmtool/tarfilter.py +++ b/vmtool/tarfilter.py @@ -1,33 +1,33 @@ """TarBall that filters input data. """ -import sys -import re -import os -import subprocess import binascii import ipaddress import logging +import os +import re +import subprocess +import sys -from vmtool.util import hmac_sha256, as_bytes from vmtool.tarball import TarBall +from vmtool.util import as_bytes, hmac_sha256 def gen_password(username, password_master): if not password_master: raise Exception("password_master not configured") - src = b"\000%s\377" % username.encode('utf8') + src = b"\000%s\377" % username.encode("utf8") h = hmac_sha256(password_master, src) - return binascii.b2a_base64(h).decode('ascii').rstrip().rstrip('=') + return binascii.b2a_base64(h).decode("ascii").rstrip().rstrip("=") class TarFilter(TarBall): - tag = re.compile(b'{{ ( [^{}]+ ) }}', re.X) + tag = re.compile(b"{{ ( [^{}]+ ) }}", re.X) _password_master = None def __init__(self, key_lookup_func, key_lookup_arg): - super(TarFilter, self).__init__() + super().__init__() self.live = 0 self.key_lookup_func = key_lookup_func self.key_lookup_arg = key_lookup_arg @@ -38,36 +38,36 @@ def set_live(self, is_live): def add_output(self, fpath, cmd): """Read command stdout, check for exit code. """ - p = subprocess.Popen(cmd, stdout=subprocess.PIPE) - buf = p.communicate()[0] - if p.returncode != 0: - raise Exception("command failed: %r" % cmd) + with subprocess.Popen(cmd, stdout=subprocess.PIPE) as p: + buf = p.communicate()[0] + if p.returncode != 0: + raise Exception("command failed: %r" % cmd) self.add_file_data(fpath, buf) def filter_data(self, fname, data): - #pfx = 'setup-final' - if fname[:2] == './': + #pfx = "setup-final" + if fname[:2] == "./": fname = fname[2:] - if fname == '.': + if fname == ".": return fname, None - #fname2 = 'setup-final/' + fname + #fname2 = "setup-final/" + fname fname2 = fname p, ext = os.path.splitext(fname) - if ext in ('.pyc', '.doctree', '.pickle', '.swp'): + if ext in (".pyc", ".doctree", ".pickle", ".swp"): return None, None - if fname.startswith('doc/html'): + if fname.startswith("doc/html"): return None, None - if fname.startswith('doc/tmp'): + if fname.startswith("doc/tmp"): return None, None - if ext in ('.png', '.jpg', '.gif'): - if fname.startswith('doc/'): + if ext in (".png", ".jpg", ".gif"): + if fname.startswith("doc/"): return None, None return fname2, data if not data: return fname2, data # seems text data, check if it uses templating - if ext == '.tmpl': + if ext == ".tmpl": fname2 = p elif data.find(b"VMTMPL", 0, 100) < 0: return fname2, data @@ -85,13 +85,13 @@ def filter_data(self, fname, data): k = m.group(1).strip() try: v = self.key_lookup(k, fname) - except: + except BaseException: logging.exception("%s: %s", fname, k) sys.exit(1) res.append(as_bytes(v)) res.append(data[pos:]) - data = b''.join(res) + data = b"".join(res) return fname2, data def _lazy_lookup(self, key, fname): @@ -102,56 +102,56 @@ def _lazy_lookup(self, key, fname): def _gen_password(self, username, fname): if self._password_master is None: - self._password_master = self._lazy_lookup('password_master', fname) + self._password_master = self._lazy_lookup("password_master", fname) return gen_password(username, self._password_master) def key_lookup(self, key, fname): if isinstance(key, bytes): - key = key.decode('utf8') - t = key.split(':', 1) + key = key.decode("utf8") + t = key.split(":", 1) if len(t) == 1: return self._lazy_lookup(key, fname) kfunc = t[0].strip() arg = t[1].strip() - if kfunc == 'PSW': + if kfunc == "PSW": psw = self._gen_password(arg, fname) return psw - elif kfunc == 'LIVE': - v1, v2 = arg.split('|', 1) + elif kfunc == "LIVE": + v1, v2 = arg.split("|", 1) if self.live: return v1.strip() else: return v2.strip() - elif kfunc == 'ALT': - v1, v2 = arg.split('|', 1) + elif kfunc == "ALT": + v1, v2 = arg.split("|", 1) v1, v2 = v1.strip(), v2.strip() try: val = self.key_lookup(v1, fname).strip() - except: + except BaseException: val = None return val or v2 - elif kfunc == 'CLEAN': + elif kfunc == "CLEAN": v = self.key_lookup(arg, fname) - v = ' '.join(v.split()) + v = " ".join(v.split()) return v.strip() - elif kfunc == 'CLEANWS': + elif kfunc == "CLEANWS": v = self.key_lookup(arg, fname) - v = ''.join(v.split()) + v = "".join(v.split()) return v.strip() - elif kfunc == 'STRIP': + elif kfunc == "STRIP": v = self.key_lookup(arg, fname) return v.strip() - elif kfunc == 'RXESC': + elif kfunc == "RXESC": v = self.key_lookup(arg, fname).strip() - return v.replace('\\', '\\\\').replace('.', '\\.') - elif kfunc == 'SPLIST': + return v.replace("\\", "\\\\").replace(".", "\\.") + elif kfunc == "SPLIST": v = self.key_lookup(arg, fname).strip() - vals = [e.strip() for e in v.split(',') if e.strip()] - return ' '.join(vals) - elif kfunc == 'NETWORK': + vals = [e.strip() for e in v.split(",") if e.strip()] + return " ".join(vals) + elif kfunc == "NETWORK": v = self.key_lookup(arg, fname).strip() return str(ipaddress.ip_network(v).network_address) - elif kfunc == 'NETMASK': + elif kfunc == "NETMASK": v = self.key_lookup(arg, fname).strip() return str(ipaddress.ip_network(v).netmask) else: diff --git a/vmtool/terra.py b/vmtool/terra.py index 8b29d72..feadca0 100644 --- a/vmtool/terra.py +++ b/vmtool/terra.py @@ -3,59 +3,65 @@ import json -__all__ = ['tf_load_output_var', 'tf_load_all_vars'] +__all__ = ["tf_load_output_var", "tf_load_all_vars"] def tf_load_output_var(state_file, name): keys = tf_load_all_vars(state_file) if name not in keys: - raise KeyError('%s: TF module does not have output: %s' % (state_file, name)) + raise KeyError("%s: TF module does not have output: %s" % (state_file, name)) return keys[name] + def _load_state_v3(state): res = {} - for mod in state['modules']: - path = mod['path'] - if path == ['root']: + for mod in state["modules"]: + path = mod["path"] + if path == ["root"]: # top-level resource - resmap = mod.get('resources', {}) + resmap = mod.get("resources", {}) for resname in resmap: - attmap = resmap[resname].get('primary', {}).get('attributes', {}) + attmap = resmap[resname].get("primary", {}).get("attributes", {}) for attname in attmap: - fqname = '%s.%s' % (resname, attname) + fqname = "%s.%s" % (resname, attname) res[fqname] = attmap[attname] - elif path[0] == 'root': + elif path[0] == "root": # module - mpath = '.'.join(path[1:]) - modvars = mod.get('outputs', {}) + mpath = ".".join(path[1:]) + modvars = mod.get("outputs", {}) for keyname in modvars: - fqname = 'module.%s.%s' % (mpath, keyname) - res[fqname] = modvars[keyname]['value'] + fqname = "module.%s.%s" % (mpath, keyname) + res[fqname] = modvars[keyname]["value"] return res + def flatten(dst, k, v): if isinstance(v, dict): for kx, vx in v.items(): - flatten(dst, '%s.%s' % (k, kx), vx) + flatten(dst, "%s.%s" % (k, kx), vx) else: dst[k] = v return dst + def _load_state_v4(state): res = {} - for k, v in state['outputs'].items(): - flatten(res, k, v['value']) + for k, v in state["outputs"].items(): + flatten(res, k, v["value"]) return res + _tf_cache = {} + def tf_load_all_vars(state_file): if state_file in _tf_cache: return _tf_cache[state_file] - state = json.load(open(state_file)) - if state['version'] == 3: + with open(state_file, encoding="utf8") as f: + state = json.load(f) + if state["version"] == 3: res = _load_state_v3(state) - elif state['version'] == 4: + elif state["version"] == 4: res = _load_state_v4(state) else: raise TypeError("Unsupported version of state") diff --git a/vmtool/util.py b/vmtool/util.py index ec42bb4..e3a3a2d 100644 --- a/vmtool/util.py +++ b/vmtool/util.py @@ -1,7 +1,8 @@ """Utility functions for vmtool. """ -import sys +import binascii +import datetime import errno import gzip import hashlib @@ -12,24 +13,22 @@ import os import re import subprocess +import sys import time -import binascii -import datetime - -__all__ = ['hash_known_host', 'ssh_add_known_host', 'parse_console', 'fmt_dur', - 'gz_compress', 'rsh_quote', 'printf', 'eprintf'] +__all__ = ["hash_known_host", "ssh_add_known_host", "parse_console", "fmt_dur", + "gz_compress", "rsh_quote", "printf", "eprintf"] def as_unicode(s): if not isinstance(s, bytes): return s - return s.decode('utf8') + return s.decode("utf8") def as_bytes(s): if not isinstance(s, bytes): - return s.encode('utf8') + return s.encode("utf8") return s @@ -45,13 +44,19 @@ def hash_known_host(host, old_entry=None): """Hash hostname or ip for SSH known_hosts file. """ if old_entry: - salt = binascii.a2b_base64(old_entry[3:].split('|', 1)[0]) + salt = binascii.a2b_base64(old_entry[3:].split("|", 1)[0]) else: salt = os.urandom(20) h = hmac.new(salt, as_bytes(host), hashlib.sha1).digest() s64 = as_unicode(encode_base64(salt)) h64 = as_unicode(encode_base64(h)) - return '|1|%s|%s' % (s64, h64) + return "|1|%s|%s" % (s64, h64) + + +def load_lines(fn): + with open(fn, encoding="utf8") as f: + for ln in f.readlines(): + yield ln def ssh_add_known_host(kh_file, dns, ip, ktype, kval, vm_id, hash_hosts=True): @@ -59,7 +64,7 @@ def ssh_add_known_host(kh_file, dns, ip, ktype, kval, vm_id, hash_hosts=True): if not os.path.isdir(fdir): os.makedirs(fdir, 0o700, exist_ok=True) - space_rc = re.compile('[ \t]+') + space_rc = re.compile("[ \t]+") new_file = [] if os.path.isfile(kh_file): found_ip = False @@ -67,10 +72,10 @@ def ssh_add_known_host(kh_file, dns, ip, ktype, kval, vm_id, hash_hosts=True): drops = False cur_key = (ktype, kval) lines = 0 - for ln in open(kh_file).readlines(): + for ln in load_lines(kh_file): lines += 1 xln = ln.strip() - if not xln or xln[0] == '#': + if not xln or xln[0] == "#": new_file.append(ln) continue t = space_rc.split(xln) @@ -78,9 +83,9 @@ def ssh_add_known_host(kh_file, dns, ip, ktype, kval, vm_id, hash_hosts=True): kt = t[1].strip() kv = t[2].strip() old_key = (kt, kv) - if kt != ktype and ktype != 'ecdsa-sha2-nistp256': + if kt != ktype and ktype != "ecdsa-sha2-nistp256": pass - elif adr.startswith('|1|'): + elif adr.startswith("|1|"): if ip and adr == hash_known_host(ip, adr): if old_key == cur_key: found_ip = True @@ -130,7 +135,7 @@ def ssh_add_known_host(kh_file, dns, ip, ktype, kval, vm_id, hash_hosts=True): ipln = "%s %s %s %s\n" % (ip, ktype, kval, vm_id) new_file.append(ipln) - write_atomic(kh_file, ''.join(new_file)) + write_atomic(kh_file, "".join(new_file)) # @@ -138,7 +143,7 @@ def ssh_add_known_host(kh_file, dns, ip, ktype, kval, vm_id, hash_hosts=True): # -def parse_console(vm_console, key_types=('ssh-ed25519', 'ecdsa-sha2-nistp256')): +def parse_console(vm_console, key_types=("ssh-ed25519", "ecdsa-sha2-nistp256")): """Parse SSH keys from AWS vm console. """ begin = "-----BEGIN SSH HOST KEY KEYS-----" @@ -158,14 +163,14 @@ def parse_console(vm_console, key_types=('ssh-ed25519', 'ecdsa-sha2-nistp256')): # parse lines klines = vm_console[p1 + len(begin):p2] - for kln in klines.split('\n'): - pos = kln.find('ecdsa-') + for kln in klines.split("\n"): + pos = kln.find("ecdsa-") if pos < 0: - pos = kln.find('ssh-') + pos = kln.find("ssh-") if pos < 0: continue kln = kln[pos:].strip() - ktype, kcert, kname = kln.split(' ') + ktype, kcert, kname = kln.split(" ") if ktype not in key_types: continue keys.append((ktype, kcert)) @@ -191,9 +196,9 @@ def gz_compress(filename, data): def rsh_quote(args): if not isinstance(args, (tuple, list)): - raise ValueError('rsh_quote needs list of args') + raise ValueError("rsh_quote needs list of args") res = [] - rc_bad = re.compile(r'[^\-\w.,:_=/]') + rc_bad = re.compile(r"[^\-\w.,:_=/]") for a in args: if rc_bad.search(a): a = "'%s'" % a.replace("'", "'\\''") @@ -211,14 +216,14 @@ def hmac_sha256(key, data): def printf(msg, *args): if args: msg = msg % args - sys.stdout.write(msg + '\n') + sys.stdout.write(msg + "\n") sys.stdout.flush() def eprintf(msg, *args): if args: msg = msg % args - sys.stderr.write(msg + '\n') + sys.stderr.write(msg + "\n") sys.stderr.flush() @@ -227,7 +232,7 @@ def time_printf(msg, *args): tstr = "%02d:%02d:%02d *** " % (t.tm_hour, t.tm_min, t.tm_sec) if args: msg = msg % args - sys.stdout.write(tstr + msg + '\n') + sys.stdout.write(tstr + msg + "\n") sys.stdout.flush() @@ -240,7 +245,7 @@ def run_successfully(cmd, **kwargs): def local_cmd(cmd): - return subprocess.check_output(cmd).decode('utf8') + return subprocess.check_output(cmd).decode("utf8") def _json_default(obj): @@ -253,22 +258,24 @@ def print_json(obj): print(json.dumps(obj, indent=4, default=_json_default, sort_keys=True)) -# non-win32 -def write_atomic(fn, data, bakext=None, mode='b'): +def write_atomic(fn, data, bakext=None, mode="b"): """Write file with rename.""" - if mode not in ['', 'b', 't']: + if mode not in ("", "b", "t"): raise ValueError("unsupported fopen mode") # write new data to tmp file - fn2 = fn + '.new' - f = open(fn2, 'w' + mode) - f.write(as_bytes(data)) - f.close() + fn2 = fn + ".new" + if mode == "b": + with open(fn2, "wb") as f: + f.write(as_bytes(data)) + else: + with open(fn2, "w", encoding="utf8") as f: + f.write(data) # link old data to backup file if bakext: - if bakext.find('/') >= 0: + if bakext.find("/") >= 0: raise ValueError("invalid bakext") fnb = fn + bakext try: @@ -283,10 +290,10 @@ def write_atomic(fn, data, bakext=None, mode='b'): raise # win32 does not like replace - if sys.platform == 'win32': + if sys.platform == "win32": try: os.remove(fn) - except: + except BaseException: pass # atomically replace file @@ -297,21 +304,20 @@ def fmt_dur(dur): """Format time duration. >>> dlong = ((27 * 24 + 2) * 60 + 38) * 60 + 43 - >>> [fmt_dur(v) for v in (0.001, 1.1, dlong, -5)] == ['0s', '1s', '27d2h38m43s', '-5s'] + >>> [fmt_dur(v) for v in (0.001, 1.1, dlong, -5)] == ["0s", "1s", "27d2h38m43s", "-5s"] True """ res = [] if dur < 0: - res.append('-') + res.append("-") dur = -dur tmp, secs = divmod(int(dur), 60) tmp, mins = divmod(tmp, 60) days, hours = divmod(tmp, 24) - for (val, unit) in ((days, 'd'), (hours, 'h'), (mins, 'm'), (secs, 's')): + for (val, unit) in ((days, "d"), (hours, "h"), (mins, "m"), (secs, "s")): if val: - res.append('%d%s' % (val, unit)) + res.append("%d%s" % (val, unit)) if not res: - return '0s' - return ''.join(res) - + return "0s" + return "".join(res) diff --git a/vmtool/xglob.py b/vmtool/xglob.py index 989c40e..3494f62 100644 --- a/vmtool/xglob.py +++ b/vmtool/xglob.py @@ -22,20 +22,21 @@ """ -import sys +import functools import os import os.path import re -import functools +import sys +from typing import Callable, Iterable, Iterator, Match, Optional -__all__ = ['xglob', 'xfilter'] +__all__ = ["xglob", "xfilter"] # special regex symbols -_RXMAGIC = re.compile(r'[][(){}\\.?*+|^$]') +_RXMAGIC = re.compile(r"[][(){}\\.?*+|^$]") # glob magic -_GMAGIC = re.compile(r'[][()*?]') +_GMAGIC = re.compile(r"[][()*?]") # glob tokens _GTOK = re.compile(r""" @@ -46,40 +47,40 @@ # map glob syntax to regex syntax _PARENS = { - '?(': ['(?:', ')?'], - '*(': ['(?:', ')*'], - '+(': ['(?:', ')+'], - '@(': ['(?:', ')'], - '!(': ['(?!', ')'], + "?(": ("(?:", ")?"), + "*(": ("(?:", ")*"), + "+(": ("(?:", ")+"), + "@(": ("(?:", ")"), + "!(": ("(?!", ")"), } -def escape(s): +def escape(s: str) -> str: """Escape glob meta-characters. """ - return _GMAGIC.sub(r'[\g<0>]', s) + return _GMAGIC.sub(r"[\g<0>]", s) -def re_escape(s): +def re_escape(s: str) -> str: """Escape regex meta-characters. """ - return _RXMAGIC.sub(r'\\\g<0>', s) + return _RXMAGIC.sub(r"\\\g<0>", s) -def has_magic(pat): +def has_magic(pat: str) -> bool: """Contains glob magic chars. """ return _GMAGIC.search(pat) is not None -def _nomatch(name): +def _nomatch(name: str) -> Optional[Match[str]]: """Invalid pattern does not match anything. """ return None @functools.lru_cache(maxsize=256, typed=True) -def _compile(pat): +def _compile(pat: str) -> Callable[[str], Optional[Match[str]]]: """Convert glob/fnmatch pattern to compiled regex. """ plen = len(pat) @@ -98,25 +99,25 @@ def _compile(pat): c = m.group(0) if len(c) > 1: - if c[0] == '[': - if c[1] == '!': - x = '[^' + re_escape(c[2:-1]) + ']' + if c[0] == "[": + if c[1] == "!": + x = "[^" + re_escape(c[2:-1]) + "]" else: - x = '[' + re_escape(c[1:-1]) + ']' + x = "[" + re_escape(c[1:-1]) + "]" elif c in _PARENS: x = _PARENS[c][0] parens.append(_PARENS[c][1]) else: x = re_escape(c) - elif c == '?': - x = '.' - elif c == '*': - x = '.*' + elif c == "?": + x = "." + elif c == "*": + x = ".*" if res and res[-1] == x: continue - elif c == ')' and parens: + elif c == ")" and parens: x = parens.pop() - elif c == '|' and parens: + elif c == "|" and parens: x = c else: x = re_escape(c) @@ -125,17 +126,17 @@ def _compile(pat): if parens: return _nomatch - xre = r'\A' + ''.join(res) + r'\Z' + xre = r"\A" + "".join(res) + r"\Z" return re.compile(xre, re.S).match -def xfilter(pat, names): +def xfilter(pat: str, names: Iterable[str]) -> Iterator[str]: """Filter name list based on glob pattern. """ matcher = _compile(pat) - if pat[0] != '.': + if pat[0] != ".": for n in names: - if n[0] != '.' and matcher(n): + if n[0] != "." and matcher(n): yield n else: for n in names: @@ -143,17 +144,17 @@ def xfilter(pat, names): yield n -def dirglob_nopat(dirname, basename, dirs_only): +def _dirglob_nopat(dirname: str, pattern: str, dirs_only: bool) -> Iterator[str]: """File name without pattern. """ - if basename == '': - if os.path.isdir(dirname): - yield basename - elif os.path.lexists(os.path.join(dirname, basename)): - yield basename + if pattern == "": + if os.path.isdir(pattern): + yield pattern + elif os.path.lexists(os.path.join(dirname, pattern)): + yield pattern -def dirglob_pat(dirname, pattern, dirs_only): +def _dirglob_pat(dirname: str, pattern: str, dirs_only: bool) -> Iterator[str]: """File name with pattern. """ if not isinstance(pattern, bytes) and isinstance(dirname, bytes): @@ -165,30 +166,30 @@ def dirglob_pat(dirname, pattern, dirs_only): return xfilter(pattern, names) -def dirglob_subtree(dirname, pattern, dirs_only): - """File name is '**', recurse into subtrees. +def _dirglob_subtree(dirname: str, pattern: str, dirs_only: bool) -> Iterator[str]: + """File name is "**", recurse into subtrees. """ for dp, dnames, fnames in os.walk(dirname, topdown=True): if dp == dirname: - basedir = '' + basedir = "" yield basedir else: basedir = dp[len(dirname) + 1:] + os.path.sep if not dirs_only: for fn in fnames: - if fn[0] != '.': + if fn[0] != ".": yield basedir + fn filtered = [] for dn in dnames: - if dn[0] != '.': + if dn[0] != ".": filtered.append(dn) yield basedir + dn - dnames[:] = filtered + dnames[:] = filtered # os.walk(topdown) next scan -def _xglob(pat, dirs_only=False): +def _xglob(pat: str, dirs_only: bool = False) -> Iterator[str]: """Internal implementation. """ @@ -202,7 +203,7 @@ def _xglob(pat, dirs_only=False): dn, bn = os.path.split(pat) if not dn: # pattern without dir part - for name in dirglob_pat(os.curdir, bn, dirs_only): + for name in _dirglob_pat(os.curdir, bn, dirs_only): yield name return @@ -213,20 +214,20 @@ def _xglob(pat, dirs_only=False): dirs = iter([dn]) # decide how to expand file part - if bn == '**': - dirglob = dirglob_subtree + if bn == "**": + dirglob = _dirglob_subtree elif has_magic(bn): - dirglob = dirglob_pat + dirglob = _dirglob_pat else: - dirglob = dirglob_nopat + dirglob = _dirglob_nopat # loop over files for dn in dirs: for name in dirglob(dn, bn, dirs_only): - yield os.path.join(dn, name).replace(os.path.sep, '/') + yield os.path.join(dn, name).replace(os.path.sep, "/") -def xglob(pat): +def xglob(pat: str) -> Iterator[str]: """Extended glob. Supports ** and extended glob syntax in pattern. @@ -236,12 +237,12 @@ def xglob(pat): return _xglob(pat) -def main(): +def main() -> None: for pat in sys.argv[1:]: for fn in xglob(pat): print(fn) -if __name__ == '__main__': +if __name__ == "__main__": main()