Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setup(
name="DAWG2",
version="0.9.0",
version="0.9.1",
description="Fast and memory efficient DAWG (DAFSA) for Python",
long_description=open('README.rst').read() + '\n\n' + open('CHANGES.rst').read(),
author='Mikhail Korobov',
Expand Down
65 changes: 34 additions & 31 deletions src/dawg.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,13 @@ cdef class DAWG:
b_step = <bytes>(key[word_pos].encode('utf8'))

if b_step in replace_chars:
next_index = index
b_replace_char, u_replace_char = <tuple>replace_chars[b_step]

if self.dct.Follow(b_replace_char, &next_index):
prefix = current_prefix + key[start_pos:word_pos] + u_replace_char
extra_keys = self._similar_keys(prefix, key, next_index, replace_chars)
res.extend(extra_keys)
for (b_replace_char, u_replace_char) in replace_chars[b_step]:
next_index = index
is_followed = self.dct.Follow(b_replace_char, &next_index)
if is_followed:
prefix = current_prefix + key[start_pos:word_pos] + u_replace_char
extra_keys = self._similar_keys(prefix, key, next_index, replace_chars)
res.extend(extra_keys)

if not self.dct.Follow(b_step, &index):
break
Expand All @@ -225,17 +225,17 @@ cdef class DAWG:

``replaces`` is an object obtained from
``DAWG.compile_replaces(mapping)`` where mapping is a dict
that maps single-char unicode sitrings to another single-char
that maps single-char unicode strings to (one or more) single-char
unicode strings.

This may be useful e.g. for handling single-character umlauts.
"""
return self._similar_keys("", key, self.dct.root(), replaces)

cpdef list prefixes(self, unicode key):
'''
"""
Return a list with keys of this DAWG that are prefixes of the ``key``.
'''
"""
return [p.decode('utf8') for p in self.b_prefixes(<bytes>key.encode('utf8'))]

cpdef list b_prefixes(self, bytes b_key):
Expand All @@ -254,9 +254,9 @@ cdef class DAWG:
return res

def iterprefixes(self, unicode key):
'''
"""
Return a generator with keys of this DAWG that are prefixes of the ``key``.
'''
"""
cdef BaseType index = self.dct.root()
cdef bytes b_key = <bytes>key.encode('utf8')
cdef int pos = 1
Expand All @@ -273,13 +273,16 @@ cdef class DAWG:
def compile_replaces(cls, replaces):

for k,v in replaces.items():
if len(k) != 1 or len(v) != 1:
raise ValueError("Keys and values must be single-char unicode strings.")

if len(k) != 1:
raise ValueError("Keys must be single-char unicode strings.")
if (isinstance(v, str) and len(v) != 1):
raise ValueError("Values must be single-char unicode strings or non-empty lists of such.")
if isinstance(v, list) and (any(len(v_entry) != 1 for v_entry in v) or len(v) < 1):
raise ValueError("Values must be single-char unicode strings or non-empty lists of such.")
return dict(
(
k.encode('utf8'),
(v.encode('utf8'), unicode(v))
[(v_entry.encode('utf8'), unicode(v_entry)) for v_entry in v]
)
for k, v in replaces.items()
)
Expand Down Expand Up @@ -725,13 +728,13 @@ cdef class BytesDAWG(CompletionDAWG):
b_step = <bytes>(key[word_pos].encode('utf8'))

if b_step in replace_chars:
next_index = index
b_replace_char, u_replace_char = <tuple>replace_chars[b_step]

if self.dct.Follow(b_replace_char, &next_index):
prefix = current_prefix + key[start_pos:word_pos] + u_replace_char
extra_items = self._similar_items(prefix, key, next_index, replace_chars)
res.extend(extra_items)
for (b_replace_char, u_replace_char) in replace_chars[b_step]:
next_index = index
is_followed = self.dct.Follow(b_replace_char, &next_index)
if is_followed:
prefix = current_prefix + key[start_pos:word_pos] + u_replace_char
extra_items = self._similar_items(prefix, key, next_index, replace_chars)
res.extend(extra_items)

if not self.dct.Follow(b_step, &index):
break
Expand All @@ -752,7 +755,7 @@ cdef class BytesDAWG(CompletionDAWG):

``replaces`` is an object obtained from
``DAWG.compile_replaces(mapping)`` where mapping is a dict
that maps single-char unicode sitrings to another single-char
that maps single-char unicode strings to (one or more) single-char
unicode strings.
"""
return self._similar_items("", key, self.dct.root(), replaces)
Expand All @@ -772,12 +775,12 @@ cdef class BytesDAWG(CompletionDAWG):
b_step = <bytes>(key[word_pos].encode('utf8'))

if b_step in replace_chars:
next_index = index
b_replace_char, u_replace_char = <tuple>replace_chars[b_step]

if self.dct.Follow(b_replace_char, &next_index):
extra_items = self._similar_item_values(word_pos+1, key, next_index, replace_chars)
res.extend(extra_items)
for (b_replace_char, u_replace_char) in replace_chars[b_step]:
next_index = index
is_followed = self.dct.Follow(b_replace_char, &next_index)
if is_followed:
extra_items = self._similar_item_values(word_pos+1, key, next_index, replace_chars)
res.extend(extra_items)

if not self.dct.Follow(b_step, &index):
break
Expand All @@ -797,7 +800,7 @@ cdef class BytesDAWG(CompletionDAWG):

``replaces`` is an object obtained from
``DAWG.compile_replaces(mapping)`` where mapping is a dict
that maps single-char unicode sitrings to another single-char
that maps single-char unicode strings to (one or more) single-char
unicode strings.
"""
return self._similar_item_values(0, key, self.dct.root(), replaces)
Expand Down
67 changes: 67 additions & 0 deletions tests/test_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,73 @@ class TestPrediction(object):
]


