Skip to content

Commit c4dfa3a

Browse files
committed
segment/recognize: spawn bg process for Kraken…
- during `setup`, instead of loading models in the processor directly, instantiate and spawn a singleton predictor subprocess with the given parameters (after resolving the model path name), communicating via shared (task and result) queues to synchronize processor and predictor processes; the predictor will then load models in its own address space - at runtime, the processor merely calls the predictor with the respective arguments for that page, which translates into - putting the arguments on the task queue - getting the results from the result queue, blocking - at runtime, the predictor loops into: - receiving inputs from the task queue, blocking - calling `predict` on them - putting outputs on the result queue - in the predictor, tasks and results are identified via page id, so results get retrieved for their respective task only, implemented via shared dict to synchronize forked processor workers - during `shutdown`, tell the predictor to shut down as well (terminating the subprocess); the predictor will then exit its loop and close the queues - abstract from kraken.pageseg, kraken.blla, and kraken.rpred differences in initialization phase and inference phase via shared `common.KrakenPredictor` class, override specifics in - `recognize.KrakenRecognizePredictor`: - during `setup`, after loading the model, submit a special "task" to query the model's `one_channel_mode` attribute - at runtime, translate the model into a `defaultdict` for `mm_rpred`, but picklable to be compatible with mp.Queue; for the same reason, exhaust the result generator immediately - `segment.KrakenSegmentPredictor`: during `setup`, map the given parameters and inputs to kwargs as applicable by either `pageseg.segment` or `blla.segment`
1 parent 9f68868 commit c4dfa3a

File tree

4 files changed

+183
-55
lines changed

4 files changed

+183
-55
lines changed

ocrd_kraken/binarize.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
from os.path import join
33
from typing import Optional
44

5+
import kraken.binarization
6+
57
from ocrd.processor.base import OcrdPageResult
68
from ocrd.processor.ocrd_page_result import OcrdPageResultImage
79

8-
import kraken.binarization
910
from ocrd import Processor
1011
from ocrd_utils import assert_file_grp_cardinality, getLogger, make_file_id, MIMETYPE_PAGE
1112
from ocrd_models.ocrd_page import AlternativeImageType, OcrdPage, to_xml

ocrd_kraken/common.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import multiprocessing as mp
2+
3+
from ocrd_utils import config, initLogging
4+
5+
class KrakenPredictor(mp.context.SpawnProcess):
6+
def __init__(self, logger, parameter):
7+
self.logger = logger
8+
self.parameter = parameter
9+
ctxt = mp.get_context('spawn')
10+
self.taskq = ctxt.Queue(maxsize=1 + config.OCRD_MAX_PARALLEL_PAGES)
11+
self.resultq = ctxt.Queue(maxsize=1 + config.OCRD_MAX_PARALLEL_PAGES)
12+
self.terminate = ctxt.Event()
13+
ctxt = mp.get_context('fork') # base.Processor will fork workers
14+
self.results = ctxt.Manager().dict()
15+
super().__init__()
16+
self.daemon = True
17+
def __call__(self, page_id, *page_input):
18+
self.taskq.put((page_id, page_input))
19+
self.logger.debug("sent task for '%s'", page_id)
20+
#return self.get(page_id)
21+
result = self.get(page_id)
22+
self.logger.debug("received result for '%s'", page_id)
23+
return result
24+
def get(self, page_id):
25+
while not self.terminate.is_set():
26+
if page_id in self.results:
27+
result = self.results.pop(page_id)
28+
if isinstance(result, Exception):
29+
raise Exception(f"predictor failed for {page_id}") from result
30+
return result
31+
try:
32+
page_id, result = self.resultq.get(timeout=0.7)
33+
except mp.queues.Empty:
34+
continue
35+
self.logger.debug("storing results for '%s'", page_id)
36+
self.results[page_id] = result
37+
raise Exception(f"predictor terminated while waiting on results for {page_id}")
38+
def run(self):
39+
initLogging()
40+
try:
41+
self.setup()
42+
except Exception as e:
43+
self.logger.exception("setup failed")
44+
self.terminate.set()
45+
while not self.terminate.is_set():
46+
try:
47+
page_id, page_input = self.taskq.get(timeout=1.1)
48+
except mp.queues.Empty:
49+
continue
50+
self.logger.debug("predicting '%s'", page_id)
51+
try:
52+
page_output = self.predict(*page_input)
53+
except Exception as e:
54+
self.logger.error("prediction failed: %s", e.__class__.__name__)
55+
page_output = e
56+
self.resultq.put((page_id, page_output))
57+
self.logger.debug("sent result for '%s'", page_id)
58+
self.resultq.close()
59+
self.resultq.cancel_join_thread()
60+
self.logger.debug("predictor terminated")
61+
def setup(self):
62+
raise NotImplementedError()
63+
def predict(self, *inputs):
64+
raise NotImplementedError()
65+
def shutdown(self):
66+
# do not terminate from forked processor instances
67+
if mp.parent_process() is None:
68+
self.terminate.set()
69+
self.taskq.close()
70+
self.taskq.cancel_join_thread()
71+
self.logger.debug(f"terminated {self} in {mp.current_process().name}")
72+
else:
73+
self.logger.debug(f"not touching {self} in {mp.current_process().name}")
74+
def __del__(self):
75+
self.logger.debug(f"deinit of {self} in {mp.current_process().name}")
76+
self.shutdown()

