Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ python benchmarks/benchmark.py \
--device mps \
--split train

# Benchmark hybrid model with ONNX fallback
python benchmarks/benchmark.py \
--model hybrid \
--data-dir data/test \
--model-path models/artifacts/layoutlmv3_invoice_ner.onnx \
--run-name "layoutlmv3-lora-heuristics-ONNX"

# Benchmark Finetuned LayoutLMv3 only
python benchmarks/benchmark.py \
--model layoutlmv3 \
Expand Down
30 changes: 25 additions & 5 deletions benchmarks/models/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,21 @@ def load(self) -> None:
Note: Heuristics don't require loading, they're pure pattern matching.
"""
if self.fallback_model is None:
# Default to LayoutLMv3
from benchmarks.models.layoutlmv3_model import LayoutLMv3Model
# Check for ONNX model path
model_path = ""
if self.model_config:
model_path = self.model_config.get("model_path", "")

self.fallback_model = LayoutLMv3Model(self.model_config)
if str(model_path).endswith(".onnx"):
logger.info("Detected ONNX model path, using OnnxModel fallback")
from benchmarks.models.onnx_model import OnnxModel

self.fallback_model = OnnxModel(self.model_config)
else:
# Default to LayoutLMv3
from benchmarks.models.layoutlmv3_model import LayoutLMv3Model

self.fallback_model = LayoutLMv3Model(self.model_config)

logger.info("Loading hybrid model...")
logger.info("✓ Heuristics ready (no loading required)")
Expand Down Expand Up @@ -142,6 +153,16 @@ def predict(

result.metadata["fallback_used"] = True
result.metadata["extraction_stage"] = "model_fallback"

# Get model name from config
fb_config = self.fallback_model.get_config()
model_path = (
fb_config.get("model_path")
or fb_config.get("checkpoint_path")
or "unknown"
)
result.metadata["model_name"] = os.path.basename(str(model_path))

result.method = "model_fallback"

return result
Expand All @@ -167,8 +188,7 @@ def get_config(self) -> Dict[str, Any]:
fallback_config = self.fallback_model.get_config()

return {
"model_name": "HybridModel (Heuristics + ML)",
"model_version": "v1.0",
"model_name": "HybridModel (Heuristics + LM)",
"architecture": "Heuristics → Model Fallback",
"heuristic_patterns": 14, # From heuristics.py
"fallback_model": fallback_config,
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "invoice"
version = "0.2.2"
version = "0.2.3"
description = "Finetuned LayoutLMv3 model for invoice number extraction"
authors = [{ name = "Ryan Z. Nie", email = "ryanznie@gatech.edu" }]
readme = "README.md"
Expand Down Expand Up @@ -67,4 +67,4 @@ required-environments = ["sys_platform == 'darwin'"]
[tool.coverage.run]
relative_files = true
branch = true
source = ["."]
source = ["."]
Loading