Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vesuvius/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ blending = [
"scipy>=1.13.1", # For scipy.ndimage.gaussian_filter in blending.py
"tqdm>=4.67.1", # Progress bars
"nest-asyncio>=1.6.0", # Required by vesuvius.utils.catalog (imported via __init__)
"scikit-image>=0.24.0", # For topo post-processing (label, euler_number, morphology)
"numba>=0.60.0", # For JIT-compiled rasterization in topo post-processing
]

tests = ["pytest>=8.4.2"]
Expand Down
105 changes: 99 additions & 6 deletions vesuvius/src/vesuvius/models/run/blending.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def generate_gaussian_map(patch_size: tuple, sigma_scale: float = 8.0, dtype=np.


def _init_worker(part_files, output_path, gaussian_map, patch_size, num_classes, is_s3,
finalize_config=None):
finalize_config=None, topo_config=None):
"""Initialize per-worker process state with cached zarr stores.

Called once when each worker process starts. Zarr stores (and their
Expand All @@ -392,6 +392,8 @@ def _init_worker(part_files, output_path, gaussian_map, patch_size, num_classes,
Args:
finalize_config: Optional FinalizeConfig. When set, chunks are finalized
(softmax + uint8) inline instead of writing float16 blended logits.
topo_config: Optional TopoPostprocessConfig. When set, topology-aware
post-processing replaces simple finalization.
"""
numcodecs.blosc.use_threads = False
storage_opts = {'anon': False} if is_s3 else None
Expand All @@ -404,6 +406,7 @@ def _init_worker(part_files, output_path, gaussian_map, patch_size, num_classes,
'logits_stores': {},
'output_store': open_zarr(output_path, mode='r+', storage_options=storage_opts),
'finalize_config': finalize_config,
'topo_config': topo_config,
})


Expand Down Expand Up @@ -523,8 +526,20 @@ def process_chunk(chunk_info, chunk_patches, epsilon=1e-8):
np.divide(chunk_logits, chunk_weights[np.newaxis, :, :, :] + epsilon,
out=normalized, where=chunk_weights[np.newaxis, :, :, :] > 0)

topo_config = _worker_state.get('topo_config')
finalize_config = _worker_state.get('finalize_config')
if finalize_config is not None:
if topo_config is not None:
from vesuvius.models.run.topo_postprocess import apply_topo_finalization
result, is_empty = apply_topo_finalization(normalized, num_classes, topo_config)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Avoid chunk-local topo filtering for boundary-spanning objects

This invokes apply_topo_finalization independently for each chunk, but the topo pipeline then applies component-level filters (remove_small_objects, Euler checks, sheet fitting) that assume whole connected objects. When a real object crosses chunk boundaries, each chunk sees only a fragment, so valid structures can be dropped at boundaries and produce visible seams/gaps in the merged mask. Topology filtering needs halo/overlap handling or a global pass before size- and topology-based rejection.

Useful? React with 👍 / 👎.

if not is_empty:
finalized_slice = (
slice(None),
slice(z_start, z_end),
slice(y_start, y_end),
slice(x_start, x_end)
)
output_store[finalized_slice] = result
elif finalize_config is not None:
from vesuvius.models.run.finalize_outputs import apply_finalization
result, is_empty = apply_finalization(normalized, num_classes, finalize_config)
if not is_empty:
Expand Down Expand Up @@ -592,7 +607,8 @@ def merge_inference_outputs(
verbose: bool = True,
num_parts: int = 1, # Number of parts to split processing into
global_part_id: int = 0, # Part ID for this process (0-indexed)
finalize_config=None): # Optional FinalizeConfig — fuse finalization when set
finalize_config=None, # Optional FinalizeConfig — fuse finalization when set
topo_config=None): # Optional TopoPostprocessConfig — topo post-processing when set
"""
Args:
parent_dir: Directory containing logits_part_X.zarr and coordinates_part_X.zarr.
Expand All @@ -607,6 +623,8 @@ def merge_inference_outputs(
num_parts: Number of parts to split the blending process into.
global_part_id: Part ID for this process (0-indexed). Used for Z-axis partitioning.
finalize_config: Optional FinalizeConfig. When provided, softmax + uint8 quantization
topo_config: Optional TopoPostprocessConfig. When provided, topology-aware
post-processing is applied per-chunk instead of simple finalization.
is applied inline after blending (fused mode), skipping the intermediate float16 array.
"""

