Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,12 @@ jobs:
- name: Build distribution
run: uv build

- name: Install package with mypy
- name: Install package with test dependencies
run: |
uv pip install --system dist/*.whl
uv pip install --system mypy
uv pip install --system -e ".[test]"

- run: mypy --install-types --non-interactive commcare_export/ tests/ migrations/
- run: mypy commcare_export/ tests/

finish:
needs: test
Expand Down
2 changes: 1 addition & 1 deletion MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
recursive-include migrations *.py *.ini
recursive-include commcare_export/migrations *.py *.ini *.mako
include commcare_export/VERSION
43 changes: 0 additions & 43 deletions commcare_export/__init__.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,5 @@
import logging
import os
import re
from .version import __version__

__all__ = [
'__version__',
'get_logger',
'get_error_logger',
'logger_name_from_filepath',
'Logger',
'repo_root',
]

repo_root = os.path.abspath(os.path.join(__file__, os.pardir, os.pardir))


class Logger:
def __init__(self, logger, level):
self.logger = logger
self.level = level
self.linebuf = ''

def write(self, buf):
for line in buf.rstrip().splitlines():
self.logger.log(self.level, line.rstrip())


def logger_name_from_filepath(filepath):
relative_path = os.path.relpath(filepath, start=repo_root)
cleaned_path = relative_path.replace('/', '.')
return re.sub(r'\.py$', '', cleaned_path)


def get_error_logger():
return Logger(logging.getLogger(), logging.ERROR)


def get_logger(filepath=None):
if filepath:
logger = logging.getLogger(
logger_name_from_filepath(filepath)
)
else:
logger = logging.getLogger()

logger.setLevel(logging.DEBUG)
return logger
22 changes: 11 additions & 11 deletions commcare_export/checkpoint.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import os
import logging
import uuid
from importlib import resources
from contextlib import contextmanager
from operator import attrgetter

Expand All @@ -12,9 +13,8 @@
from commcare_export.commcare_minilinq import PaginationMode
from commcare_export.exceptions import DataExportException
from commcare_export.writers import SqlMixin
from commcare_export import get_logger, repo_root

logger = get_logger(__file__)
logger = logging.getLogger(__name__)
Base = declarative_base()


Expand Down Expand Up @@ -85,7 +85,7 @@ def session_scope(Session):

class CheckpointManager(SqlMixin):
table_name = 'commcare_export_runs'
migrations_repository = os.path.join(repo_root, 'migrations')
migrations_repository = resources.files("commcare_export") / "migrations"

def __init__(
self,
Expand Down Expand Up @@ -200,13 +200,13 @@ def _set_checkpoint(

def create_checkpoint_table(self, revision='head'):
from alembic import command, config
cfg = config.Config(
os.path.join(self.migrations_repository, 'alembic.ini')
)
cfg.set_main_option('script_location', self.migrations_repository)
with self.engine.begin() as connection:
cfg.attributes['connection'] = connection
command.upgrade(cfg, revision)

with resources.as_file(self.migrations_repository) as migrations_path:
cfg = config.Config(str(migrations_path / 'alembic.ini'))
cfg.set_main_option('script_location', str(migrations_path))
with self.engine.begin() as connection:
cfg.attributes['connection'] = connection
command.upgrade(cfg, revision)

def _cleanup(self):
self._validate_tables()
Expand Down
127 changes: 62 additions & 65 deletions commcare_export/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json
import os
import sys
import logging
import dateutil.parser
import requests
import sqlalchemy
Expand All @@ -28,11 +27,11 @@
from commcare_export.repeatable_iterator import RepeatableIterator
from commcare_export.utils import get_checkpoint_manager
from commcare_export.version import __version__
from commcare_export import get_logger, get_error_logger
import logging

EXIT_STATUS_SUCCESS = 0
EXIT_STATUS_ERROR = 1
logger = get_logger(__file__)
logger = logging.getLogger(__name__)

commcare_hq_aliases = {
'local': 'http://localhost:8000',
Expand Down Expand Up @@ -184,13 +183,13 @@ def add_to_parser(self, parser, **additional_kwargs):
]


def set_up_logging(log_dir=None):
def set_up_file_logging(log_dir=None):
"""
Set up file-based logging.

:param log_dir: Directory where the log file will be written. If
None, uses the current working directory.
:returns tuple: (success, log_file_path, error_msg)
:returns tuple: (success, log_file_path, error_msg, file_handler)
"""
if log_dir is None:
log_dir = os.getcwd()
Expand All @@ -203,15 +202,38 @@ def set_up_logging(log_dir=None):
with open(log_file, 'a'): # Test write permissions
pass

logging.basicConfig(
filename=log_file,
format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s',
filemode='a',
)
sys.stderr = get_error_logger()
return True, log_file, None
file_handler = logging.FileHandler(log_file, mode='a')
return True, log_file, None, file_handler
except (OSError, IOError, PermissionError) as err:
return False, log_file, f"{type(err).__name__}: {err}"
return False, log_file, f"{type(err).__name__}: {err}", None


def set_up_logging(args):
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter('%(message)s'))
handlers = [stream_handler]
if not args.no_logfile:
success, log_file, error, file_handler = set_up_file_logging(
args.log_dir
)
if success:
file_handler.setFormatter(
logging.Formatter(
'%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
)
)
handlers.append(file_handler)
print(f'Writing logs to {log_file}')
else:
print(f'Warning: Unable to write to log file {log_file}: {error}')
print('Logging to console only.')

log_level = logging.DEBUG if args.verbose else logging.WARN
root_logger = logging.getLogger()
root_logger.handlers.clear()
root_logger.setLevel(log_level)
for handler in handlers:
root_logger.addHandler(handler)


def main(argv):
Expand All @@ -229,24 +251,7 @@ def main(argv):
if errors:
raise Exception(f"Could not proceed. Following issues were found: {', '.join(errors)}.")

if not args.no_logfile:
success, log_file, error = set_up_logging(args.log_dir)
if success:
print(f'Writing logs to {log_file}')
else:
print(f'Warning: Unable to write to log file {log_file}: {error}')
print('Logging to console only.')

if args.verbose:
logging.basicConfig(
level=logging.DEBUG,
format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
)
else:
logging.basicConfig(
level=logging.WARN,
format='%(asctime)s %(name)-12s %(levelname)-8s %(message)s'
)
set_up_logging(args)

logging.getLogger('alembic').setLevel(logging.WARN)
logging.getLogger('backoff').setLevel(logging.FATAL)
Expand All @@ -258,13 +263,7 @@ def main(argv):

if not args.project:
error_msg = "commcare-export: error: argument --project is required"
# output to log file through sys.stderr
print(
error_msg,
file=sys.stderr
)
# Output to console for debugging
print(error_msg)
logger.error(error_msg)
sys.exit(1)

print("Running export...")
Expand Down Expand Up @@ -364,11 +363,9 @@ def _get_writer(output_format, output, strict_types):
elif output_format == 'csv':
if not output.endswith(".zip"):
print(
"WARNING: csv output is a zip file, but "
f"will be written to {output}"
)
print(
"Consider appending .zip to the file name to avoid confusion."
'WARNING: CSV output is a zip file, but will be written to '
f"'{output}'.\n"
"Consider appending '.zip' to the file name to avoid confusion."
)
return writers.CsvTableWriter(output)
elif output_format == 'json':
Expand All @@ -383,8 +380,9 @@ def _get_writer(output_format, output, strict_types):
charset_split = output.split('charset=')
if len(charset_split) > 1 and charset_split[1] != 'utf8mb4':
raise Exception(
f"The charset '{charset_split[1]}' might cause problems with the export. "
f"It is recommended that you use 'utf8mb4' instead."
f"The charset '{charset_split[1]}' might cause problems "
"with the export. It is recommended that you use "
"'utf8mb4' instead."
)

return writers.SqlTableWriter(output, strict_types)
Expand Down Expand Up @@ -414,8 +412,7 @@ def _get_checkpoint_manager(args):
args.query
):
logger.warning(
"Checkpointing disabled for non builtin, "
"non file-based query"
"Checkpointing disabled for non builtin, non file-based query"
)
elif args.since or args.until:
logger.warning(
Expand All @@ -442,29 +439,30 @@ def evaluate_query(env, query):
lazy_result = query.eval(env)
force_lazy_result(lazy_result)
return 0
except requests.exceptions.RequestException as e:
if e.response and e.response.status_code == 401:
print(
"\nAuthentication failed. Please check your credentials.",
file=sys.stderr
except requests.exceptions.RequestException as err:
if err.response and err.response.status_code == 401:
logger.error(
"Authentication failed. Please check your credentials."
)
return EXIT_STATUS_ERROR
else:
raise
except ResourceRepeatException as e:
print('Stopping because the export is stuck')
print(e.message)
print('Try increasing --batch-size to overcome the error')
except ResourceRepeatException as err:
logger.error(
'Stopping because the export is stuck.\n'
f'{err.message}\n'
f'Try increasing --batch-size to overcome the error'
)
return EXIT_STATUS_ERROR
except (
sqlalchemy.exc.DataError,
sqlalchemy.exc.InternalError,
sqlalchemy.exc.ProgrammingError
) as e:
print('Stopping because of database error:\n', e)
) as err:
logger.error(f'Stopping because of database error:\n{err}')
return EXIT_STATUS_ERROR
except KeyboardInterrupt:
print('\nExport aborted', file=sys.stderr)
logger.error("Export aborted")
return EXIT_STATUS_ERROR


Expand All @@ -473,10 +471,9 @@ def main_with_args(args):
writer = _get_writer(args.output_format, args.output, args.strict_types)

if args.query is None and args.users is False and args.locations is False:
print(
'At least one the following arguments is required: '
'--query, --users, --locations',
file=sys.stderr
logger.error(
"At least one the following arguments is required: "
"--query, --users, --locations"
)
return EXIT_STATUS_ERROR

Expand All @@ -500,8 +497,8 @@ def main_with_args(args):
lp = LocationInfoProvider(api_client, page_size=args.batch_size)
try:
query = get_queries(args, writer, lp, column_enforcer)
except DataExportException as e:
print(e.message, file=sys.stderr)
except DataExportException as err:
logger.error(err.message)
return EXIT_STATUS_ERROR

if args.dump_query:
Expand Down
3 changes: 1 addition & 2 deletions commcare_export/commcare_hq_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import commcare_export
from commcare_export.repeatable_iterator import RepeatableIterator
from commcare_export import get_logger

AUTH_MODE_PASSWORD = 'password'
AUTH_MODE_APIKEY = 'apikey'
Expand All @@ -20,7 +19,7 @@
LATEST_KNOWN_VERSION = '0.5'
RESOURCE_REPEAT_LIMIT = 10

logger = get_logger(__file__)
logger = logging.getLogger(__name__)


def on_wait(details):
Expand Down
4 changes: 2 additions & 2 deletions commcare_export/commcare_minilinq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
API directly.
"""
import json
import logging
from enum import Enum
from typing import Any
from urllib.parse import parse_qs, urlparse
Expand All @@ -14,9 +15,8 @@

from commcare_export.env import CannotBind, CannotReplace, DictEnv
from commcare_export.misc import unwrap
from commcare_export import get_logger

logger = get_logger(__file__)
logger = logging.getLogger(__name__)

SUPPORTED_RESOURCES = {
'form',
Expand Down
5 changes: 3 additions & 2 deletions commcare_export/location_info_provider.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@

import logging

from commcare_export.commcare_minilinq import SimplePaginator
from commcare_export.misc import unwrap_val
from commcare_export import get_logger

logger = get_logger(__file__)
logger = logging.getLogger(__name__)

# LocationInfoProvider uses the /location_type/ endpoint of the API to
# retrieve location type data, stores that information in a dictionary
Expand Down
Loading
Loading