Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ LOG_LEVEL=INFO
# Device to run inference on: cpu, cuda, mps
DEVICE=cpu

# Backend Configuration (onnx or triton)
INFERENCE_BACKEND=onnx

# Triton Configuration
TRITON_URL=localhost:8000

TRITON_MODEL_NAME=layoutlmv3-lora-invoice-number

# Model path (relative to project root)
MODEL_PATH=models/layoutlmv3-lora-invoice-number

Expand Down
32 changes: 31 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,41 @@ The easiest way to configure the application:
DEVICE=mps
```

3. Start the application (automatically loads `.env`):
```bash
docker-compose up -d
```

### Inference Backend Configuration

The application supports both local ONNX Runtime (default) and remote Triton Inference Server.

**1. Local ONNX (Default)**
No extra configuration needed.

**2. Triton Inference Server**

First, create the model repository structure:
```bash
python scripts/setup_triton_repo.py --model_path models/layoutlmv3-lora-invoice-number
```

Then start the server:
```bash
docker run --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 \
-v $(pwd)/triton_model_repo:/models \
nvcr.io/nvidia/tritonserver:23.10-py3 \
tritonserver --model-repository=/models
```

Configure `.env` and run `python app.py` to use the API:
```bash
INFERENCE_BACKEND=triton
TRITON_URL=localhost:8000
TRITON_MODEL_NAME=layoutlmv3-lora-invoice-number
```



### Available Environment Variables

Key variables (see `.env.example` for all options):
Expand Down
2 changes: 0 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import logging
from dotenv import load_dotenv
import gradio as gr

# Import from src modules
from src import app, create_gradio_interface, load_model, DEVICE, MODEL_PATH

load_dotenv()
Expand Down
4 changes: 2 additions & 2 deletions models/artifacts/layoutlmv3_invoice_ner.onnx.dvc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
outs:
- md5: 27488294cecffb7e82bb4dbdba745406
size: 501667170
- md5: 1bad589d51dab7bd2efd42126881583d
size: 501607527
hash: md5
path: layoutlmv3_invoice_ner.onnx
4 changes: 2 additions & 2 deletions models/artifacts/layoutlmv3_invoice_ner_optimized.onnx.dvc
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
outs:
- md5: 38c9fa3d001a55a43280b1f9ce69f697
size: 501491951
- md5: 709e8042dd19087194a7c76a1b4a3f2b
size: 501492871
hash: md5
path: layoutlmv3_invoice_ner_optimized.onnx
2 changes: 1 addition & 1 deletion models/artifacts/model_metadata.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"base_model": "microsoft/layoutlmv3-base",
"num_labels": 3,
"max_length": 512,
"onnx_opset_version": 14,
"onnx_opset_version": 17,
"input_names": [
"pixel_values",
"input_ids",
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "invoice"
version = "0.2.3"
version = "0.3.0"
description = "Finetuned LayoutLMv3 model for invoice number extraction"
authors = [{ name = "Ryan Z. Nie", email = "ryanznie@gatech.edu" }]
readme = "README.md"
Expand Down Expand Up @@ -35,6 +35,7 @@ dependencies = [
"onnxruntime>=1.16.0",
"onnxruntime-tools>=1.7.0",
"onnxconverter-common>=1.14.0",
"tritonclient[http]>=2.41.0",
# Testing dependencies
"pytest>=8.0.0",
"pytest-cov>=4.1.0",
Expand Down
5 changes: 3 additions & 2 deletions scripts/export_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def export_to_onnx(
self,
model: torch.nn.Module,
processor: LayoutLMv3Processor,
opset_version: int = 14,
opset_version: int = 17,
) -> Path:
"""Export merged model to ONNX format"""
print("\n🚀 Exporting to ONNX...")
Expand Down Expand Up @@ -149,6 +149,7 @@ def forward(self, pixel_values, input_ids, attention_mask, bbox):
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=opset_version,
custom_opsets={"ai.onnx.ml": 3}, # CRITICAL for Triton
do_constant_folding=True,
export_params=True,
verbose=False,
Expand Down Expand Up @@ -405,7 +406,7 @@ def save_metadata(self, processor: LayoutLMv3Processor, onnx_path: Path) -> None
"base_model": self.base_model,
"num_labels": self.num_labels,
"max_length": 512,
"onnx_opset_version": 14,
"onnx_opset_version": 17,
"input_names": ["pixel_values", "input_ids", "attention_mask", "bbox"],
"output_names": ["logits"],
"label2id": {"O": 0, "B-INVOICE_ID": 1, "I-INVOICE_ID": 2},
Expand Down
163 changes: 163 additions & 0 deletions scripts/setup_triton_repo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
#!/usr/bin/env python3
"""
Setup Triton Inference Server Model Repository for LayoutLMv3

This script creates the necessary directory structure and configuration files
for serving the LayoutLMv3 ONNX model with Triton Inference Server.
"""

import argparse
import shutil
from pathlib import Path
import onnx
import sys


def create_model_repo(
model_path: str, repo_path: str, model_name: str, config_name: str = "config.pbtxt"
):
"""Create Triton model repository structure"""
print(f"🔧 Setting up Triton Model Repository for: {model_name}")

source_path = Path(model_path)
if not source_path.exists():
print(f"❌ Error: Model file not found at {source_path}")
sys.exit(1)

repo_dir = Path(repo_path)
model_dir = repo_dir / model_name
version_dir = model_dir / "1"

# Create directories
version_dir.mkdir(parents=True, exist_ok=True)
print(f" ✓ Created directory structure: {version_dir}")

# Copy model
dest_path = version_dir / "model.onnx"
shutil.copy2(source_path, dest_path)
print(f" ✓ Copied model to: {dest_path}")

# Generate config.pbtxt
generate_config(source_path, model_dir / config_name, model_name)


def generate_config(model_path: Path, config_path: Path, model_name: str):
"""Generate config.pbtxt based on ONNX model properties"""
print("📝 Generating Triton configuration...")

# Load ONNX model to inspect inputs/outputs
model = onnx.load(str(model_path))

# Basic configuration
config_lines = [
f'name: "{model_name}"',
'platform: "onnxruntime_onnx"',
"max_batch_size: 8", # Adjust as needed
"",
"dynamic_batching { }", # Enable dynamic batching
"",
]

# Inputs
# Inputs
# config_lines.append("input [") <-- Removed list syntax
for input_tensor in model.graph.input:
name = input_tensor.name
# Skip batch dim for Triton config if max_batch_size > 0
dims = [
d.dim_value if d.dim_value > 0 else -1
for d in input_tensor.type.tensor_type.shape.dim
]

# Handle dynamic batch dimension (usually the first one)
# In Triton with max_batch_size > 0, we exclude the batch dimension from the config shape
if len(dims) > 0:
dims = dims[1:]

data_type = mapping_onnx_type_to_triton(input_tensor.type.tensor_type.elem_type)

config_lines.append("input {")
config_lines.append(f' name: "{name}"')
config_lines.append(f" data_type: {data_type}")
config_lines.append(f" dims: {dims}")
config_lines.append("}")
# config_lines.append("]") <-- Removed list syntax
config_lines.append("")

# Outputs
# Outputs
# config_lines.append("output [") <-- Removed list syntax
for output_tensor in model.graph.output:
name = output_tensor.name
dims = [
d.dim_value if d.dim_value > 0 else -1
for d in output_tensor.type.tensor_type.shape.dim
]

if len(dims) > 0:
dims = dims[1:]

data_type = mapping_onnx_type_to_triton(
output_tensor.type.tensor_type.elem_type
)

config_lines.append("output {")
config_lines.append(f' name: "{name}"')
config_lines.append(f" data_type: {data_type}")
config_lines.append(f" dims: {dims}")
config_lines.append("}")
# config_lines.append("]") <-- Removed list syntax

# Write config
with open(config_path, "w") as f:
f.write("\n".join(config_lines))

print(f" ✓ Configuration saved to: {config_path}")


def mapping_onnx_type_to_triton(onnx_type):
"""Map ONNX data types to Triton data types"""
# https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#datatypes
type_map = {
1: "TYPE_FP32", # FLOAT
2: "TYPE_UINT8", # UINT8
3: "TYPE_INT8", # INT8
4: "TYPE_UINT16", # UINT16
5: "TYPE_INT16", # INT16
6: "TYPE_INT32", # INT32
7: "TYPE_INT64", # INT64
9: "TYPE_BOOL", # BOOL
10: "TYPE_FP16", # FLOAT16
11: "TYPE_FP64", # DOUBLE
}
return type_map.get(onnx_type, "TYPE_FP32") # Default to FP32


def main():
parser = argparse.ArgumentParser(description="Setup Triton Model Repository")
parser.add_argument("--model_path", required=True, help="Path to source ONNX model")
parser.add_argument(
"--repo_dir",
default="triton_model_repo",
help="Path to Triton model repository",
)
parser.add_argument(
"--model_name",
default="layoutlmv3-lora-invoice-number",
help="Name of the model in Triton",
)

args = parser.parse_args()

create_model_repo(args.model_path, args.repo_dir, args.model_name)

print("\n✅ Triton repository setup complete!")
print(f" Repository path: {Path(args.repo_dir).absolute()}")
print(" To start Triton:")
print(
f" docker run --rm -p 8000:8000 -p 8001:8001 -p 8002:8002 -v {Path(args.repo_dir).absolute()}:/models nvcr.io/nvidia/tritonserver:23.10-py3 tritonserver --model-repository=/models"
)


if __name__ == "__main__":
main()
15 changes: 7 additions & 8 deletions src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ async def lifespan(app: FastAPI):

app = FastAPI(
title="Invoice NER API",
description="LayoutLMv3 model for extracting invoice numbers",
version="1.0.0",
description="Finetuned LayoutLMv3 model for extracting invoice numbers",
lifespan=lifespan,
)

Expand All @@ -55,14 +54,14 @@ class PredictionRequest(BaseModel):
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy" if inference.model is not None else "unhealthy",
"model_loaded": inference.model is not None,
"status": "healthy" if inference.backend is not None else "unhealthy",
"model_loaded": inference.backend is not None,
"device": inference.DEVICE,
}


@app.post("/predict")
async def predict(
def predict(
image: UploadFile = File(..., description="Invoice image file (JPG, PNG, etc.)"),
ocr_file: UploadFile = File(..., description="OCR data file (TXT or JSON format)"),
):
Expand All @@ -78,12 +77,12 @@ async def predict(
Returns:
JSON with extracted invoice number, method used, and detailed predictions
"""
if inference.model is None or inference.processor is None:
if inference.backend is None or inference.processor is None:
raise HTTPException(status_code=503, detail="Model not loaded")

try:
# Read and validate image
image_bytes = await image.read()
image_bytes = image.file.read()
try:
pil_image = Image.open(io.BytesIO(image_bytes))
pil_image = pil_image.convert("RGB")
Expand All @@ -93,7 +92,7 @@ async def predict(
img_width, img_height = pil_image.size

# Read and parse OCR file
ocr_bytes = await ocr_file.read()
ocr_bytes = ocr_file.file.read()
ocr_filename = ocr_file.filename.lower()

try:
Expand Down
Loading
Loading