Expand Down Expand Up @@ -703,9 +721,15 @@ def merge_inference_outputs(
print(f" Original Volume Shape (Z,Y,X): {original_volume_shape}")

# --- 3. Prepare Output Stores ---
# Topo post-processing: binary uint8 mask with shape (1, Z, Y, X)
if topo_config is not None:
output_shape = (1, *original_volume_shape)
output_dtype = np.uint8
print(f"Topo post-processing mode enabled")
print(f" Output shape: {output_shape}, dtype: uint8")
# When fused mode is active, populate finalize_config with multi-task metadata
# from the logits zarr attrs and use the finalized shape/dtype.
if finalize_config is not None:
elif finalize_config is not None:
if hasattr(part0_logits_store, 'attrs'):
finalize_config.is_multi_task = part0_logits_store.attrs.get('is_multi_task', False)
finalize_config.target_info = part0_logits_store.attrs.get('target_info', None)
Expand Down Expand Up @@ -860,7 +884,7 @@ def print_progress_stats():
with ProcessPoolExecutor(
max_workers=num_workers,
initializer=_init_worker,
initargs=(part_files, output_path, gaussian_map, patch_size, num_classes, is_s3, finalize_config)
initargs=(part_files, output_path, gaussian_map, patch_size, num_classes, is_s3, finalize_config, topo_config)
) as executor:
future_to_chunk = {
executor.submit(
Expand Down Expand Up @@ -930,6 +954,13 @@ def print_progress_stats():
output_zarr.attrs['original_volume_shape'] = original_volume_shape
output_zarr.attrs['sigma_scale'] = sigma_scale

# Add topo post-processing metadata
if topo_config is not None:
output_zarr.attrs['processing_mode'] = 'topo_postprocess'
output_zarr.attrs['topo_t_low'] = topo_config.topo_t_low
output_zarr.attrs['topo_t_high'] = topo_config.topo_t_high
output_zarr.attrs['fused_blend_topo'] = True

# Add finalization metadata when in fused mode
if finalize_config is not None:
output_zarr.attrs['processing_mode'] = finalize_config.mode
Expand Down Expand Up @@ -1032,6 +1063,43 @@ def blend_and_finalize_main():
parser.add_argument('--threshold', action='store_true',
help='Apply argmax and only save class predictions (no probabilities).')

# Topo post-processing args
parser.add_argument('--topo_postprocess', action='store_true',
help='Enable topology-aware post-processing (replaces simple finalization).')
parser.add_argument('--topo_t_low', type=float, default=0.2,
help='Low threshold for hysteresis. Default: 0.2')
parser.add_argument('--topo_t_high', type=float, default=0.83,
help='High threshold for hysteresis. Default: 0.83')
parser.add_argument('--topo_z_radius', type=int, default=1,
help='Z radius for anisotropic closing. Default: 1')
parser.add_argument('--topo_xy_radius', type=int, default=0,
help='XY radius for anisotropic closing. Default: 0')
parser.add_argument('--topo_dust_min_size', type=int, default=100,
help='Min size for dust removal after hysteresis. Default: 100')
parser.add_argument('--topo_min_object_size', type=int, default=1000,
help='Min object size after initial thresholding. Default: 1000')
parser.add_argument('--topo_final_min_object_size', type=int, default=2000,
help='Min object size for final cleanup. Default: 2000')
parser.add_argument('--topo_grid_resolution', type=int, default=100,
help='Base grid resolution for sheet fitting. Default: 100')
parser.add_argument('--topo_thickness', type=int, default=3,
help='Sheet thickness for dilation. Default: 3')
parser.add_argument('--topo_smoothing', type=float, default=1.0,
help='Gaussian smoothing sigma for fitted sheets. Default: 1.0')
parser.add_argument('--topo_overlap_buffer', type=int, default=0,
help='Overlap buffer for erosion. Default: 0')
parser.add_argument('--topo_min_coverage', type=float, default=0.65,
help='Min coverage score for accepting a fitted sheet. Default: 0.65')
parser.add_argument('--topo_min_dice', type=float, default=0.7,
help='Min Dice score for accepting a fitted sheet. Default: 0.7')
parser.add_argument('--topo_max_distance', type=int, default=10,
help='Max distance for sheet fitting. Default: 10')
parser.add_argument('--topo_samples_per_edge', type=int, default=8,
help='Samples per edge for surface rasterization. Default: 8')
parser.add_argument('--topo_alt_t_lows', type=str, default='0.5,0.7',
help='Comma-separated T_low values for alternative volumes. Default: 0.5,0.7')
parser.add_argument('--topo_border_crop', type=int, default=3,
help='Border crop size in voxels. Default: 3')
args = parser.parse_args()

if args.part_id < 0 or args.part_id >= args.num_parts:
Expand All @@ -1046,8 +1114,32 @@ def blend_and_finalize_main():
except ValueError:
parser.error("Invalid chunk_size format. Expected 3 comma-separated integers (Z,Y,X).")

topo_config = None
if args.topo_postprocess:
from vesuvius.models.run.topo_postprocess import TopoPostprocessConfig
alt_t_lows = tuple(float(x) for x in args.topo_alt_t_lows.split(','))
topo_config = TopoPostprocessConfig(
topo_t_low=args.topo_t_low,
topo_t_high=args.topo_t_high,
topo_z_radius=args.topo_z_radius,
topo_xy_radius=args.topo_xy_radius,
topo_dust_min_size=args.topo_dust_min_size,
topo_min_object_size=args.topo_min_object_size,
topo_final_min_object_size=args.topo_final_min_object_size,
topo_grid_resolution=args.topo_grid_resolution,
topo_thickness=args.topo_thickness,
topo_smoothing=args.topo_smoothing,
topo_overlap_buffer=args.topo_overlap_buffer,
topo_min_coverage=args.topo_min_coverage,
topo_min_dice=args.topo_min_dice,
topo_max_distance=args.topo_max_distance,
topo_samples_per_edge=args.topo_samples_per_edge,
topo_alt_t_lows=alt_t_lows,
topo_border_crop=args.topo_border_crop,
)

from vesuvius.models.run.finalize_outputs import FinalizeConfig
finalize_config = FinalizeConfig(mode=args.mode, threshold=args.threshold)
finalize_config = FinalizeConfig(mode=args.mode, threshold=args.threshold) if not args.topo_postprocess else None

try:
merge_inference_outputs(
Expand All @@ -1061,6 +1153,7 @@ def blend_and_finalize_main():
num_parts=args.num_parts,
global_part_id=args.part_id,
finalize_config=finalize_config,
topo_config=topo_config,
)
return 0
except Exception as e:
Expand Down
Loading