Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions beangulp/importers/csvbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import re

from collections import defaultdict
from itertools import islice
from itertools import islice, tee
from beancount.core import data

import beangulp
Expand All @@ -17,6 +17,26 @@
"""Marker to indicate that a value was not specified."""


def _chomp(iterable, head, tail):
"""Return an iterator that yields selected values from an iterable.

Args:
iterable: The iterable to iterate.
head: Number of initial elements to skip.
tail: Number of trailing elements to skip.

>>> list(_chomp(range(10), 2, 3))
[2, 3, 4, 5, 6]

"""
iterator = islice(iterable, head, None)
if not tail:
yield from iterator
iterator, sentinel = tee(iterator)
for _ in islice(sentinel, tail, None):
yield next(iterator)


def _resolve(spec, names):
"""Resolve column specification into column index.

Expand Down Expand Up @@ -234,8 +254,10 @@ class Order(Enum):
class CSVReader(metaclass=CSVMeta):
encoding = "utf8"
"""File encoding."""
skiplines = 0
"""Number of input lines to skip before starting processing."""
header = 0
"""Number of header lines to skip."""
footer = 0
"""Number of footer lines to ignore."""
names = True
"""Whether the data file contains a row with column names."""
dialect = None
Expand All @@ -248,6 +270,12 @@ class CSVReader(metaclass=CSVMeta):
# This is populated by the CSVMeta metaclass.
columns = {}

def __init__(self):
if hasattr(self, 'skiplines'):
# Warn about use of deprecated class attribute, eventually.
# warnings.warn('skiplines is deprecated, use header instead', DeprecationWarning)
self.header = self.skiplines

def read(self, filepath):
"""Read CSV file according to class defined columns specification.

Expand All @@ -265,8 +293,8 @@ def read(self, filepath):
"""

with open(filepath, encoding=self.encoding) as fd:
# Skip header lines.
lines = islice(fd, self.skiplines, None)
# Skip header and footer lines.
lines = _chomp(fd, self.header, self.footer)

# Filter out comment lines.
if self.comments:
Expand Down Expand Up @@ -311,6 +339,7 @@ class Importer(beangulp.Importer, CSVReader):
"""

def __init__(self, account, currency, flag="*"):
super().__init__()
self.importer_account = account
self.currency = currency
self.flag = flag
Expand Down Expand Up @@ -347,7 +376,7 @@ def extract(self, filepath, existing):
default_account = self.account(filepath)

# Compute the line number of the first data line.
offset = int(self.skiplines) + bool(self.names) + 1
offset = int(self.header) + bool(self.names) + 1

for lineno, row in enumerate(self.read(filepath), offset):
# Skip empty lines.
Expand Down
37 changes: 37 additions & 0 deletions beangulp/importers/csvbase_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from beancount.parser import cmptest
from beancount.utils.test_utils import docfile
from beangulp.importers.csvbase import (
_chomp,
Column,
Columns,
Date,
Expand Down Expand Up @@ -429,6 +430,27 @@ class Reader(CSVReader):
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0][0], "a")

@docfile
def test_footer(self, filename):
"""\
Header
First, Second
a, b
Footer
Footer
"""

class Reader(CSVReader):
first = Column(0)
second = Column(1)
header = 1
footer = 2

reader = Reader()
rows = list(reader.read(filename))
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0][0], "a")


class Base(Importer):
def identify(self, filepath):
Expand Down Expand Up @@ -772,3 +794,18 @@ class CSVImporter(Base):
with self.assertRaisesRegex(RuntimeError, msg) as ctx:
importer.extract(filename, [])
self.assertIsInstance(ctx.exception.__cause__, decimal.InvalidOperation)


class TestChomp(unittest.TestCase):

def test_header(self):
self.assertEqual(list(_chomp(range(10), 2, 0)), [2, 3, 4, 5, 6, 7, 8, 9])

def test_footer(self):
self.assertEqual(list(_chomp(range(10), 0, 3)), [0, 1, 2, 3, 4, 5, 6])

def test_header_and_footer(self):
self.assertEqual(list(_chomp(range(10), 2, 3)), [2, 3, 4, 5, 6])

def test_short(self):
self.assertEqual(list(_chomp(range(1), 2, 3)), [])