Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions beangulp/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

import chardet

from beancount.utils import defdict
from beangulp import mimetypes
from beangulp import utils

# NOTE: See get_file() at the end of this file to create instances of FileMemo.

Expand Down Expand Up @@ -155,7 +155,7 @@ def get_file(filename):
return _CACHE[filename]


_CACHE = defdict.DefaultDictWithKey(_FileMemo)
_CACHE = utils.DefaultDictWithKey(_FileMemo)


def cache(func=None, *, key=None):
Expand Down
12 changes: 12 additions & 0 deletions beangulp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from os import path
from typing import Iterator, Sequence, Union, Set, Optional, Dict
import datetime
import collections
import decimal
import hashlib
import logging
Expand All @@ -13,6 +14,17 @@
from beangulp import mimetypes


class DefaultDictWithKey(collections.defaultdict):
"""A version of defaultdict whose factory accepts the key as an argument.
Note: collections.defaultdict would be improved by supporting this directly,
this is a common occurrence.
"""

def __missing__(self, key):
self[key] = value = self.default_factory(key)
return value


def getmdate(filepath: str) -> datetime.date:
"""Return file modification date."""
mtime = path.getmtime(filepath)
Expand Down
13 changes: 12 additions & 1 deletion beangulp/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import types
import unittest

from unittest import mock
from shutil import rmtree
from tempfile import mkdtemp

Expand Down Expand Up @@ -111,3 +111,14 @@ def test_idify(self):
)
self.assertEqual("A____B.pdf", utils.idify("A____B_._pdf"))


class TestDefDictWithKey(unittest.TestCase):
def test_defdict_with_key(self):
factory = mock.MagicMock()
testdict = utils.DefaultDictWithKey(factory)

testdict["a"]
testdict["b"]
self.assertEqual(2, len(factory.mock_calls))
self.assertEqual(("a",), factory.mock_calls[0][1])
self.assertEqual(("b",), factory.mock_calls[1][1])
Loading