ocrd_kraken/recognize.py

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from ocrd.processor.base import OcrdPageResult
33
import regex
44
import itertools
5+
from collections import defaultdict
56
import numpy as np
67
from scipy.sparse.csgraph import minimum_spanning_tree
78
from shapely.geometry import Polygon, LineString, box as Rectangle
@@ -37,6 +38,38 @@
3738
TextLineOrderSimpleType
3839
)
3940

41+
from .common import KrakenPredictor
42+
43+
class KrakenRecognizePredictor(KrakenPredictor):
44+
# workaround for Kraken's unpicklable defaultdict choice
45+
class DefaultDict(defaultdict):
46+
def __init__(self, default=None):
47+
self.default = default
48+
super().__init__()
49+
def default_factory(self):
50+
return self.default
51+
def setup(self):
52+
import torch
53+
from kraken.lib.models import load_any
54+
model = self.parameter['model']
55+
self.logger.info("loading model '%s'", model)
56+
device = self.parameter['device']
57+
if device != 'cpu' and not torch.cuda.is_available():
58+
device = 'cpu'
59+
if device == 'cpu':
60+
self.logger.warning("no CUDA device available. Running without GPU will be slow")
61+
self.model = load_any(model, device=device)
62+
def predict(self, *inputs):
63+
from kraken.rpred import mm_rpred
64+
if not len(inputs):
65+
return self.model.nn.input[1] == 1 and self.model.one_channel_mode == '1'
66+
image, segmentation = inputs
67+
nets = __class__.DefaultDict(self.model)
68+
result = mm_rpred(nets, image, segmentation,
69+
self.parameter['pad'],
70+
self.parameter['bidi_reordering'])
71+
# we must exhaust the generator before enqueuing
72+
return list(result)
4073

4174
class KrakenRecognize(Processor):
4275

@@ -48,23 +81,17 @@ def setup(self):
4881
"""
4982
Load model, set predict function
5083
"""
84+
parameter = dict(self.parameter)
85+
parameter['model'] = self.resolve_resource(parameter['model'])
86+
self.predictor = KrakenRecognizePredictor(self.logger, parameter)
87+
self.predictor.start()
88+
self.binary = self.predictor("") # blocks until model is loaded
89+
self.logger.info("loaded %s model %s", "binary" if self.binary else "grayscale", self.parameter["model"])
5190

52-
import torch
53-
from kraken.rpred import rpred
54-
from kraken.lib.models import load_any
55-
model_fname = self.resolve_resource(self.parameter['model'])
56-
self.logger.info("loading model '%s'", model_fname)
57-
device = self.parameter['device']
58-
if device != 'cpu' and not torch.cuda.is_available():
59-
device = 'cpu'
60-
if device == 'cpu':
61-
self.logger.warning("no CUDA device available. Running without GPU will be slow")
62-
self.model = load_any(model_fname, device=device)
63-
def predict(page_image, segmentation):
64-
return rpred(self.model, page_image, segmentation,
65-
self.parameter['pad'],
66-
self.parameter['bidi_reordering'])
67-
self.predict = predict
91+
def shutdown(self):
92+
if getattr(self, 'predictor', None):
93+
self.predictor.shutdown()
94+
del self.predictor
6895

