From cebdbe97e0d87ba27d385111de5352da9efc1914 Mon Sep 17 00:00:00 2001 From: Lasse Borgholt Date: Tue, 21 Oct 2025 10:25:42 +0200 Subject: [PATCH 1/6] Added word-level pass + minor optimization --- evaluation/evaluate_dataset.py | 6 +- poetry.lock | 71 +- pyproject.toml | 1 + src/error_align/backtrace_graph.py | 52 +- .../baselines/optimal_word_alignment.py | 1 + .../baselines/power/power/levenshtein.py | 6 +- src/error_align/baselines/power_alignment.py | 35 +- src/error_align/edit_distance.py | 2 +- src/error_align/error_align.py | 654 ++++++++++++------ src/error_align/func.py | 11 +- src/error_align/utils.py | 23 +- tests/test_default.py | 11 +- 12 files changed, 603 insertions(+), 270 deletions(-) diff --git a/evaluation/evaluate_dataset.py b/evaluation/evaluate_dataset.py index 637139a..9230712 100644 --- a/evaluation/evaluate_dataset.py +++ b/evaluation/evaluate_dataset.py @@ -110,7 +110,7 @@ def get_error_alignments(ref: str, hyp: str, beam_size: int): List[Alignment]: A list of alignment objects. """ - return ErrorAlign(ref=ref, hyp=hyp).align(beam_size=beam_size, pbar=False) + return ErrorAlign(ref=ref, hyp=hyp, word_level_pass=True).align(beam_size=beam_size) def get_optimal_word_alignments(ref: str, hyp: str): @@ -274,8 +274,6 @@ def main(transcript_file: str, only_error_align: bool, beam_size: int, save_resu PronouncerLex("/home/lb/repos/power-asr/lex/cmudict.rep.json").pronounce if language_code == "en" else None ) - # dataset = dataset.select(range(792, 794)) - c_n, p_n = 0, 0 for example in tqdm(dataset): ref, hyp = example["ref"], example["hyp"] @@ -323,7 +321,7 @@ def main(transcript_file: str, only_error_align: bool, beam_size: int, save_resu norm_edits = c_n if edits == "character_edits" else p_n score = norm_edits / abs_edits if abs_edits > 0 else 1.0 duration = method_metrics["duration"] - print(f"{method_name}: score = {score:.4f} | edits = {abs_edits}/{norm_edits} | time = {duration:.1f}s") + print(f"{method_name}: score = {score:.4f} | edits = {abs_edits}/{norm_edits} | time = {duration:.2f}s") method_metrics["score"] = score # Convert edit lists to numpy arrays for statistical tests. diff --git a/poetry.lock b/poetry.lock index fa53012..6861b01 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4318,6 +4318,75 @@ files = [ [package.dependencies] numpy = "*" +[[package]] +name = "line-profiler" +version = "5.0.0" +description = "Line-by-line profiler" +optional = false +python-versions = ">=3.8" +groups = ["dev"] +files = [ + {file = "line_profiler-5.0.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5cd1621ff77e1f3f423dcc2611ef6fba462e791ce01fb41c95dce6d519c48ec8"}, + {file = "line_profiler-5.0.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:17a44491d16309bc39fc6197b376a120ebc52adc3f50b0b6f9baf99af3124406"}, + {file = "line_profiler-5.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a36a9a5ea5e37b0969a451f922b4dbb109350981187317f708694b3b5ceac3a5"}, + {file = "line_profiler-5.0.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b67e6e292efaf85d9678fe29295b46efd72c0d363b38e6b424df39b6553c49b3"}, + {file = "line_profiler-5.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b9c92c28ee16bf3ba99966854407e4bc927473a925c1629489c8ebc01f8a640"}, + {file = "line_profiler-5.0.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:51609cc264df6315cd9b9fa76d822a7b73a4f278dcab90ba907e32dc939ab1c2"}, + {file = "line_profiler-5.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:67f9721281655dc2b6763728a63928e3b8a35dfd6160c628a3c599afd0814a71"}, + {file = "line_profiler-5.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:c2c27ac0c30d35ca1de5aeebe97e1d9c0d582e3d2c4146c572a648bec8efcfac"}, + {file = "line_profiler-5.0.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f32d536c056393b7ca703e459632edc327ff9e0fc320c7b0e0ed14b84d342b7f"}, + {file = "line_profiler-5.0.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:a7da04ffc5a0a1f6653f43b13ad2e7ebf66f1d757174b7e660dfa0cbe74c4fc6"}, + {file = "line_profiler-5.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d2746f6b13c19ca4847efd500402d53a5ebb2fe31644ce8af74fbeac5ea4c54c"}, + {file = "line_profiler-5.0.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b4290319a59730c04cbd03755472d10524130065a20a695dc10dd66ffd92172"}, + {file = "line_profiler-5.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7cd168a8af0032e8e3cb2fbb9ffc7694cdcecd47ec356ae863134df07becb3a2"}, + {file = "line_profiler-5.0.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:cbe7b095865d00dda0f53d7d4556c2b1b5d13f723173a85edb206a78779ee07a"}, + {file = "line_profiler-5.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:ff176045ea8a9e33900856db31b0b979357c337862ae4837140c98bd3161c3c7"}, + {file = "line_profiler-5.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:474e0962d02123f1190a804073b308a67ef5f9c3b8379184483d5016844a00df"}, + {file = "line_profiler-5.0.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:729b18c0ac66b3368ade61203459219c202609f76b34190cbb2508b8e13998c8"}, + {file = "line_profiler-5.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:438ed24278c428119473b61a473c8fe468ace7c97c94b005cb001137bc624547"}, + {file = "line_profiler-5.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:920b0076dca726caadbf29f0bfcce0cbcb4d9ff034cd9445a7308f9d556b4b3a"}, + {file = "line_profiler-5.0.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:53326eaad2d807487dcd45d2e385feaaed81aaf72b9ecd4f53c1a225d658006f"}, + {file = "line_profiler-5.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e3995a989cdea022f0ede5db19a6ab527f818c59ffcebf4e5f7a8be4eb8e880"}, + {file = "line_profiler-5.0.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:8bf57892a1d3a42273652506746ba9f620c505773ada804367c42e5b4146d6b6"}, + {file = "line_profiler-5.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:43672085f149f5fbf3f08bba072ad7014dd485282e8665827b26941ea97d2d76"}, + {file = "line_profiler-5.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:446bd4f04e4bd9e979d68fdd916103df89a9d419e25bfb92b31af13c33808ee0"}, + {file = "line_profiler-5.0.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:9873fabbae1587778a551176758a70a5f6c89d8d070a1aca7a689677d41a1348"}, + {file = "line_profiler-5.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:2cd6cdb5a4d3b4ced607104dbed73ec820a69018decd1a90904854380536ed32"}, + {file = "line_profiler-5.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:34d6172a3bd14167b3ea2e629d71b08683b17b3bc6eb6a4936d74e3669f875b6"}, + {file = "line_profiler-5.0.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5edd859be322aa8252253e940ac1c60cca4c385760d90a402072f8f35e4b967"}, + {file = "line_profiler-5.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d4f97b223105eed6e525994f5653061bd981e04838ee5d14e01d17c26185094"}, + {file = "line_profiler-5.0.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4758007e491bee3be40ebcca460596e0e28e7f39b735264694a9cafec729dfa9"}, + {file = "line_profiler-5.0.0-cp313-cp313-win_amd64.whl", hash = "sha256:213b19c4b65942db5d477e603c18c76126e3811a39d8bab251d930d8ce82ffba"}, + {file = "line_profiler-5.0.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:84c91fdc813e41c7d07ff3d1630a8b9efd54646c144432178f8603424ab06f81"}, + {file = "line_profiler-5.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ebaf17814431f429d76166b7c0e57c6e84925f7b57e348f8edfd8e96968f0d73"}, + {file = "line_profiler-5.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:412efd162a9ad75d80410e58ba80368f587af854c6b373a152a4f858e15f6102"}, + {file = "line_profiler-5.0.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b3b05c9177201f02b18a70039e72bcf5a75288abb362e97e17a83f0db334e368"}, + {file = "line_profiler-5.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c4d3147aa07caa44e05f44db4e27ca4f5392187c0934f887bdb81d7dc1884c9"}, + {file = "line_profiler-5.0.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:6cec60f39d0e72548173bfcd419566221e2c0c6168ecca46678f427a0e21b732"}, + {file = "line_profiler-5.0.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:7d14141fe4376510cc192cd828f357bf276b8297fcda00ebac5adbc9235732f4"}, + {file = "line_profiler-5.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:64b4ce2506d1dac22f05f51692970ecb89741cb6a15bcb4c00212b2c39610ff1"}, + {file = "line_profiler-5.0.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7ba2142d35a3401d348cb743611bac52ba9db9cf026f8aa82c34d13effb98a71"}, + {file = "line_profiler-5.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:17724b2dff0edb3a4ac402bef6381060a4c424fbaa170e651306495f7c95bba8"}, + {file = "line_profiler-5.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d2315baca21a9be299b5a0a89f2ce4ed5cfd12ba039a82784a298dd106d3621d"}, + {file = "line_profiler-5.0.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:febbfc59502984e2cb0deb27cd163ed71847e36bbb82763f2bf3c9432cc440ab"}, + {file = "line_profiler-5.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:213dc34b1abdcafff944c13e62f2f1d254fc1cb30740ac0257e4567c8bea9a03"}, + {file = "line_profiler-5.0.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:011ac8167855513cac266d698b34b8ded9c673640d105a715c989fd5f27a298c"}, + {file = "line_profiler-5.0.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:4646907f588439845d7739d6a5f10ab08a2f8952d65f61145eeb705e8bb4797e"}, + {file = "line_profiler-5.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:2cb6dced51bf906ddf2a8d75eda3523cee4cfb0102f54610e8f849630341a281"}, + {file = "line_profiler-5.0.0.tar.gz", hash = "sha256:a80f0afb05ba0d275d9dddc5ff97eab637471167ff3e66dcc7d135755059398c"}, +] + +[package.extras] +all = ["Cython (>=3.0.3)", "IPython (>=8.12.2) ; python_version < \"3.9.0\" and python_version >= \"3.8.0\"", "IPython (>=8.14.0) ; python_version < \"4.0.0\" and python_version >= \"3.9.0\"", "cibuildwheel (>=2.11.2) ; python_version < \"4.0\" and python_version >= \"3.11\"", "cibuildwheel (>=2.11.2) ; python_version == \"3.10\"", "cibuildwheel (>=2.11.2) ; python_version == \"3.8\"", "cibuildwheel (>=2.11.2) ; python_version == \"3.9\"", "cmake (>=3.21.2)", "coverage[toml] (>=6.5.0) ; python_version < \"3.12\" and python_version >= \"3.10\"", "coverage[toml] (>=6.5.0) ; python_version == \"3.8\"", "coverage[toml] (>=6.5.0) ; python_version == \"3.9\"", "coverage[toml] (>=7.3.0) ; python_version < \"4.0\" and python_version >= \"3.12\"", "ninja (>=1.10.2)", "pytest (>=7.4.4) ; python_version < \"4.0\" and python_version >= \"3.13\"", "pytest (>=7.4.4) ; python_version == \"3.10\"", "pytest (>=7.4.4) ; python_version == \"3.11\"", "pytest (>=7.4.4) ; python_version == \"3.12\"", "pytest (>=7.4.4) ; python_version == \"3.8\"", "pytest (>=7.4.4) ; python_version == \"3.9\"", "pytest-cov (>=3.0.0)", "rich (>=12.3.0)", "scikit-build (>=0.11.1)", "setuptools (>=68.2.2) ; python_version < \"4.0\" and python_version >= \"3.8\"", "tomli ; python_version < \"3.11\"", "ubelt (>=1.3.4)", "xdoctest (>=1.1.3)"] +all-strict = ["Cython (==3.0.3)", "IPython (==8.12.2) ; python_version < \"3.9.0\" and python_version >= \"3.8.0\"", "IPython (==8.14.0) ; python_version < \"4.0.0\" and python_version >= \"3.9.0\"", "cibuildwheel (==2.11.2) ; python_version < \"4.0\" and python_version >= \"3.11\"", "cibuildwheel (==2.11.2) ; python_version == \"3.10\"", "cibuildwheel (==2.11.2) ; python_version == \"3.8\"", "cibuildwheel (==2.11.2) ; python_version == \"3.9\"", "cmake (==3.21.2)", "coverage[toml] (==6.5.0) ; python_version < \"3.12\" and python_version >= \"3.10\"", "coverage[toml] (==6.5.0) ; python_version == \"3.8\"", "coverage[toml] (==6.5.0) ; python_version == \"3.9\"", "coverage[toml] (==7.3.0) ; python_version < \"4.0\" and python_version >= \"3.12\"", "ninja (==1.10.2)", "pytest (==7.4.4) ; python_version < \"4.0\" and python_version >= \"3.13\"", "pytest (==7.4.4) ; python_version == \"3.10\"", "pytest (==7.4.4) ; python_version == \"3.11\"", "pytest (==7.4.4) ; python_version == \"3.12\"", "pytest (==7.4.4) ; python_version == \"3.8\"", "pytest (==7.4.4) ; python_version == \"3.9\"", "pytest-cov (==3.0.0)", "rich (==12.3.0)", "scikit-build (==0.11.1)", "setuptools (==68.2.2) ; python_version < \"4.0\" and python_version >= \"3.8\"", "tomli ; python_version < \"3.11\"", "ubelt (==1.3.4)", "xdoctest (==1.1.3)"] +ipython = ["IPython (>=8.12.2) ; python_version < \"3.9.0\" and python_version >= \"3.8.0\"", "IPython (>=8.14.0) ; python_version < \"4.0.0\" and python_version >= \"3.9.0\""] +ipython-strict = ["IPython (==8.12.2) ; python_version < \"3.9.0\" and python_version >= \"3.8.0\"", "IPython (==8.14.0) ; python_version < \"4.0.0\" and python_version >= \"3.9.0\""] +optional = ["IPython (>=8.12.2) ; python_version < \"3.9.0\" and python_version >= \"3.8.0\"", "IPython (>=8.14.0) ; python_version < \"4.0.0\" and python_version >= \"3.9.0\"", "rich (>=12.3.0)"] +optional-strict = ["IPython (==8.12.2) ; python_version < \"3.9.0\" and python_version >= \"3.8.0\"", "IPython (==8.14.0) ; python_version < \"4.0.0\" and python_version >= \"3.9.0\"", "rich (==12.3.0)"] +runtime-strict = ["tomli ; python_version < \"3.11\""] +tests = ["coverage[toml] (>=6.5.0) ; python_version < \"3.12\" and python_version >= \"3.10\"", "coverage[toml] (>=6.5.0) ; python_version == \"3.8\"", "coverage[toml] (>=6.5.0) ; python_version == \"3.9\"", "coverage[toml] (>=7.3.0) ; python_version < \"4.0\" and python_version >= \"3.12\"", "pytest (>=7.4.4) ; python_version < \"4.0\" and python_version >= \"3.13\"", "pytest (>=7.4.4) ; python_version == \"3.10\"", "pytest (>=7.4.4) ; python_version == \"3.11\"", "pytest (>=7.4.4) ; python_version == \"3.12\"", "pytest (>=7.4.4) ; python_version == \"3.8\"", "pytest (>=7.4.4) ; python_version == \"3.9\"", "pytest-cov (>=3.0.0)", "ubelt (>=1.3.4)", "xdoctest (>=1.1.3)"] +tests-strict = ["coverage[toml] (==6.5.0) ; python_version < \"3.12\" and python_version >= \"3.10\"", "coverage[toml] (==6.5.0) ; python_version == \"3.8\"", "coverage[toml] (==6.5.0) ; python_version == \"3.9\"", "coverage[toml] (==7.3.0) ; python_version < \"4.0\" and python_version >= \"3.12\"", "pytest (==7.4.4) ; python_version < \"4.0\" and python_version >= \"3.13\"", "pytest (==7.4.4) ; python_version == \"3.10\"", "pytest (==7.4.4) ; python_version == \"3.11\"", "pytest (==7.4.4) ; python_version == \"3.12\"", "pytest (==7.4.4) ; python_version == \"3.8\"", "pytest (==7.4.4) ; python_version == \"3.9\"", "pytest-cov (==3.0.0)", "ubelt (==1.3.4)", "xdoctest (==1.1.3)"] + [[package]] name = "llvmlite" version = "0.45.1" @@ -11741,4 +11810,4 @@ evaluation = ["backoff", "click", "datasets", "gitpython", "librosa", "matplotli [metadata] lock-version = "2.1" python-versions = ">=3.11,<3.14" -content-hash = "31f5fce1d9649bbc9fc00bf3605e696e575e8c2e8c5d2d6968c7fb598ef75a69" +content-hash = "79908b7edec986c8bf406b8f8331ae7fbd819a8454925eade284337aa1618623" diff --git a/pyproject.toml b/pyproject.toml index 2e1ea15..1518f84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,7 @@ pre-commit = "^4.2.0" build = "^1.3.0" twine = "^6.2.0" ipython = "^9.6.0" +line-profiler = "^5.0.0" [tool.poetry.group.test] optional = true diff --git a/src/error_align/backtrace_graph.py b/src/error_align/backtrace_graph.py index 4af9829..c5cd26c 100644 --- a/src/error_align/backtrace_graph.py +++ b/src/error_align/backtrace_graph.py @@ -1,7 +1,7 @@ import random from collections import Counter -from error_align.utils import OP_TYPE_COMBO_MAP, OpType +from error_align.utils import END_DELIMITER, OP_TYPE_COMBO_MAP, START_DELIMITER, OpType class Node: @@ -168,26 +168,64 @@ def get_path(self, sample=False): return path - def get_unambiguous_matches(self, ref): - """Get word spans that are unambiguously matched (i.e., only one path in backtrace graph). + def get_unambiguous_node_matches(self) -> list[tuple[int, int]]: + """Get nodes that can only be accounted for by a match. Returns: - list[Node]: A list of nodes representing the unambiguous path. + list[tuple[int, int]]: A list of index tuples representing the unambiguous node matches. """ - ref = "*" + ref # Index offset + match_indices = set() + match_per_token = { + "ref": Counter(), + "hyp": Counter(), + } + ref_op_types = {OpType.MATCH, OpType.SUBSTITUTE, OpType.DELETE} + hyp_op_types = {OpType.MATCH, OpType.SUBSTITUTE, OpType.INSERT} + + for (hyp_idx, ref_idx), node in self.nodes.items(): + # Identify all nodes at which a match occurs. + if len(node.parents) == 1 and OpType.MATCH in node.parents: + match_indices.add((hyp_idx, ref_idx)) + + # Count number of paths passing through each token. + if ref_op_types.intersection(node.parents): + match_per_token["ref"][ref_idx] += 1 + if hyp_op_types.intersection(node.parents): + match_per_token["hyp"][hyp_idx] += 1 + + # Collect only those matches that are unambiguous on both sides. + unambiguous_matches = [] + for hyp_idx, ref_idx in match_indices: + if match_per_token["ref"][ref_idx] == 1 and match_per_token["hyp"][hyp_idx] == 1: + unambiguous_matches.append((hyp_idx - 1, ref_idx - 1)) # Offset indices + + return sorted(unambiguous_matches, key=lambda n: n[1]) + + def get_unambiguous_token_span_matches(self, ref): + """Get word spans (i.e., <...>) that are unambiguously matched. + + That is, there is only one subpath that can account for the span using MATCH operations. + + Other subpaths that include INSERT, DELETE, SUBSTITUTE operations are not considered. + + Returns: + list[tuple[int, int]]: A list of index tuples representing the end node of unambiguous span matches. + + """ + ref = "_" + ref # NOTE: Implicit index offset for root node. mono_match_end_nodes = set() ref_idxs = Counter() hyp_idxs = Counter() for (hyp_idx, ref_idx), node in self.nodes.items(): - if OpType.MATCH in node.parents and ref[ref_idx] == "<": + if OpType.MATCH in node.parents and ref[ref_idx] == START_DELIMITER: _ref_idx, _hyp_idx = ref_idx + 1, hyp_idx + 1 while True: if (_hyp_idx, _ref_idx) not in self.nodes: break if OpType.MATCH not in self.nodes[(_hyp_idx, _ref_idx)].parents: break - if ref[_ref_idx] == ">": + if ref[_ref_idx] == END_DELIMITER: end_index = (_hyp_idx, _ref_idx) mono_match_end_nodes.add(end_index) ref_idxs[_ref_idx] += 1 diff --git a/src/error_align/baselines/optimal_word_alignment.py b/src/error_align/baselines/optimal_word_alignment.py index bcd8b68..1c8b05b 100644 --- a/src/error_align/baselines/optimal_word_alignment.py +++ b/src/error_align/baselines/optimal_word_alignment.py @@ -28,6 +28,7 @@ def _get_optimal_word_alignment_values(ref_token: str, hyp_token: str): return len(hyp_token), len(ref_token), diag_cost + class OptimalWordAlign: """Optimal word-level alignment based on global-to-local edits (GLE) metric.""" diff --git a/src/error_align/baselines/power/power/levenshtein.py b/src/error_align/baselines/power/power/levenshtein.py index a648ec4..b9d15ed 100644 --- a/src/error_align/baselines/power/power/levenshtein.py +++ b/src/error_align/baselines/power/power/levenshtein.py @@ -206,9 +206,9 @@ def alignment_capacity(self): def hyp_oriented_alignment(self, hyp_only=True): ''' - Returns all alignment tokens. + Returns all alignment tokens. If an S slot is an multiword alignment, duplicates AlignLabels.substitution by the capacity. - TODO: Move to subclass. + TODO: Move to subclass. ''' alignment = [] ref_align_len, hyp_align_len = zip(*self.alignment_capacity()) @@ -411,7 +411,7 @@ def bestPathsGraph(self, minPos=None, maxPos=None): labels=(rlabel, hlabel, align)) chart.appendleft((prev_i, prev_j)) - if time() - start > 120: + if time() - start > 30: print("\nWarning: Long computation time\n") raise AssertionError("Computation took too long") return G diff --git a/src/error_align/baselines/power_alignment.py b/src/error_align/baselines/power_alignment.py index cbbc816..3994843 100644 --- a/src/error_align/baselines/power_alignment.py +++ b/src/error_align/baselines/power_alignment.py @@ -41,20 +41,35 @@ def align(self): alignments = [] for (_, ref_token), (_, hyp_token), (_, align_token) in zip(s1_args, s2_args, align_args): - - if align_token == "C": - op_type = OpType.MATCH - if align_token == "S": - op_type = OpType.SUBSTITUTE - if align_token == "I": - op_type = OpType.INSERT - if align_token == "D": + + # NOTE: The original Power alignments fail in a few edge cases, so we + # implement a simple fix, where the op_type is based on the tokens instead. + + # if align_token == "C": + # op_type = OpType.MATCH + # if align_token == "S": + # op_type = OpType.SUBSTITUTE + # if align_token == "I": + # op_type = OpType.INSERT + # if align_token == "D": + # op_type = OpType.DELETE + + if not ref_token and not hyp_token: + continue + elif ref_token and not hyp_token: op_type = OpType.DELETE + elif not ref_token and hyp_token: + op_type = OpType.INSERT + elif ref_token and hyp_token: + if ref_token.lower() == hyp_token.lower(): + op_type = OpType.MATCH + else: + op_type = OpType.SUBSTITUTE alignment = Alignment( op_type=op_type, - ref=ref_token, - hyp=hyp_token, + ref=None if ref_token == "" else ref_token, + hyp=None if hyp_token == "" else hyp_token, ) alignments.append(alignment) diff --git a/src/error_align/edit_distance.py b/src/error_align/edit_distance.py index 2254737..4e26ad3 100644 --- a/src/error_align/edit_distance.py +++ b/src/error_align/edit_distance.py @@ -97,7 +97,7 @@ def compute_distance_matrix( # Track possible operations (note that the order of operations matters). if backtrace: pos_ops = tuple() - if diag_val == new_val and diag_cost == 0: + if diag_val == new_val and diag_cost <= 0: pos_ops += (OpType.MATCH,) if ins_val == new_val: pos_ops += (OpType.INSERT,) diff --git a/src/error_align/error_align.py b/src/error_align/error_align.py index 56a3bd3..375a6fe 100644 --- a/src/error_align/error_align.py +++ b/src/error_align/error_align.py @@ -1,11 +1,11 @@ -from collections import defaultdict +from dataclasses import dataclass, field +from functools import lru_cache from typing import Union import regex as re -from tqdm import tqdm from error_align.backtrace_graph import BacktraceGraph -from error_align.edit_distance import compute_error_align_distance_matrix +from error_align.edit_distance import compute_error_align_distance_matrix, compute_levenshtein_distance_matrix from error_align.utils import ( END_DELIMITER, START_DELIMITER, @@ -15,10 +15,121 @@ basic_tokenizer, categorize_char, ensure_length_preservation, - get_manhattan_distance, ) +def _embed_tokens(text_tokens: list[str]) -> str: + """Embed tokens with delimiters.""" + return "".join([f"<{t}>" for t in text_tokens]) + + +@lru_cache(maxsize=None) +def _categorize_char_cached(c: str) -> int: + """Cached version of categorize_char for performance.""" + return categorize_char(c) + + +def _get_char_types(text: str) -> list[int]: + """Get character types (0-3) for each character in the text.""" + return [_categorize_char_cached(c) for c in text] + + +def _create_index_map(text_tokens: list[re.Match]) -> list[int]: + """Create an index map for the given tokens. + + The 'index_map' is used to map each aligned character back to its original position in the input text. + + NOTE: -1 is used for delimiter (<>) and indicates no match in the source sequence. + """ + index_map = [] + for match in text_tokens: + index_map.append(-1) # Start delimiter + index_map.extend(range(*match.span())) + index_map.append(-1) # End delimiter + return index_map + + +@dataclass +class SubgraphMetadata: + """Data class to hold information needed for beam search alignment. + + This data class encapsulates all necessary information about a subgraph + derived from the reference and hypothesis texts, including their tokenized + and normalized forms, as well as various derived attributes used during + the alignment process. + + It works as a reference for the `Path` class during beam search alignment. + + Attributes: + ref_raw (str): The full raw reference text. + hyp_raw (str): The full raw hypothesis text. + ref_token_matches (list[re.Match]): List of regex Match objects for reference tokens. + hyp_token_matches (list[re.Match]): List of regex Match objects for hypothesis tokens. + ref_norm (list[str]): List of normalized reference tokens. + hyp_norm (list[str]): List of normalized hypothesis tokens. + ref (str): The embedded reference text with delimiters. + hyp (str): The embedded hypothesis text with delimiters. + ref_max_idx (int): The maximum index in the reference text. + hyp_max_idx (int): The maximum index in the hypothesis text. + ref_char_types (list[int]): List of character types for the reference text. + hyp_char_types (list[int]): List of character types for the hypothesis text. + ref_index_map (list[int]): Index map for the reference text. + hyp_index_map (list[int]): Index map for the hypothesis text. + backtrace_graph (BacktraceGraph): The backtrace graph for the subgraph. + backtrace_node_set (set[tuple[int, int]]): Set of nodes in the backtrace graph. + unambiguous_matches (set[tuple[int, int]]): Set of end node indices for unambiguous token span matches. + """ + + # Init arguments. + ref_raw: str + hyp_raw: str + ref_token_matches: list[re.Match] + hyp_token_matches: list[re.Match] + ref_norm: list[str] + hyp_norm: list[str] + + # NOTE: The *_raw variables corresponds to the full input, even if only a subgraph is being aligned. + # The *_token_matches are computed on the full input so their indices correspond to the full input as well, + # even if only a subset of the tokens is being aligned. + + # Derived attributes. + ref: str = field(init=False) + hyp: str = field(init=False) + ref_max_idx: int = field(init=False) + hyp_max_idx: int = field(init=False) + ref_char_types: list[int] = field(init=False) + hyp_char_types: list[int] = field(init=False) + ref_index_map: list[int] = field(init=False) + hyp_index_map: list[int] = field(init=False) + backtrace_graph: BacktraceGraph = field(init=False) + backtrace_node_set: set[tuple[int, int]] = field(init=False) + unambiguous_matches: set[tuple[int, int]] = field(init=False) + + def __repr__(self): + ref_preview = self.ref if len(self.ref) < 20 else self.ref[:17] + "..." + hyp_preview = self.hyp if len(self.hyp) < 20 else self.hyp[:17] + "..." + return f'SubgraphMetadata(ref="{ref_preview}", hyp="{hyp_preview}")' + + def __post_init__(self): + # Process reference and hypothesis texts and compute derived attributes. + self.ref = _embed_tokens(self.ref_norm) + self.hyp = _embed_tokens(self.hyp_norm) + self.ref_max_idx = len(self.ref) - 1 + self.hyp_max_idx = len(self.hyp) - 1 + self.ref_char_types = _get_char_types(self.ref) + self.hyp_char_types = _get_char_types(self.hyp) + self.ref_index_map = _create_index_map(self.ref_token_matches) + self.hyp_index_map = _create_index_map(self.hyp_token_matches) + + # First pass: Compute backtrace graph. + _, backtrace_matrix = compute_error_align_distance_matrix(self.ref, self.hyp, backtrace=True) + self.backtrace_graph = BacktraceGraph(backtrace_matrix) + # NOTE: Used for backtrace deviation penalty during beam search. + self.backtrace_node_set = self.backtrace_graph.get_node_set() + # NOTE: Used for beam pruning during beam search. + self.unambiguous_matches = self.backtrace_graph.get_unambiguous_token_span_matches(self.ref) + + class ErrorAlign: """Error alignment class that performs a two-pass alignment process.""" @@ -28,6 +139,7 @@ def __init__( hyp: str, tokenizer: callable = basic_tokenizer, normalizer: callable = basic_normalizer, + word_level_pass: bool = False, ): """Initialize the error alignment with reference and hypothesis texts. @@ -47,8 +159,9 @@ def __init__( if not isinstance(hyp, str): raise TypeError("Hypothesis sequence must be a string.") - self.ref = ref - self.hyp = hyp + self.ref_raw = ref + self.hyp_raw = hyp + self.word_level_pass = word_level_pass # Inclusive tokenization: Track the token position in the original text. self._ref_token_matches = tokenizer(ref) @@ -56,71 +169,54 @@ def __init__( # Length-preserving normalization: Ensure that the normalizer preserves token length. normalizer = ensure_length_preservation(normalizer) - self._ref = "".join([f"<{normalizer(r.group())}>" for r in self._ref_token_matches]) - self._hyp = "".join([f"<{normalizer(h.group())}>" for h in self._hyp_token_matches]) - - # Categorize characters. - self._ref_char_types = list(map(categorize_char, self._ref)) - self._hyp_char_types = list(map(categorize_char, self._hyp)) - - # Initialize graph attributes. - self._identical_inputs = self._ref == self._hyp - self._ref_max_idx = len(self._ref) - 1 - self._hyp_max_idx = len(self._hyp) - 1 - self.end_index = (self._hyp_max_idx, self._ref_max_idx) - - # Create index maps for reference and hypothesis sequences. - self._ref_index_map = self._create_index_map(self._ref_token_matches) - self._hyp_index_map = self._create_index_map(self._hyp_token_matches) - - # First pass: Extract backtrace graph. - if not self._identical_inputs: - _, backtrace_matrix = compute_error_align_distance_matrix(self._ref, self._hyp, backtrace=True) - self._backtrace_graph = BacktraceGraph(backtrace_matrix) - self._backtrace_node_set = self._backtrace_graph.get_node_set() - self._unambiguous_matches = self._backtrace_graph.get_unambiguous_matches(self._ref) - else: - self._backtrace_graph = None - self._backtrace_node_set = None - self._unambiguous_matches = None + self._ref_norm = [normalizer(r.group()) for r in self._ref_token_matches] + self._hyp_norm = [normalizer(h.group()) for h in self._hyp_token_matches] + self._identical_inputs = self._ref_norm == self._hyp_norm - def __repr__(self): - ref_preview = self.ref if len(self.ref) < 20 else self.ref[:17] + "..." - hyp_preview = self.hyp if len(self.hyp) < 20 else self.hyp[:17] + "..." - return f'ErrorAlign(ref="{ref_preview}", hyp="{hyp_preview}")' + if self._identical_inputs: + self._src = None + elif word_level_pass: + self._src = self._prepare_subspans_with_word_level_pass() + else: + self._src = SubgraphMetadata( + ref_raw=self.ref_raw, + hyp_raw=self.hyp_raw, + ref_token_matches=self._ref_token_matches, + hyp_token_matches=self._hyp_token_matches, + ref_norm=self._ref_norm, + hyp_norm=self._hyp_norm, + ) - def align( - self, - beam_size: int = 100, - pbar: bool = False, - return_path: bool = False, - ) -> Union[list[Alignment], "Path"]: + def align(self, beam_size: int = 100) -> Union[list[Alignment], "Path"]: """Perform beam search to align reference and hypothesis texts. Args: beam_size (int): The size of the beam for beam search. Defaults to 100. - pbar (bool): Whether to display a progress bar. Defaults to False. - return_path (bool): Whether to return the path object or just the alignments. Defaults to False. Returns: list[Alignment]: A list of Alignment objects. """ - # Skip beam search if inputs are identical. if self._identical_inputs: - return self._identical_input_alignments() + return self._align_identical_inputs() + elif self.word_level_pass: + return self._align_post_word_level(self._src, beam_size=beam_size) + else: + return self._beam_search_alignment(self._src, beam_size=beam_size) + + def _beam_search_alignment( + self, + src: SubgraphMetadata, + beam_size: int = 100, + ) -> Union[list[Alignment], "Path"]: + """Perform beam search to align reference and hypothesis texts for a given source.""" # Initialize the beam with a single path starting at the root node. - start_path = Path(self) + start_path = Path(src) beam = {start_path.pid: start_path} - prune_map = defaultdict(lambda: float("inf")) + prune_map = dict() ended = [] - # Setup progress bar, if enabled. - if pbar: - total_mdist = self._ref_max_idx + self._hyp_max_idx + 2 - progress_bar = tqdm(total=total_mdist, desc="Aligning transcripts") - # Expand candidate paths until all have reached the terminal node. while len(beam) > 0: new_beam = {} @@ -133,13 +229,15 @@ def align( # Transition to all child nodes. for new_path in path.expand(): - if new_path.pid in prune_map: - if new_path.cost > prune_map[new_path.pid]: + new_path_cost = new_path.cost + new_path_pid = new_path.pid + if new_path_pid in prune_map: + if new_path_cost > prune_map[new_path_pid]: continue - prune_map[new_path.pid] = new_path.cost + prune_map[new_path_pid] = new_path_cost - if new_path.pid not in new_beam or new_path.cost < new_beam[new_path.pid].cost: - new_beam[new_path.pid] = new_path + if new_path_pid not in new_beam or new_path_cost < new_beam[new_path_pid].cost: + new_beam[new_path_pid] = new_path # Update the beam with the newly expanded paths. new_beam = list(new_beam.values()) @@ -149,141 +247,219 @@ def align( # Keep only the best path if, it matches the segment. if len(beam) > 0 and beam[0]._at_unambiguous_match_node: beam = beam[:1] - prune_map = defaultdict(lambda: float("inf")) + prune_map = dict() beam = {p.pid: p for p in beam} # Convert to dict for diversity check. - # Update progress bar, if enabled. - try: - worst_path = next(reversed(beam.values())) - mdist = get_manhattan_distance(worst_path.index, self.end_index) - if pbar: - progress_bar.n = total_mdist - mdist - progress_bar.refresh() - except StopIteration: - if pbar: - progress_bar.n = total_mdist - progress_bar.refresh() - # Return the best path or its alignments. ended.sort(key=lambda p: p.cost) - if return_path: - return ended[0] if len(ended) > 0 else None - return ended[0].alignments if len(ended) > 0 else [] - - def _create_index_map(self, text_tokens: list[re.Match]) -> list[int]: - """Create an index map for the given tokens. + return ended[0].get_alignments() if len(ended) > 0 else [] - The 'index_map' is used to map each aligned character back to its original position in the input text. - - NOTE: -1 is used for delimiter (<>) and indicates no match in the source sequence. - """ - index_map = [] - for match in text_tokens: - index_map.extend([-1]) # Start delimiter - index_map.extend(list(range(*match.span()))) - index_map.extend([-1]) # End delimiter - return index_map - - def _identical_input_alignments(self) -> list[Alignment]: + def _align_identical_inputs(self) -> list[Alignment]: """Return alignments for identical reference and hypothesis pairs.""" assert self._identical_inputs, "Inputs are not identical." alignments = [] - for ref_match, hyp_match in zip(self._ref_token_matches, self._hyp_token_matches, strict=False): - ref_slice = slice(*ref_match.span()) - hyp_slice = slice(*hyp_match.span()) - ref_token = self.ref[ref_slice] - hyp_token = self.hyp[hyp_slice] - alignment = Alignment( - op_type=OpType.MATCH, - ref_slice=ref_slice, - hyp_slice=hyp_slice, - ref=ref_token, - hyp=hyp_token, - ) + for i in range(len(self._ref_token_matches)): + alignment = self._get_match_alignment_from_token_indices(i, i) alignments.append(alignment) return alignments + def _align_post_word_level( + self, + src: list[tuple[OpType, Union[SubgraphMetadata, range, tuple[int, int]]]], + beam_size: int, + ) -> list[Alignment]: + """Perform alignment after a word-level pass.""" + alignments = [] + for op_type, src_ in src: + if op_type == OpType.MATCH: + alignment = self._get_match_alignment_from_token_indices(*src_) + alignments.append(alignment) + elif op_type in (OpType.INSERT, OpType.DELETE): + alignment_ = [self._get_insert_or_delete_alignment_from_token_index(op_type, i) for i in src_] + alignments.extend(alignment_) + else: + alignments_ = self._beam_search_alignment(src=src_, beam_size=beam_size) + alignments.extend(alignments_) + + return alignments + + def _get_match_alignment_from_token_indices(self, hyp_idx: int, ref_idx: int) -> Alignment: + """Get a MATCH alignment for the given token indices.""" + ref_token_match = self._ref_token_matches[ref_idx] + hyp_token_match = self._hyp_token_matches[hyp_idx] + ref_slice = slice(*ref_token_match.span()) + hyp_slice = slice(*hyp_token_match.span()) + alignment = Alignment( + op_type=OpType.MATCH, + ref_slice=ref_slice, + hyp_slice=hyp_slice, + ref=self.ref_raw[ref_slice], + hyp=self.hyp_raw[hyp_slice], + ) + return alignment + + def _get_insert_or_delete_alignment_from_token_index( + self, + op_type: Union[OpType.INSERT, OpType.DELETE], + token_idx: int, + ) -> Alignment: + """Get an INSERT or DELETE alignment for the given token index.""" + if op_type == OpType.INSERT: + token_match = self._hyp_token_matches[token_idx] + slice_ = slice(*token_match.span()) + token = self.hyp_raw[slice_] + return Alignment( + op_type=op_type, + hyp_slice=slice_, + hyp=token, + ) + elif op_type == OpType.DELETE: + token_match = self._ref_token_matches[token_idx] + slice_ = slice(*token_match.span()) + token = self.ref_raw[slice_] + return Alignment( + op_type=op_type, + ref_slice=slice_, + ref=token, + ) + else: + raise ValueError(f"Invalid operation type for insert/delete alignment: {op_type}") + + def _prepare_subspans_with_word_level_pass( + self, + ) -> list[tuple[OpType, Union[SubgraphMetadata, range, tuple[int, int]]]]: + """Perform a word-level alignment pass to identify unambiguous matches.""" + + # Extract the word-level backtrace graph. + _, backtrace_matrix = compute_levenshtein_distance_matrix(self._ref_norm, self._hyp_norm, backtrace=True) + backtrace_graph = BacktraceGraph(backtrace_matrix) + match_indices = backtrace_graph.get_unambiguous_node_matches() + match_indices = match_indices + [(len(self._hyp_norm), len(self._ref_norm))] + + # Iterate over the unambiguous matches to extract subspans (i.e., the span of words between two matches). + hyp_start, ref_start = (0, 0) + subspans = [] + end_index = len(match_indices) - 1 + for i, (hyp_end, ref_end) in enumerate(match_indices): + ref_is_empty = ref_start == ref_end + hyp_is_empty = hyp_start == hyp_end + + # NOTE: Subspans where ref xor hyp is empty are guaranteed to be all INSERT or DELETE ops. + if ref_is_empty and hyp_is_empty: + pass + elif not ref_is_empty and not hyp_is_empty: + src = self._get_subgraph_metadata(ref_start, ref_end, hyp_start, hyp_end) + subspans.append((OpType.SUBSTITUTE, src)) + elif ref_is_empty: + subspans.append((OpType.INSERT, range(hyp_start, hyp_end))) + elif hyp_is_empty: + subspans.append((OpType.DELETE, range(ref_start, ref_end))) + + if i < end_index: + subspans.append((OpType.MATCH, (hyp_end, ref_end))) + ref_start, hyp_start = (ref_end + 1, hyp_end + 1) + + return subspans + + def _get_subgraph_metadata(self, ref_start, ref_end, hyp_start, hyp_end) -> SubgraphMetadata: + """Extract subgraph metadata for the given reference and hypothesis slices.""" + return SubgraphMetadata( + ref_raw=self.ref_raw, + hyp_raw=self.hyp_raw, + ref_token_matches=self._ref_token_matches[ref_start:ref_end], + hyp_token_matches=self._hyp_token_matches[hyp_start:hyp_end], + ref_norm=self._ref_norm[ref_start:ref_end], + hyp_norm=self._hyp_norm[hyp_start:hyp_end], + ) + + def __repr__(self): + ref_preview = self.ref_raw if len(self.ref_raw) < 20 else self.ref_raw[:17] + "..." + hyp_preview = self.hyp_raw if len(self.hyp_raw) < 20 else self.hyp_raw[:17] + "..." + return f'ErrorAlign(ref="{ref_preview}", hyp="{hyp_preview}")' + class Path: """Class to represent a graph path.""" - def __init__(self, src: ErrorAlign): + __slots__ = ( + "src", + "ref_idx", + "hyp_idx", + "last_ref_idx", + "last_hyp_idx", + "_closed_cost", + "_open_cost", + "_at_unambiguous_match_node", + "_end_indices", + "_alignments", + "_alignments_index", + ) + + def __init__(self, src: SubgraphMetadata): """Initialize the Path class with a given path.""" self.src = src self.ref_idx = -1 self.hyp_idx = -1 + self.last_hyp_idx = -1 + self.last_ref_idx = -1 self._closed_cost = 0 self._open_cost = 0 self._at_unambiguous_match_node = False - self._last_end_index = (-1, -1) self._end_indices = tuple() self._alignments = None self._alignments_index = None - def __repr__(self): - return f"Path(({self.ref_idx}, {self.hyp_idx}), score={self.cost})" + @property + def pid(self): + """Get the ID of the path used for pruning.""" + return hash((self.hyp_idx, self.ref_idx, self.last_hyp_idx, self.last_ref_idx)) + + @property + def cost(self): + """Get the cost of the path.""" + return self._closed_cost + self._open_cost + self._substitution_penalty(self.hyp_idx, self.ref_idx) + + @property + def norm_cost(self): + """Get the normalized cost of the path.""" + cost = self.cost + if cost == 0: + return 0 + return cost / (self.ref_idx + self.hyp_idx + 3) # NOTE: +3 to avoid zero division. Root = (-1,-1). + + @property + def index(self): + """Get the current node index of the path.""" + return (self.hyp_idx, self.ref_idx) @property - def alignments(self) -> list[Alignment]: + def at_end(self): + """Check if the path has reached the terminal node.""" + return self.hyp_idx == self.src.hyp_max_idx and self.ref_idx == self.src.ref_max_idx + + def get_alignments(self) -> list[Alignment]: """Get the alignments of the path.""" + # Return cached alignments if available and the path has not changed. if self._alignments is not None and self._alignments_index == self.index: return self._alignments + # Compute alignments from the segment end indices. self._alignments_index = self.index alignments = [] start_hyp, start_ref = (0, 0) - for (end_hyp, end_ref), score in self._end_indices: + for end_hyp, end_ref, score in self._end_indices: end_hyp, end_ref = end_hyp + 1, end_ref + 1 - # Construct DELETE alignment. if start_hyp == end_hyp: - assert start_ref < end_ref - ref_slice = slice(start_ref, end_ref) - ref_slice = self._translate_slice(ref_slice, self.src._ref_index_map) - assert ref_slice is not None - alignment = Alignment( - op_type=OpType.DELETE, - ref_slice=ref_slice, - ref=self.src.ref[ref_slice], - ) + alignment = self._get_delete_alignment(start_ref, end_ref) alignments.append(alignment) - - # Construct INSERT alignment. elif start_ref == end_ref: - assert start_hyp < end_hyp - hyp_slice = slice(start_hyp, end_hyp) - hyp_slice = self._translate_slice(hyp_slice, self.src._hyp_index_map) - assert hyp_slice is not None - alignment = Alignment( - op_type=OpType.INSERT, - hyp_slice=hyp_slice, - hyp=self.src.hyp[hyp_slice], - left_compound=self.src._hyp_index_map[start_hyp] >= 0, - right_compound=self.src._hyp_index_map[end_hyp - 1] >= 0, - ) + alignment = self._get_insert_alignment(start_hyp, end_hyp) alignments.append(alignment) - - # Construct SUBSTITUTE or MATCH alignment. else: - assert start_hyp < end_hyp and start_ref < end_ref - hyp_slice = slice(start_hyp, end_hyp) - ref_slice = slice(start_ref, end_ref) - hyp_slice = self._translate_slice(hyp_slice, self.src._hyp_index_map) - ref_slice = self._translate_slice(ref_slice, self.src._ref_index_map) - assert hyp_slice is not None and ref_slice is not None - is_match_segment = score == 0 - op_type = OpType.MATCH if is_match_segment else OpType.SUBSTITUTE - alignment = Alignment( - op_type=op_type, - ref_slice=ref_slice, - hyp_slice=hyp_slice, - ref=self.src.ref[ref_slice], - hyp=self.src.hyp[hyp_slice], - left_compound=self.src._hyp_index_map[start_hyp] >= 0, - right_compound=self.src._hyp_index_map[end_hyp - 1] >= 0, - ) + alignment = self._get_match_or_substitution_alignment(start_hyp, end_hyp, start_ref, end_ref, score) alignments.append(alignment) start_hyp, start_ref = end_hyp, end_ref @@ -293,33 +469,6 @@ def alignments(self) -> list[Alignment]: return alignments - @property - def pid(self): - """Get the ID of the path used for pruning.""" - return hash((self.index, self._last_end_index)) - - @property - def cost(self): - """Get the cost of the path.""" - return self._closed_cost + self._open_cost + self._substitution_penalty() - - @property - def norm_cost(self): - """Get the normalized cost of the path.""" - if self.cost == 0: - return 0 - return self.cost / (self.ref_idx + self.hyp_idx + 3) # NOTE: +3 to avoid zero division. Root = (-1,-1). - - @property - def index(self): - """Get the current node index of the path.""" - return (self.hyp_idx, self.ref_idx) - - @property - def at_end(self): - """Check if the path has reached the terminal node.""" - return self.index == self.src.end_index - def expand(self): """Expand the path by transitioning to child nodes. @@ -342,102 +491,152 @@ def expand(self): if sub_or_match_path is not None: yield sub_or_match_path - def _transition_and_shallow_copy(self, ref_step: int, hyp_step: int): - """Create a shallow copy of the path.""" - new_path = Path(self.src) + def _get_delete_alignment(self, start_ref_idx: int, end_ref_idx: int) -> Alignment: + """Get a DELETE alignment for a given reference slice.""" + ref_slice = slice(start_ref_idx, end_ref_idx) + ref_slice = self._translate_slice(ref_slice, self.src.ref_index_map) + return Alignment( + op_type=OpType.DELETE, + ref_slice=ref_slice, + ref=self.src.ref_raw[ref_slice], + ) + + def _get_insert_alignment(self, start_hyp_idx: int, end_hyp_idx: int) -> Alignment: + """Get an INSERT alignment for a given hypothesis slice.""" + hyp_slice = slice(start_hyp_idx, end_hyp_idx) + hyp_slice = self._translate_slice(hyp_slice, self.src.hyp_index_map) + return Alignment( + op_type=OpType.INSERT, + hyp_slice=hyp_slice, + hyp=self.src.hyp_raw[hyp_slice], + ) + + def _get_match_or_substitution_alignment( + self, + start_hyp_idx: int, + end_hyp_idx: int, + start_ref_idx: int, + end_ref_idx: int, + score: int, + ) -> Alignment: + """Get a MATCH or SUBSTITUTE alignment for given hypothesis and reference slices.""" + hyp_slice = slice(start_hyp_idx, end_hyp_idx) + ref_slice = slice(start_ref_idx, end_ref_idx) + hyp_slice = self._translate_slice(hyp_slice, self.src.hyp_index_map) + ref_slice = self._translate_slice(ref_slice, self.src.ref_index_map) + is_match_segment = score == 0 + op_type = OpType.MATCH if is_match_segment else OpType.SUBSTITUTE + return Alignment( + op_type=op_type, + ref_slice=ref_slice, + hyp_slice=hyp_slice, + ref=self.src.ref_raw[ref_slice], + hyp=self.src.hyp_raw[hyp_slice], + left_compound=self.src.hyp_index_map[start_hyp_idx] >= 0, + right_compound=self.src.hyp_index_map[end_hyp_idx - 1] >= 0, + ) + + def _transition_to_child_node(self, ref_step: int, hyp_step: int): + """Transition to a child node by creating a new Path instance.""" + new_path = Path.__new__(Path) # NOTE: Bypass __init__ for shallow copy. + new_path.src = self.src new_path.ref_idx = self.ref_idx + ref_step new_path.hyp_idx = self.hyp_idx + hyp_step + new_path.last_hyp_idx = self.last_hyp_idx + new_path.last_ref_idx = self.last_ref_idx new_path._closed_cost = self._closed_cost new_path._open_cost = self._open_cost new_path._at_unambiguous_match_node = False - new_path._last_end_index = self._last_end_index new_path._end_indices = self._end_indices + new_path._alignments = None + new_path._alignments_index = None return new_path - def _reset_segment_variables(self, index: tuple[int, int]) -> None: + def _reset_segment_variables(self, hyp_idx: int, ref_idx: int) -> None: """Apply updates when segment end is detected.""" self._closed_cost += self._open_cost - self._closed_cost += self._substitution_penalty(index) - self._last_end_index = index + self._closed_cost += self._substitution_penalty(hyp_idx, ref_idx) + self.last_hyp_idx = hyp_idx + self.last_ref_idx = ref_idx self._open_cost = 0 - def _end_insertion_segment(self, index: tuple[int, int]) -> None: + def _end_insertion_segment(self, hyp_idx: int, ref_idx: int) -> None: """End the current segment, if criteria for an insertion are met.""" - hyp_slice = slice(self._last_end_index[0] + 1, index[0] + 1) - hyp_slice = self._translate_slice(hyp_slice, self.src._hyp_index_map) - ref_is_empty = index[1] == self._last_end_index[1] + hyp_slice = slice(self.last_hyp_idx + 1, hyp_idx + 1) + hyp_slice = self._translate_slice(hyp_slice, self.src.hyp_index_map) + ref_is_empty = ref_idx == self.last_ref_idx if hyp_slice is not None and ref_is_empty: - self._end_indices += ((index, self._open_cost),) - self._reset_segment_variables(index) + self._end_indices += ((hyp_idx, ref_idx, self._open_cost),) + self._reset_segment_variables(hyp_idx, ref_idx) def _end_segment(self) -> Union[None, "Path"]: """End the current segment, if criteria for an insertion, a substitution, or a match are met.""" - hyp_slice = slice(self._last_end_index[0] + 1, self.index[0] + 1) - hyp_slice = self._translate_slice(hyp_slice, self.src._hyp_index_map) - ref_slice = slice(self._last_end_index[1] + 1, self.index[1] + 1) - ref_slice = self._translate_slice(ref_slice, self.src._ref_index_map) + hyp_slice = slice(self.last_hyp_idx + 1, self.hyp_idx + 1) + hyp_slice = self._translate_slice(hyp_slice, self.src.hyp_index_map) + ref_slice = slice(self.last_ref_idx + 1, self.ref_idx + 1) + ref_slice = self._translate_slice(ref_slice, self.src.ref_index_map) assert ref_slice is not None - hyp_is_empty = self.index[0] == self._last_end_index[0] + hyp_is_empty = self.hyp_idx == self.last_hyp_idx if hyp_is_empty: - self._end_indices += ((self.index, self._open_cost),) + self._end_indices += ((self.hyp_idx, self.ref_idx, self._open_cost),) else: # TODO: Handle edge case where hyp has only covered delimiters. if hyp_slice is None: return None is_match_segment = self._open_cost == 0 - self._at_unambiguous_match_node = is_match_segment and self.index in self.src._unambiguous_matches - self._end_indices += ((self.index, self._open_cost),) + self._at_unambiguous_match_node = is_match_segment and self.index in self.src.unambiguous_matches + self._end_indices += ((self.hyp_idx, self.ref_idx, self._open_cost),) # Update the path score and reset segments attributes. - self._reset_segment_variables(self.index) + self._reset_segment_variables(self.hyp_idx, self.ref_idx) return self def _in_backtrace_node_set(self, index) -> bool: """Check if the given operation is an optimal transition at the current index.""" - return index in self.src._backtrace_node_set + return index in self.src.backtrace_node_set def _add_delete(self) -> Union[None, "Path"]: """Expand the path by adding a delete operation.""" # Ensure we are not at the end of the hypothesis sequence. - if self.hyp_idx >= self.src._hyp_max_idx: + if self.hyp_idx >= self.src.hyp_max_idx: return None # Transition and update costs. - new_path = self._transition_and_shallow_copy(ref_step=0, hyp_step=1) + new_path = self._transition_to_child_node(ref_step=0, hyp_step=1) is_backtrace = self._in_backtrace_node_set(self.index) - is_delimiter = self.src._hyp_char_types[new_path.hyp_idx] == 0 # NOTE: 0 indicates delimiter. + is_delimiter = self.src.hyp_char_types[new_path.hyp_idx] == 0 # NOTE: 0 indicates delimiter. new_path._open_cost += 1 if is_delimiter else 2 new_path._open_cost += 0 if is_backtrace or is_delimiter else 1 # Check for end-of-segment criteria. - if self.src._hyp[new_path.hyp_idx] == END_DELIMITER: - new_path._end_insertion_segment(new_path.index) + if self.src.hyp[new_path.hyp_idx] == END_DELIMITER: + new_path._end_insertion_segment(new_path.hyp_idx, new_path.ref_idx) return new_path def _add_insert(self) -> Union[None, "Path"]: """Expand the path by adding an insert operation.""" # Ensure we are not at the end of the reference sequence. - if self.ref_idx >= self.src._ref_max_idx: + if self.ref_idx >= self.src.ref_max_idx: return None # Transition and check for end-of-segment criteria. - new_path = self._transition_and_shallow_copy(ref_step=1, hyp_step=0) - if self.src._ref[new_path.ref_idx] == START_DELIMITER: - new_path._end_insertion_segment(self.index) + new_path = self._transition_to_child_node(ref_step=1, hyp_step=0) + if self.src.ref[new_path.ref_idx] == START_DELIMITER: + new_path._end_insertion_segment(self.hyp_idx, self.ref_idx) # Update costs. is_backtrace = self._in_backtrace_node_set(self.index) - is_delimiter = self.src._ref_char_types[new_path.ref_idx] == 0 # NOTE: 0 indicates delimiter. + is_delimiter = self.src.ref_char_types[new_path.ref_idx] == 0 # NOTE: 0 indicates delimiter. new_path._open_cost += 1 if is_delimiter else 2 new_path._open_cost += 0 if is_backtrace or is_delimiter else 1 # Check for end-of-segment criteria. - if self.src._ref[new_path.ref_idx] == END_DELIMITER: + if self.src.ref[new_path.ref_idx] == END_DELIMITER: new_path = new_path._end_segment() return new_path @@ -445,33 +644,33 @@ def _add_insert(self) -> Union[None, "Path"]: def _add_substitution_or_match(self) -> Union[None, "Path"]: """Expand the given path by adding a substitution or match operation.""" # Ensure we are not at the end of either sequence. - if self.ref_idx >= self.src._ref_max_idx or self.hyp_idx >= self.src._hyp_max_idx: + if self.ref_idx >= self.src.ref_max_idx or self.hyp_idx >= self.src.hyp_max_idx: return None # Transition and ensure that the transition is allowed. - new_path = self._transition_and_shallow_copy(ref_step=1, hyp_step=1) - is_match = self.src._ref[new_path.ref_idx] == self.src._hyp[new_path.hyp_idx] + new_path = self._transition_to_child_node(ref_step=1, hyp_step=1) + is_match = self.src.ref[new_path.ref_idx] == self.src.hyp[new_path.hyp_idx] if not is_match: - ref_is_delimiter = self.src._ref_char_types[new_path.ref_idx] == 0 # NOTE: 0 indicates delimiter - hyp_is_delimiter = self.src._hyp_char_types[new_path.hyp_idx] == 0 # NOTE: 0 indicates delimiter + ref_is_delimiter = self.src.ref_char_types[new_path.ref_idx] == 0 # NOTE: 0 indicates delimiter + hyp_is_delimiter = self.src.hyp_char_types[new_path.hyp_idx] == 0 # NOTE: 0 indicates delimiter if ref_is_delimiter or hyp_is_delimiter: return None # Check for end-of-segment criteria. - if self.src._ref[new_path.ref_idx] == START_DELIMITER: - new_path._end_insertion_segment(self.index) + if self.src.ref[new_path.ref_idx] == START_DELIMITER: + new_path._end_insertion_segment(self.hyp_idx, self.ref_idx) # Update costs, if not a match. if not is_match: is_backtrace = self._in_backtrace_node_set(self.index) is_letter_type_match = ( - self.src._ref_char_types[new_path.ref_idx] == self.src._hyp_char_types[new_path.hyp_idx] + self.src.ref_char_types[new_path.ref_idx] == self.src.hyp_char_types[new_path.hyp_idx] ) new_path._open_cost += 2 if is_letter_type_match else 3 new_path._open_cost += 0 if is_backtrace else 1 # Check for end-of-segment criteria. - if self.src._ref[new_path.ref_idx] == END_DELIMITER: + if self.src.ref[new_path.ref_idx] == END_DELIMITER: new_path = new_path._end_segment() return new_path @@ -485,11 +684,12 @@ def _translate_slice(self, segment_slice: slice, index_map: list[int]) -> None | start, end = int(slice_indices[0]), int(slice_indices[-1] + 1) return slice(start, end) - def _substitution_penalty(self, index: tuple[int, int] | None = None) -> int: + def _substitution_penalty(self, hyp_idx: int, ref_idx: int) -> int: """Get the substitution penalty given an index.""" - index = index or self.index - ref_is_not_empty = index[1] > self._last_end_index[1] - hyp_is_not_empty = index[0] > self._last_end_index[0] - if ref_is_not_empty and hyp_is_not_empty: - return self._open_cost - return 0 + # NOTE: Since *_idx is guaranteed to be equal to or higher than last_*_idx, we only need to check for equality. + if ref_idx == self.last_ref_idx or hyp_idx == self.last_hyp_idx: + return 0 + return self._open_cost + + def __repr__(self): + return f"Path(({self.ref_idx}, {self.hyp_idx}), score={self.cost})" diff --git a/src/error_align/func.py b/src/error_align/func.py index 6ac8dfc..2279350 100644 --- a/src/error_align/func.py +++ b/src/error_align/func.py @@ -8,8 +8,7 @@ def error_align( tokenizer: callable = basic_tokenizer, normalizer: callable = basic_normalizer, beam_size: int = 100, - pbar: bool = False, - return_path: bool = False, + word_level_pass: bool = True, ) -> list[Alignment] | Path: """Perform error alignment between two sequences. @@ -18,8 +17,9 @@ def error_align( hyp (str): The hypothesis sequence/transcript. tokenizer (callable): A function to tokenize the sequences. Must be regex-based and return Match objects. normalizer (callable): A function to normalize the tokens. Defaults to basic_normalizer. - pbar (bool): Whether to display a progress bar. Defaults to False. - return_path (bool): Whether to return the path object or just the alignments. Defaults to False. + beam_size (int): The beam size for beam search alignment. Defaults to 100. + word_level_pass (bool): Use an initial word-level pass to identify unambiguous matches. Defaults to True. + Note that this is not described in the original paper. Returns: list[tuple[str, str, OpType]]: A list of tuples containing aligned reference token, @@ -31,8 +31,7 @@ def error_align( hyp, tokenizer=tokenizer, normalizer=normalizer, + word_level_pass=word_level_pass, ).align( beam_size=beam_size, - pbar=pbar, - return_path=return_path, ) diff --git a/src/error_align/utils.py b/src/error_align/utils.py index f8f4567..c5f776a 100644 --- a/src/error_align/utils.py +++ b/src/error_align/utils.py @@ -25,6 +25,22 @@ class Alignment: left_compound: bool = False right_compound: bool = False + def __post_init__(self): + if self.op_type == OpType.MATCH: + if self.ref is None or self.hyp is None: + raise ValueError("MATCH operation must have non-empty ref or hyp.") + if self.left_compound or self.right_compound: + raise ValueError("MATCH operation cannot have compound markers.") + elif self.op_type == OpType.INSERT: + if self.hyp is None or self.ref is not None: + raise ValueError("INSERT operation must have non-empty hyp and empty ref.") + elif self.op_type == OpType.DELETE: + if self.hyp is not None or self.ref is None: + raise ValueError("DELETE operation must have non-empty ref and empty hyp.") + elif self.op_type == OpType.SUBSTITUTE: + if self.ref is None or self.hyp is None: + raise ValueError("SUBSTITUTE operation must have both ref and hyp.") + @property def hyp_with_compound_markers(self) -> str: """Return the hypothesis with compound markers if applicable.""" @@ -36,7 +52,7 @@ def __repr__(self) -> str: if self.op_type == OpType.DELETE: return f'Alignment({self.op_type.name}: "{self.ref}")' if self.op_type == OpType.INSERT: - return f'Alignment({self.op_type.name}: {self.hyp_with_compound_markers})' + return f"Alignment({self.op_type.name}: {self.hyp_with_compound_markers})" if self.op_type == OpType.SUBSTITUTE: return f'Alignment({self.op_type.name}: {self.hyp_with_compound_markers} -> "{self.ref}")' return f'Alignment({self.op_type.name}: "{self.hyp}" == "{self.ref}")' @@ -113,11 +129,6 @@ def categorize_char(c: str) -> int: return 3 # NOTE: Unvoiced characters (only apostrophes are expected by default). -def get_manhattan_distance(a: tuple[int, int], b: tuple[int, int]) -> int: - """Calculate the Manhattan distance between two points a and b.""" - return abs(a[0] - b[0]) + abs(a[1] - b[1]) - - def basic_tokenizer(text: str) -> list: """Default tokenizer that splits text into words based on whitespace. diff --git a/tests/test_default.py b/tests/test_default.py index 7e1b79e..8ecba22 100644 --- a/tests/test_default.py +++ b/tests/test_default.py @@ -2,6 +2,7 @@ from error_align import ErrorAlign, error_align from error_align.edit_distance import compute_levenshtein_distance_matrix +from error_align.error_align import Path from error_align.utils import OpType, categorize_char @@ -11,7 +12,7 @@ def test_error_align() -> None: ref = "This is a substitution test deleted." hyp = "Inserted this is a contribution test." - alignments = error_align(ref, hyp, pbar=True) + alignments = error_align(ref, hyp) expected_ops = [ OpType.INSERT, # Inserted OpType.MATCH, # This @@ -73,8 +74,8 @@ def test_representations() -> None: assert repr(ea) == 'ErrorAlign(ref="test", hyp="pest")' # Test Path class representation - path = ea.align(beam_size=10, return_path=True) - assert repr(path) == f"Path(({path.ref_idx}, {path.hyp_idx}), score={path.cost})" + path = Path(src=ea._src) + assert repr(path) == "Path((-1, -1), score=0)" @suppress_type_checks @@ -99,9 +100,9 @@ def test_backtrace_graph() -> None: hyp = "This is a pest." # Create ErrorAlign instance and generate backtrace graph. - ea = ErrorAlign(ref, hyp) + ea = ErrorAlign(ref, hyp, word_level_pass=False) ea.align(beam_size=10) - graph = ea._backtrace_graph + graph = ea._src.backtrace_graph # Check basic properties of the graph. assert isinstance(graph.get_path(), list) From 6e5efbd05c8c8c7e8c6aa5bd91f3cb8d49a123ac Mon Sep 17 00:00:00 2001 From: Lasse Borgholt Date: Tue, 21 Oct 2025 10:32:56 +0200 Subject: [PATCH 2/6] Type hinting and consistency --- src/error_align/backtrace_graph.py | 17 +++++++++++------ src/error_align/error_align.py | 30 +++++++++++++++--------------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/src/error_align/backtrace_graph.py b/src/error_align/backtrace_graph.py index c5cd26c..632fe3a 100644 --- a/src/error_align/backtrace_graph.py +++ b/src/error_align/backtrace_graph.py @@ -23,11 +23,11 @@ def __init__(self, hyp_idx, ref_idx) -> None: self._outgoing_edge_counts = {} @property - def index(self): + def index(self) -> tuple[int, int]: return (self.hyp_idx, self.ref_idx) @property - def offset_index(self): + def offset_index(self) -> tuple[int, int]: """Get the offset index of the node so indices match the hypothesis and reference strings. Root will be at (-1, -1). @@ -42,35 +42,40 @@ def offset_index(self): def number_of_paths(self): return self._bwd_node_count * self._fwd_node_count - def number_of_ingoing_paths_via(self, op_type: OpType): + def number_of_ingoing_paths_via(self, op_type: OpType) -> int: """Get the number of paths going through this node via the given operation type. Args: op_type (OpType): The operation type to check. + Returns: + int: The number of ingoing paths via the given operation type. """ if op_type not in self.parents: return 0 return self._ingoing_edge_counts[op_type] * self.parents[op_type]._outgoing_edge_counts[op_type] - def number_of_outgoing_paths_via(self, op_type: OpType): + def number_of_outgoing_paths_via(self, op_type: OpType) -> int: """Get the number of paths going through this node via the given operation type. Args: op_type (OpType): The operation type to check. + Returns: + int: The number of outgoing paths via the given operation type. + """ if op_type not in self.children: return 0 return self._outgoing_edge_counts[op_type] * self.children[op_type]._ingoing_edge_counts[op_type] @property - def is_terminal(self): + def is_terminal(self) -> bool: """Check if the node is a terminal node (i.e., it has no children).""" return len(self.children) == 0 @property - def is_root(self): + def is_root(self) -> bool: """Check if the node is a root node (i.e., it has no parents).""" return len(self.parents) == 0 diff --git a/src/error_align/error_align.py b/src/error_align/error_align.py index 375a6fe..70ee3f7 100644 --- a/src/error_align/error_align.py +++ b/src/error_align/error_align.py @@ -73,8 +73,8 @@ class SubgraphMetadata: hyp_max_idx (int): The maximum index in the hypothesis text. ref_char_types (list[int]): List of character types for the reference text. hyp_char_types (list[int]): List of character types for the hypothesis text. - ref_index_map (list[int]): Index map for the reference text. - hyp_index_map (list[int]): Index map for the hypothesis text. + ref_idx_map (list[int]): Index map for the reference text. + hyp_idx_map (list[int]): Index map for the hypothesis text. backtrace_graph (BacktraceGraph): The backtrace graph for the subgraph. backtrace_node_set (set[tuple[int, int]]): Set of nodes in the backtrace graph. unambiguous_matches (set[tuple[int, int]]): Set of end node indices for unambiguous token span matches. @@ -99,8 +99,8 @@ class SubgraphMetadata: hyp_max_idx: int = field(init=False) ref_char_types: list[int] = field(init=False) hyp_char_types: list[int] = field(init=False) - ref_index_map: list[int] = field(init=False) - hyp_index_map: list[int] = field(init=False) + ref_idx_map: list[int] = field(init=False) + hyp_idx_map: list[int] = field(init=False) backtrace_graph: BacktraceGraph = field(init=False) backtrace_node_set: set[tuple[int, int]] = field(init=False) unambiguous_matches: set[tuple[int, int]] = field(init=False) @@ -118,8 +118,8 @@ def __post_init__(self): self.hyp_max_idx = len(self.hyp) - 1 self.ref_char_types = _get_char_types(self.ref) self.hyp_char_types = _get_char_types(self.hyp) - self.ref_index_map = _create_index_map(self.ref_token_matches) - self.hyp_index_map = _create_index_map(self.hyp_token_matches) + self.ref_idx_map = _create_index_map(self.ref_token_matches) + self.hyp_idx_map = _create_index_map(self.hyp_token_matches) # First pass: Compute backtrace graph. _, backtrace_matrix = compute_error_align_distance_matrix(self.ref, self.hyp, backtrace=True) @@ -494,7 +494,7 @@ def expand(self): def _get_delete_alignment(self, start_ref_idx: int, end_ref_idx: int) -> Alignment: """Get a DELETE alignment for a given reference slice.""" ref_slice = slice(start_ref_idx, end_ref_idx) - ref_slice = self._translate_slice(ref_slice, self.src.ref_index_map) + ref_slice = self._translate_slice(ref_slice, self.src.ref_idx_map) return Alignment( op_type=OpType.DELETE, ref_slice=ref_slice, @@ -504,7 +504,7 @@ def _get_delete_alignment(self, start_ref_idx: int, end_ref_idx: int) -> Alignme def _get_insert_alignment(self, start_hyp_idx: int, end_hyp_idx: int) -> Alignment: """Get an INSERT alignment for a given hypothesis slice.""" hyp_slice = slice(start_hyp_idx, end_hyp_idx) - hyp_slice = self._translate_slice(hyp_slice, self.src.hyp_index_map) + hyp_slice = self._translate_slice(hyp_slice, self.src.hyp_idx_map) return Alignment( op_type=OpType.INSERT, hyp_slice=hyp_slice, @@ -522,8 +522,8 @@ def _get_match_or_substitution_alignment( """Get a MATCH or SUBSTITUTE alignment for given hypothesis and reference slices.""" hyp_slice = slice(start_hyp_idx, end_hyp_idx) ref_slice = slice(start_ref_idx, end_ref_idx) - hyp_slice = self._translate_slice(hyp_slice, self.src.hyp_index_map) - ref_slice = self._translate_slice(ref_slice, self.src.ref_index_map) + hyp_slice = self._translate_slice(hyp_slice, self.src.hyp_idx_map) + ref_slice = self._translate_slice(ref_slice, self.src.ref_idx_map) is_match_segment = score == 0 op_type = OpType.MATCH if is_match_segment else OpType.SUBSTITUTE return Alignment( @@ -532,8 +532,8 @@ def _get_match_or_substitution_alignment( hyp_slice=hyp_slice, ref=self.src.ref_raw[ref_slice], hyp=self.src.hyp_raw[hyp_slice], - left_compound=self.src.hyp_index_map[start_hyp_idx] >= 0, - right_compound=self.src.hyp_index_map[end_hyp_idx - 1] >= 0, + left_compound=self.src.hyp_idx_map[start_hyp_idx] >= 0, + right_compound=self.src.hyp_idx_map[end_hyp_idx - 1] >= 0, ) def _transition_to_child_node(self, ref_step: int, hyp_step: int): @@ -564,7 +564,7 @@ def _reset_segment_variables(self, hyp_idx: int, ref_idx: int) -> None: def _end_insertion_segment(self, hyp_idx: int, ref_idx: int) -> None: """End the current segment, if criteria for an insertion are met.""" hyp_slice = slice(self.last_hyp_idx + 1, hyp_idx + 1) - hyp_slice = self._translate_slice(hyp_slice, self.src.hyp_index_map) + hyp_slice = self._translate_slice(hyp_slice, self.src.hyp_idx_map) ref_is_empty = ref_idx == self.last_ref_idx if hyp_slice is not None and ref_is_empty: self._end_indices += ((hyp_idx, ref_idx, self._open_cost),) @@ -573,9 +573,9 @@ def _end_insertion_segment(self, hyp_idx: int, ref_idx: int) -> None: def _end_segment(self) -> Union[None, "Path"]: """End the current segment, if criteria for an insertion, a substitution, or a match are met.""" hyp_slice = slice(self.last_hyp_idx + 1, self.hyp_idx + 1) - hyp_slice = self._translate_slice(hyp_slice, self.src.hyp_index_map) + hyp_slice = self._translate_slice(hyp_slice, self.src.hyp_idx_map) ref_slice = slice(self.last_ref_idx + 1, self.ref_idx + 1) - ref_slice = self._translate_slice(ref_slice, self.src.ref_index_map) + ref_slice = self._translate_slice(ref_slice, self.src.ref_idx_map) assert ref_slice is not None From a4cd834be1c75b5aab55115302f0b16ef2877c8f Mon Sep 17 00:00:00 2001 From: Lasse Borgholt Date: Tue, 21 Oct 2025 11:47:27 +0200 Subject: [PATCH 3/6] Type hinting alias for word-level pass --- src/error_align/error_align.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/src/error_align/error_align.py b/src/error_align/error_align.py index 70ee3f7..b037c2c 100644 --- a/src/error_align/error_align.py +++ b/src/error_align/error_align.py @@ -17,6 +17,8 @@ ensure_length_preservation, ) +SubspanDescriptor = Union["SubgraphMetadata", range, tuple[int, int]] + def _embed_tokens(text_tokens: list[str]) -> str: """Embed tokens with delimiters.""" @@ -264,11 +266,7 @@ def _align_identical_inputs(self) -> list[Alignment]: alignments.append(alignment) return alignments - def _align_post_word_level( - self, - src: list[tuple[OpType, Union[SubgraphMetadata, range, tuple[int, int]]]], - beam_size: int, - ) -> list[Alignment]: + def _align_post_word_level(self, src: list[tuple[OpType, SubspanDescriptor]], beam_size: int) -> list[Alignment]: """Perform alignment after a word-level pass.""" alignments = [] for op_type, src_ in src: @@ -326,9 +324,7 @@ def _get_insert_or_delete_alignment_from_token_index( else: raise ValueError(f"Invalid operation type for insert/delete alignment: {op_type}") - def _prepare_subspans_with_word_level_pass( - self, - ) -> list[tuple[OpType, Union[SubgraphMetadata, range, tuple[int, int]]]]: + def _prepare_subspans_with_word_level_pass(self) -> list[tuple[OpType, SubspanDescriptor]]: """Perform a word-level alignment pass to identify unambiguous matches.""" # Extract the word-level backtrace graph. From ecd56b93bd1fef9a680f29d978f24385bb3852ac Mon Sep 17 00:00:00 2001 From: Lasse Borgholt Date: Tue, 21 Oct 2025 11:53:59 +0200 Subject: [PATCH 4/6] Fix method ordering --- src/error_align/error_align.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/error_align/error_align.py b/src/error_align/error_align.py index b037c2c..45ed4e0 100644 --- a/src/error_align/error_align.py +++ b/src/error_align/error_align.py @@ -107,11 +107,6 @@ class SubgraphMetadata: backtrace_node_set: set[tuple[int, int]] = field(init=False) unambiguous_matches: set[tuple[int, int]] = field(init=False) - def __repr__(self): - ref_preview = self.ref if len(self.ref) < 20 else self.ref[:17] + "..." - hyp_preview = self.hyp if len(self.hyp) < 20 else self.hyp[:17] + "..." - return f'SubgraphMetadata(ref="{ref_preview}", hyp="{hyp_preview}")' - def __post_init__(self): # Process reference and hypothesis texts and compute derived attributes. self.ref = _embed_tokens(self.ref_norm) @@ -131,6 +126,11 @@ def __post_init__(self): # NOTE: Used for beam pruning during beam search. self.unambiguous_matches = self.backtrace_graph.get_unambiguous_token_span_matches(self.ref) + def __repr__(self): + ref_preview = self.ref if len(self.ref) < 20 else self.ref[:17] + "..." + hyp_preview = self.hyp if len(self.hyp) < 20 else self.hyp[:17] + "..." + return f'SubgraphMetadata(ref="{ref_preview}", hyp="{hyp_preview}")' + class ErrorAlign: """Error alignment class that performs a two-pass alignment process.""" @@ -141,7 +141,7 @@ def __init__( hyp: str, tokenizer: callable = basic_tokenizer, normalizer: callable = basic_normalizer, - word_level_pass: bool = False, + word_level_pass: bool = True, ): """Initialize the error alignment with reference and hypothesis texts. From 783957a645e48795ce30b7021717a055d4d9bc61 Mon Sep 17 00:00:00 2001 From: Lasse Borgholt Date: Tue, 21 Oct 2025 12:56:05 +0200 Subject: [PATCH 5/6] Bump beta --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1518f84..a6a6981 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "error-align" -version = "0.1.0b3" +version = "0.1.0b4" description = "Text-to-text alignment algorithm for speech recognition error analysis." authors = ["Lasse Borgholt "] license = "MIT" From ff4c499bdef19e2946e27a02a32fe019dcde2deb Mon Sep 17 00:00:00 2001 From: Lasse Borgholt Date: Tue, 21 Oct 2025 13:06:28 +0200 Subject: [PATCH 6/6] Fixed test --- tests/test_default.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_default.py b/tests/test_default.py index 8ecba22..a197b6f 100644 --- a/tests/test_default.py +++ b/tests/test_default.py @@ -70,7 +70,7 @@ def test_representations() -> None: assert repr(match_alignment) == 'Alignment(MATCH: "test" == "test")' # Test ErrorAlign class representation - ea = ErrorAlign(ref="test", hyp="pest") + ea = ErrorAlign(ref="test", hyp="pest", word_level_pass=False) assert repr(ea) == 'ErrorAlign(ref="test", hyp="pest")' # Test Path class representation