@pytest.mark.parametrize(("word", "prediction"), SUITE)
def test_dawg_prediction(self, word, prediction):
d = dawg.DAWG(self.DATA)
assert d.similar_keys(word, self.REPLACES) == prediction

@pytest.mark.parametrize(("word", "prediction"), SUITE)
def test_record_dawg_prediction(self, word, prediction):
d = dawg.RecordDAWG(str("=H"), self.LENGTH_DATA)
assert d.similar_keys(word, self.REPLACES) == prediction

@pytest.mark.parametrize(("word", "prediction"), SUITE_ITEMS)
def test_record_dawg_items(self, word, prediction):
d = dawg.RecordDAWG(str("=H"), self.LENGTH_DATA)
assert d.similar_items(word, self.REPLACES) == prediction

@pytest.mark.parametrize(("word", "prediction"), SUITE_VALUES)
def test_record_dawg_items_values(self, word, prediction):
d = dawg.RecordDAWG(str("=H"), self.LENGTH_DATA)
assert d.similar_item_values(word, self.REPLACES) == prediction

class TestMultiValuedPrediction(object):
DATA = "хлѣб ёлка ель лѣс лѣсное всё всѣ бѣлёная изобрѣтён лев лёв лѣв вѣнскій".split(" ")
LENGTH_DATA = list(zip(DATA, ((len(w),) for w in DATA)))

REPLACES = dawg.DAWG.compile_replaces({'е': ['ё', 'ѣ'], 'и': 'і'})

SUITE = [
('осел', []),
('ель', ['ель']),
('ёль', []),
('хлеб', ['хлѣб']),
('елка', ['ёлка']),
('лесное', ['лѣсное']),
('лесноё', []),
('лёсное', []),
('изобретен', ['изобрѣтён']),
('беленая', ['бѣлёная']),
('белёная', ['бѣлёная']),
('бѣленая', ['бѣлёная']),
('бѣлёная', ['бѣлёная']),
('белѣная', []),
('бѣлѣная', []),
('все', ['всё', 'всѣ']),
('лев', ['лев', 'лёв', 'лѣв']),
('венский', ['вѣнскій']),
]

SUITE_ITEMS = [
(
it[0], # key
[
(w, [(len(w),)]) # item, value pair
for w in it[1]
]
)
for it in SUITE
]

SUITE_VALUES = [
(
it[0], # key
[[(len(w),)] for w in it[1]]
)
for it in SUITE
]


@pytest.mark.parametrize(("word", "prediction"), SUITE)
def test_dawg_prediction(self, word, prediction):
d = dawg.DAWG(self.DATA)
Expand Down