6996
def process_page_pcgts(self, *input_pcgts: Optional[OcrdPage], page_id: Optional[str] = None) -> OcrdPageResult:
7097
"""Recognize text on lines with Kraken.
@@ -96,8 +123,7 @@ def process_page_pcgts(self, *input_pcgts: Optional[OcrdPage], page_id: Optional
96123
page_image, page_coords, _ = self.workspace.image_from_page(
97124
page, page_id,
98125
feature_selector="binarized"
99-
if self.model.nn.input[1] == 1 and self.model.one_channel_mode == '1'
100-
else '')
126+
if self.binary else '')
101127
page_rect = Rectangle(0, 0, page_image.width - 1, page_image.height - 1)
102128
# TODO: find out whether kraken.lib.xml.XMLPage(...).to_container() is adequate
103129

@@ -152,7 +178,7 @@ def process_page_pcgts(self, *input_pcgts: Optional[OcrdPage], page_id: Optional
152178
text_direction='horizontal-lr',
153179
type=segtype,
154180
imagename=page_id)
155-
for idx_line, ocr_record in enumerate(self.predict(page_image, segmentation)):
181+
for idx_line, ocr_record in enumerate(self.predictor(page_id, page_image, segmentation)):
156182
line = all_lines[idx_line]
157183
id_line = line.id
158184
if not ocr_record.prediction and not ocr_record.cuts:

ocrd_kraken/segment.py

Lines changed: 60 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from typing import Optional
22
from PIL import ImageOps
33

4+
import shapely.geometry as geom
5+
from shapely.prepared import prep as geom_prep
6+
import torch
7+
48
from ocrd import Processor
59
from ocrd.processor.ocrd_page_result import OcrdPageResult
610
from ocrd_utils import (
@@ -22,44 +26,66 @@
2226
BaselineType,
2327
)
2428

25-
import shapely.geometry as geom
26-
from shapely.prepared import prep as geom_prep
27-
import torch
28-
29-
class KrakenSegment(Processor):
30-
31-
@property
32-
def executable(self):
33-
return 'ocrd-kraken-segment'
29+
from .common import KrakenPredictor
3430

31+
class KrakenSegmentPredictor(KrakenPredictor):
3532
def setup(self):
36-
"""
37-
Load models
38-
"""
39-
kwargs = {}
40-
kwargs['text_direction'] = self.parameter['text_direction']
41-
self.use_legacy = self.parameter['use_legacy']
33+
self.use_legacy = self.parameter.pop('use_legacy')
4234
if self.use_legacy:
43-
from kraken.pageseg import segment
44-
kwargs['scale'] = self.parameter['scale']
45-
kwargs['maxcolseps'] = self.parameter['maxcolseps']
46-
kwargs['black_colseps'] = self.parameter['black_colseps']
4735
self.logger.info("Using legacy segmenter")
36+
# adapt to Kraken v5 changes:
37+
self.parameter['no_hlines'] = self.parameter.pop('remove_hlines')
38+
self.parameter.pop('device')
4839
else:
4940
from kraken.lib.vgsl import TorchVGSLModel
50-
from kraken.blla import segment
5141
self.logger.info("Using blla segmenter")
52-
blla_model_fname = self.resolve_resource(self.parameter['blla_model'])
53-
kwargs['model'] = TorchVGSLModel.load_model(blla_model_fname)
42+
self.logger.info("loading model '%s'", self.parameter['model'])
43+
self.parameter['model'] = TorchVGSLModel.load_model(self.parameter['model'])
5444
device = self.parameter['device']
5545
if device != 'cpu' and not torch.cuda.is_available():
5646
device = 'cpu'
5747
if device == 'cpu':
5848
self.logger.warning("no CUDA device available. Running without GPU will be slow")
59-
kwargs['device'] = device
60-
def segmenter(img, mask=None):
61-
return segment(img, mask=mask, **kwargs)
62-
self.segmenter = segmenter
49+
self.parameter['device'] = device
50+
# adapt to Kraken v5 changes:
51+
self.parameter.pop('scale')
52+
self.parameter.pop('remove_hlines')
53+
self.parameter.pop('maxcolseps')
54+
self.parameter.pop('black_colseps')
55+
def predict(self, *inputs):
56+
if self.use_legacy:
57+
from kraken.pageseg import segment
58+
else:
59+
from kraken.blla import segment
60+
image, mask = inputs
61+
return segment(image, mask=mask, **self.parameter)
62+
63+
class KrakenSegment(Processor):
64+
65+
@property
66+
def executable(self):
67+
return 'ocrd-kraken-segment'
68+
69+
def setup(self):
70+
"""
71+
Load models
72+
"""
73+
parameter = dict(self.parameter)
74+
model = parameter.pop('blla_model')
75+
del parameter['blla_classes']
76+
del parameter['overwrite_segments']
77+
del parameter['level-of-operation']
78+
self.use_legacy = parameter['use_legacy']
79+
if not self.use_legacy:
80+
parameter['model'] = self.resolve_resource(model)
81+
self.predictor = KrakenSegmentPredictor(self.logger, parameter)
82+
self.predictor.start()
83+
84+
def shutdown(self):
85+
import multiprocessing as mp
86+
if getattr(self, 'predictor', None):
87+
self.predictor.shutdown()
88+
del self.predictor
6389

6490
def process_page_pcgts(self, *input_pcgts: Optional[OcrdPage], page_id: Optional[str] = None) -> OcrdPageResult:
6591
"""Segment into (regions and) lines with Kraken.
@@ -109,7 +135,7 @@ def process_page_pcgts(self, *input_pcgts: Optional[OcrdPage], page_id: Optional
109135
page.TextRegion = []
110136
elif len(page.TextRegion or []):
111137
self.logger.warning('Keeping %d text regions on page "%s"', len(page.TextRegion or []), page.id)
112-
self._process_page(page_image, page_coords, page, zoom)
138+
self._process_page(page_image, page_coords, page, page_id, zoom)
113139
elif self.parameter['level-of-operation'] == 'table':
114140
regions = page.get_AllRegions(classes=['Table'])
115141
if not regions:
@@ -120,7 +146,7 @@ def process_page_pcgts(self, *input_pcgts: Optional[OcrdPage], page_id: Optional
120146
region.TextRegion = []
121147
elif len(region.TextRegion or []):
122148
self.logger.warning('Keeping %d text regions in region "%s"', len(region.TextRegion or []), region.id)
123-
self._process_page(page_image, page_coords, region, zoom)
149+
self._process_page(page_image, page_coords, region, page_id, zoom)
124150
else:
125151
regions = page.get_AllRegions(classes=['Text'])
126152
if not regions:
@@ -131,11 +157,11 @@ def process_page_pcgts(self, *input_pcgts: Optional[OcrdPage], page_id: Optional
131157
region.TextLine = []
132158
elif len(region.TextLine or []):
133159
self.logger.warning('Keeping %d lines in region "%s"', len(region.TextLine or []), region.id)
134-
self._process_region(page_image, page_coords, region, zoom)
160+
self._process_region(page_image, page_coords, region, page_id, zoom)
135161

136162
return OcrdPageResult(pcgts)
137163

138-
def _process_page(self, page_image, page_coords, page, zoom=1.0):
164+
def _process_page(self, page_image, page_coords, page, page_id, zoom=1.0):
139165
def getmask():
140166
# use mask if existing regions (any type for page, text cells for table)
141167
# or segment is lower than page level
@@ -173,10 +199,10 @@ def getmask():
173199
# poly = geom.Polygon(poly).buffer(20/zoom).exterior.coords[:-1]
174200
mask.paste(255, mask=polygon_mask(page_image, poly))
175201
return mask
176-
res = self.segmenter(page_image, mask=getmask())
202+
res = self.predictor(page_id, page_image, getmask())
177203
self.logger.debug("Finished segmentation, serializing")
204+
#self.logger.debug(res)
178205
if self.use_legacy:
179-
self.logger.debug(res)
180206
idx_line = 0
181207
for idx_line, line in enumerate(res.lines):
182208
line_poly = polygon_from_x0y0x1y1(line.bbox)
@@ -191,7 +217,6 @@ def getmask():
191217
page.add_TextRegion(region_elem)
192218
self.logger.debug("Found %d lines on page %s", idx_line + 1, page.id)
193219
else:
194-
self.logger.debug(res)
195220
handled_lines = {}
196221
regions = [(type_, region)
197222
for type_ in res.regions
@@ -245,7 +270,7 @@ def getmask():
245270
page.add_TextRegion(region_elem)
246271
self.logger.debug("Found %d lines and %d regions on page %s", idx_line + 1, idx_region + 1, page.id)
247272

248-
def _process_region(self, page_image, page_coords, region, zoom=1.0):
273+
def _process_region(self, page_image, page_coords, region, page_id, zoom=1.0):
249274
def getmask():
250275
poly = coordinates_of_segment(region, page_image, page_coords)
251276
poly = geom.Polygon(poly).buffer(20/zoom).exterior.coords[:-1]
@@ -256,7 +281,7 @@ def getmask():
256281
# poly = geom.Polygon(poly).buffer(20/zoom).exterior.coords[:-1]
257282
mask.paste(255, mask=polygon_mask(page_image, poly))
258283
return mask
259-
res = self.segmenter(page_image, mask=getmask())
284+
res = self.predictor(page_id, page_image, getmask())
260285
self.logger.debug("Finished segmentation, serializing")
261286
idx_line = 0
262287
if self.use_legacy:

0 commit comments

Comments
 (0)