From 5658241b72bc5addba7ebcc379e5e8f91c2d67d9 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Mon, 16 Jun 2025 16:00:02 -0500 Subject: [PATCH 01/45] Implement a logging function and use it during inference in order to preserve stdout, move calculate_slice_boxes to utils so it can be used during inference, fix a warning. --- models/model.py | 30 ++++++++--------- predict_mass_roads_test_set.py | 46 +------------------------- utils.py | 59 +++++++++++++++++++++++++++++++++- 3 files changed, 74 insertions(+), 61 deletions(-) diff --git a/models/model.py b/models/model.py index db396cb..5233679 100644 --- a/models/model.py +++ b/models/model.py @@ -11,6 +11,7 @@ from config import CFG from utils import ( create_mask, + log ) @@ -128,7 +129,7 @@ def __init__(self, cfg, vocab_size, encoder_len, dim, num_heads, num_layers): def init_weights(self): for name, p in self.named_parameters(): if 'encoder_pos_embed' in name or 'decoder_pos_embed' in name: - print(f"Skipping initialization of pos embed layers...") + log("Skipping initialization of pos embed layers...") continue if p.dim() > 1: nn.init.xavier_uniform_(p) @@ -269,16 +270,15 @@ def predict(self, image, tgt): with torch.no_grad(): for i in range(1 + n_vertices*2): try: - print(i) + log(f"Iteration {i}") preds_p, feats_p = model.predict(image, batch_preds) - # print(preds_p.shape, feats_p.shape) if i % 2 == 0: confs_ = torch.softmax(preds_p, dim=-1).sort(axis=-1, descending=True)[0][:, 0].cpu() confs.append(confs_) preds_p = sample(preds_p) batch_preds = torch.cat([batch_preds, preds_p], dim=1) except: - print(f"Error at iteration: {i}") + log(f"Error at iteration: {i}") perm_pred = model.scorenet(feats_p) # Postprocessing. @@ -302,18 +302,18 @@ def predict(self, image, tgt): out_coords.extend(all_coords) out_confs.extend(out_confs) - print(f"preds_f shape: {preds_f.shape}") - print(f"preds_f grad: {preds_f.requires_grad}") - print(f"preds_f min: {preds_f.min()}, max: {preds_f.max()}") + log(f"preds_f shape: {preds_f.shape}") + log(f"preds_f grad: {preds_f.requires_grad}") + log(f"preds_f min: {preds_f.min()}, max: {preds_f.max()}") - print(f"perm_mat shape: {perm_mat.shape}") - print(f"perm_mat grad: {perm_mat.requires_grad}") - print(f"perm_mat min: {perm_mat.min()}, max: {preds_f.max()}") + log(f"perm_mat shape: {perm_mat.shape}") + log(f"perm_mat grad: {perm_mat.requires_grad}") + log(f"perm_mat min: {perm_mat.min()}, max: {preds_f.max()}") - print(f"batch_preds shape: {batch_preds.shape}") - print(f"batch_preds grad: {batch_preds.requires_grad}") - print(f"batch_preds min: {batch_preds.min()}, max: {batch_preds.max()}") + log(f"batch_preds shape: {batch_preds.shape}") + log(f"batch_preds grad: {batch_preds.requires_grad}") + log(f"batch_preds min: {batch_preds.min()}, max: {batch_preds.max()}") - print(f"loss : {loss}") - print(f"loss grad: {loss.requires_grad}") + log(f"loss : {loss}") + log(f"loss grad: {loss.requires_grad}") diff --git a/predict_mass_roads_test_set.py b/predict_mass_roads_test_set.py index 10b21c5..37ee5cc 100644 --- a/predict_mass_roads_test_set.py +++ b/predict_mass_roads_test_set.py @@ -23,55 +23,11 @@ test_generate, postprocess, permutations_to_polygons, + calculate_slice_bboxes, ) import time -# adapted from https://github.com/obss/sahi/blob/e798c80d6e09079ae07a672c89732dd602fe9001/sahi/slicing.py#L30, MIT License -def calculate_slice_bboxes( - image_height: int, - image_width: int, - slice_height: int = 512, - slice_width: int = 512, - overlap_height_ratio: float = 0.2, - overlap_width_ratio: float = 0.2, -) -> list[list[int]]: - """ - Given the height and width of an image, calculates how to divide the image into - overlapping slices according to the height and width provided. These slices are returned - as bounding boxes in xyxy format. - :param image_height: Height of the original image. - :param image_width: Width of the original image. - :param slice_height: Height of each slice - :param slice_width: Width of each slice - :param overlap_height_ratio: Fractional overlap in height of each slice (e.g. an overlap of 0.2 for a slice of size 100 yields an overlap of 20 pixels) - :param overlap_width_ratio: Fractional overlap in width of each slice (e.g. an overlap of 0.2 for a slice of size 100 yields an overlap of 20 pixels) - :return: a list of bounding boxes in xyxy format - """ - - slice_bboxes = [] - y_max = y_min = 0 - y_overlap = int(overlap_height_ratio * slice_height) - x_overlap = int(overlap_width_ratio * slice_width) - while y_max < image_height: - x_min = x_max = 0 - y_max = y_min + slice_height - while x_max < image_width: - x_max = x_min + slice_width - if y_max > image_height or x_max > image_width: - xmax = min(image_width, x_max) - ymax = min(image_height, y_max) - xmin = max(0, xmax - slice_width) - ymin = max(0, ymax - slice_height) - slice_bboxes.append([xmin, ymin, xmax, ymax]) - else: - slice_bboxes.append([x_min, y_min, x_max, y_max]) - x_min = x_max - x_overlap - y_min = y_max - y_overlap - - return slice_bboxes - - def get_rectangle_params_from_pascal_bbox(bbox): xmin_top_left, ymin_top_left, xmax_bottom_right, ymax_bottom_right = bbox diff --git a/utils.py b/utils.py index 33eab91..37c70d0 100644 --- a/utils.py +++ b/utils.py @@ -10,6 +10,9 @@ from torchmetrics.functional.classification import binary_jaccard_index, binary_accuracy from config import CFG +import sys +from datetime import datetime + def seed_everything(seed=1234): random.seed(seed) @@ -53,7 +56,7 @@ def create_mask(tgt, pad_idx): tgt_seq_len = tgt.size(1) tgt_mask = generate_square_subsequent_mask(tgt_seq_len) - tgt_padding_mask = (tgt == pad_idx) + tgt_padding_mask = (tgt == pad_idx).float().masked_fill(tgt == pad_idx, float('-inf')) return tgt_mask, tgt_padding_mask @@ -356,3 +359,57 @@ def save_single_predictions_as_images( torchvision.utils.save_image(y_perm[:, None, :, :]*255, f"{folder}/gt_perm_matrix_{idx}.png") return metrics_dict + + +def log(message: str, level: str = "INFO") -> None: + """Simple logging function that outputs to stderr.""" + timestamp = datetime.now().isoformat() + print(f"[{timestamp}] [{level}] {message}", file=sys.stderr, flush=True) + + +def calculate_slice_bboxes( + image_height: int, + image_width: int, + slice_height: int = 512, + slice_width: int = 512, + overlap_height_ratio: float = 0.2, + overlap_width_ratio: float = 0.2, +) -> list[list[int]]: + """ + Given the height and width of an image, calculates how to divide the image into + overlapping slices according to the height and width provided. These slices are returned + as bounding boxes in xyxy format. + + Args: + image_height: Height of the original image. + image_width: Width of the original image. + slice_height: Height of each slice + slice_width: Width of each slice + overlap_height_ratio: Fractional overlap in height of each slice (e.g. an overlap of 0.2 for a slice of size 100 yields an overlap of 20 pixels) + overlap_width_ratio: Fractional overlap in width of each slice (e.g. an overlap of 0.2 for a slice of size 100 yields an overlap of 20 pixels) + + Returns: + A list of bounding boxes in xyxy format + """ + slice_bboxes = [] + y_max = y_min = 0 + y_overlap = int(overlap_height_ratio * slice_height) + x_overlap = int(overlap_width_ratio * slice_width) + + while y_max < image_height: + x_min = x_max = 0 + y_max = y_min + slice_height + while x_max < image_width: + x_max = x_min + slice_width + if y_max > image_height or x_max > image_width: + xmax = min(image_width, x_max) + ymax = min(image_height, y_max) + xmin = max(0, xmax - slice_width) + ymin = max(0, ymax - slice_height) + slice_bboxes.append([xmin, ymin, xmax, ymax]) + else: + slice_bboxes.append([x_min, y_min, x_max, y_max]) + x_min = x_max - x_overlap + y_min = y_max - y_overlap + + return slice_bboxes From 6a0091057c6a52e5709a8607c5120caf8ba28dcf Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Mon, 16 Jun 2025 16:01:23 -0500 Subject: [PATCH 02/45] Implement a command line utility to run inference on a single image. --- config.py | 20 ++- infer_single_image.py | 64 ++++++++ polygon_inference.py | 330 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 413 insertions(+), 1 deletion(-) create mode 100644 infer_single_image.py create mode 100644 polygon_inference.py diff --git a/config.py b/config.py index f63a61b..1519f40 100644 --- a/config.py +++ b/config.py @@ -56,6 +56,24 @@ class CFG: perm_loss_weight = 10.0 SHUFFLE_TOKENS = False # order gt vertex tokens randomly every time + # Tiling configuration + TILE_SIZE = IMG_SIZE + TILE_OVERLAP = 34 + + # Polygon filtering configuration + # MIN_POLYGON_AREA: Minimum area in square pixels for a polygon to be considered valid + # Conversion: 100 sq ft ≈ 9.29 sq meters + # Since 1 pixel = 0.3m, 1 sq pixel = 0.09 sq meters + # Therefore, 9.29 sq meters ≈ 103 sq pixels + MIN_POLYGON_AREA = 103 + + # POLYGON_SIMPLIFICATION_TOLERANCE: Maximum distance in pixels that a point can be moved during polygon simplification + # A pixel is 0.3m, so 2 pixels is 0.6m, or 2 feet, which should be of no consequence + POLYGON_SIMPLIFICATION_TOLERANCE = 2 + + # Prediction configuration + PREDICTION_BATCH_SIZE = 8 # Batch size for processing tiles during prediction + BATCH_SIZE = 24 # batch size per gpu; effective batch size = BATCH_SIZE * NUM_GPUs START_EPOCH = 0 NUM_EPOCHS = 500 @@ -65,7 +83,7 @@ class CFG: SAVE_EVERY = 10 VAL_EVERY = 1 - MODEL_NAME = f'vit_small_patch{PATCH_SIZE}_{INPUT_SIZE}_dino' + MODEL_NAME = f'vit_small_patch{PATCH_SIZE}_{INPUT_SIZE}.dino' NUM_PATCHES = int((INPUT_SIZE // PATCH_SIZE) ** 2) LR = 4e-4 diff --git a/infer_single_image.py b/infer_single_image.py new file mode 100644 index 0000000..d7efb48 --- /dev/null +++ b/infer_single_image.py @@ -0,0 +1,64 @@ +import os +import json +import numpy as np +import cv2 +import matplotlib.pyplot as plt +import argparse +import sys + +from polygon_inference import PolygonInference +from utils import log + +parser = argparse.ArgumentParser() +parser.add_argument("-e", "--experiment_path", help="path to experiment folder to evaluate") +args = parser.parse_args() + +def main(): + # Load image from stdin + image_data = sys.stdin.buffer.read() + + # Initialize inference + inference = PolygonInference(args.experiment_path) + + # Get inference results + polygons_list = inference.infer(image_data) + + # Decode image for visualization + nparr = np.frombuffer(image_data, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + if image is None: + log("Failed to load image from stdin", "ERROR") + return + + # Get image dimensions + height, width = image.shape[:2] + + # Create figure with exact image dimensions + plt.figure(figsize=(width/100, height/100), dpi=100) + + # Plot merged result + vis_image_merged = image.copy() + formatted_contours = [np.array(cnt).reshape(-1, 1, 2).astype(np.int32) for cnt in polygons_list] + cv2.drawContours(vis_image_merged, formatted_contours, -1, (0, 255, 0), 1) + + # Draw dots at vertices for merged result + for contour in formatted_contours: + for point in contour: + x, y = point[0] + # Draw 2x2 red square + vis_image_merged[y-1:y+1, x-1:x+1] = [255, 0, 0] + + plt.imshow(vis_image_merged) + plt.axis('off') + plt.subplots_adjust(left=0, right=1, top=1, bottom=0) + + # Save visualization + output_path = os.path.join(f"visualization.png") + plt.savefig(output_path, dpi=100, bbox_inches='tight', pad_inches=0) + plt.close() + + # Print polygons to stdout + print(json.dumps(polygons_list)) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/polygon_inference.py b/polygon_inference.py new file mode 100644 index 0000000..433b99a --- /dev/null +++ b/polygon_inference.py @@ -0,0 +1,330 @@ +# Standard library imports +import os +import tempfile +import hashlib +from pathlib import Path +from typing import List, Tuple, Dict, Optional, Any + +# Third-party imports +import numpy as np +import cv2 +import torch +import albumentations as A +from albumentations.pytorch import ToTensorV2 +import shapely.geometry +import shapely.ops + +# Local imports +from config import CFG +from tokenizer import Tokenizer +from utils import ( + seed_everything, + test_generate, + postprocess, + permutations_to_polygons, + log, + calculate_slice_bboxes, +) +from models.model import ( + Encoder, + Decoder, + EncoderDecoder +) + +class PolygonInference: + def __init__(self, experiment_path: str, device: Optional[str] = None) -> None: + """Initialize the polygon inference with a trained model. + + Args: + experiment_path (str): Path to the experiment folder containing the model checkpoint + device (str | None, optional): Device to run the model on. Defaults to CFG.DEVICE + """ + self.device: str = device or CFG.DEVICE + self.experiment_path: str = os.path.realpath(experiment_path) + self.model: Optional[EncoderDecoder] = None + self.tokenizer: Optional[Tokenizer] = None + self._initialize_model() + + # Create persistent temporary directory for caching + self.cache_dir: Path = Path(tempfile.gettempdir()) / "pix2poly_cache" + self.cache_dir.mkdir(parents=True, exist_ok=True) + + def _initialize_model(self) -> None: + """Initialize the model and tokenizer. + + This method: + 1. Creates a new tokenizer instance + 2. Initializes the encoder-decoder model + 3. Loads the latest checkpoint from the experiment directory + """ + self.tokenizer = Tokenizer( + num_classes=1, + num_bins=CFG.NUM_BINS, + width=CFG.INPUT_WIDTH, + height=CFG.INPUT_HEIGHT, + max_len=CFG.MAX_LEN + ) + CFG.PAD_IDX = self.tokenizer.PAD_code + + encoder = Encoder(model_name=CFG.MODEL_NAME, pretrained=True, out_dim=256) + decoder = Decoder( + cfg=CFG, + vocab_size=self.tokenizer.vocab_size, + encoder_len=CFG.NUM_PATCHES, + dim=256, + num_heads=8, + num_layers=6 + ) + self.model = EncoderDecoder(cfg=CFG, encoder=encoder, decoder=decoder) + self.model.to(self.device) + self.model.eval() + + # Load latest checkpoint + latest_checkpoint = self._find_latest_checkpoint() + checkpoint_path = os.path.join(self.experiment_path, "logs", "checkpoints", latest_checkpoint) + log(f"Loading checkpoint from {checkpoint_path}") + checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) + self.model.load_state_dict(checkpoint['state_dict']) + log("Checkpoint loaded successfully") + + def _find_latest_checkpoint(self) -> str: + """Find the checkpoint with the highest epoch number. + + Returns: + str: Filename of the latest checkpoint + """ + checkpoint_dir = os.path.join(self.experiment_path, "logs", "checkpoints") + checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith('epoch_') and f.endswith('.pth')] + latest_checkpoint = sorted(checkpoint_files)[-1] + return latest_checkpoint + + def _get_tile_hash(self, tile: np.ndarray) -> str: + """Generate a hash for the input tile. + + Args: + tile (np.ndarray): Input image tile + + Returns: + str: MD5 hash of the tile + """ + return hashlib.md5(tile.tobytes()).hexdigest() + + def _load_from_cache(self, cache_key: str) -> Optional[Dict[str, Any]]: + """Load cached result if it exists. + + Args: + cache_key (str): Key to look up in the cache + + Returns: + dict | None: Cached result if found, None otherwise + """ + cache_path = self.cache_dir / f"{cache_key}.npy" + if cache_path.exists(): + return np.load(cache_path, allow_pickle=True).item() + return None + + def _save_to_cache(self, cache_key: str, result: Dict[str, Any]) -> None: + """Save result to cache. + + Args: + cache_key (str): Key to store in the cache + result (dict): Result to cache + """ + cache_path = self.cache_dir / f"{cache_key}.npy" + np.save(cache_path, result, allow_pickle=True) + + def _process_tiles_batch(self, tiles: List[np.ndarray]) -> List[Dict[str, List[np.ndarray]]]: + """Process a single batch of tiles. + + Args: + tiles (list[np.ndarray]): List of tile images to process + + Returns: + list[dict]: List of results for each tile, where each result contains: + - polygons: List of polygon coordinates + """ + valid_transforms = A.Compose([ + A.Resize(height=CFG.INPUT_HEIGHT, width=CFG.INPUT_WIDTH), + A.Normalize( + mean=[0.0, 0.0, 0.0], + std=[1.0, 1.0, 1.0], + max_pixel_value=255.0 + ), + ToTensorV2(), + ]) + + # Transform each tile individually and stack them + transformed_tiles: List[torch.Tensor] = [] + for tile in tiles: + transformed = valid_transforms(image=tile) + transformed_tiles.append(transformed['image']) + + # Stack the transformed tiles into a batch + batch_tensor = torch.stack(transformed_tiles).to(self.device) + + with torch.no_grad(): + batch_preds, batch_confs, perm_preds = test_generate( + self.model, batch_tensor, self.tokenizer, + max_len=CFG.generation_steps, + top_k=0, + top_p=1 + ) + + vertex_coords, confs = postprocess(batch_preds, batch_confs, self.tokenizer) + + results: List[Dict[str, List[np.ndarray]]] = [] + for j in range(len(tiles)): + if vertex_coords[j] is not None: + coord = torch.from_numpy(vertex_coords[j]) + else: + coord = torch.tensor([]) + + padd = torch.ones((CFG.N_VERTICES - len(coord), 2)).fill_(CFG.PAD_IDX) + coord = torch.cat([coord, padd], dim=0) + + batch_polygons = permutations_to_polygons(perm_preds[j:j+1], [coord], out='torch') + + valid_polygons: List[np.ndarray] = [] + for poly in batch_polygons[0]: + poly = poly[poly[:, 0] != CFG.PAD_IDX] + if len(poly) > 0: + valid_polygons.append(poly.cpu().numpy()[:, ::-1]) # Convert to [x,y] format + + result = { + 'polygons': valid_polygons + } + + # Cache the result + cache_key = self._get_tile_hash(tiles[j]) + self._save_to_cache(cache_key, result) + + results.append(result) + + return results + + def _merge_polygons(self, tile_results: List[Dict[str, List[np.ndarray]]], positions: List[Tuple[int, int, int, int]]) -> List[np.ndarray]: + """Merge polygon predictions from multiple tiles into a single set of polygons. + + Args: + tile_results (list[dict]): List of dictionaries containing 'polygons' for each tile + positions (list[tuple[int, int, int, int]]): List of (x, y, x_end, y_end) tuples for each tile's position + + Returns: + list[np.ndarray]: List of merged polygons in original image coordinates + """ + all_polygons: List[shapely.geometry.Polygon] = [] + + # Transform each polygon from tile coordinates to original image coordinates + for tile_result, (x, y, x_end, y_end) in zip(tile_results, positions): + tile_polygons = tile_result['polygons'] + height, width = CFG.INPUT_HEIGHT, CFG.INPUT_WIDTH + + # Transform each polygon from tile coordinates to original image coordinates + for poly in tile_polygons: + # Scale coordinates to match tile size + scaled_poly = poly * np.array([width/CFG.INPUT_WIDTH, height/CFG.INPUT_HEIGHT]) + + # Translate to original image position + translated_poly = scaled_poly + np.array([x, y]) + + # Convert to shapely polygon + shapely_poly = shapely.geometry.Polygon(translated_poly) + if shapely_poly.is_valid and shapely_poly.area > CFG.MIN_POLYGON_AREA: + all_polygons.append(shapely_poly) + + # Use shapely's unary_union to merge overlapping polygons + merged = shapely.ops.unary_union(all_polygons) + + simplified_polygons: List[np.ndarray] = [] + + # Merging tiles may have created redundant points in the middle of a line. We can remove them. + for poly in merged.geoms: + simplified = poly.simplify(CFG.POLYGON_SIMPLIFICATION_TOLERANCE) + if simplified.is_valid and not simplified.is_empty: + simplified_polygons.append(np.array(simplified.exterior.coords)) + + return simplified_polygons + + def infer(self, image_data: bytes) -> List[List[List[float]]]: + """Infer polygons in an image. + + Args: + image_data (bytes): Raw image data + + Returns: + list[list[list[float]]]: List of polygons where each polygon is a list of [x,y] coordinates. + Each coordinate is rounded to 2 decimal places. + + Raises: + ValueError: If the image data is invalid, empty, or cannot be decoded + RuntimeError: If there are issues with model prediction or polygon processing + """ + if not image_data: + raise ValueError("Empty image data provided") + + seed_everything(42) + + # Decode image + nparr = np.frombuffer(image_data, np.uint8) + image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + if image is None: + raise ValueError("Failed to decode image data") + + if image.size == 0: + raise ValueError("Decoded image is empty") + + # Convert to RGB + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + # Split image into tiles + height, width = image.shape[:2] + if height == 0 or width == 0: + raise ValueError("Invalid image dimensions") + + overlap_ratio = CFG.TILE_OVERLAP / CFG.TILE_SIZE + + bboxes = calculate_slice_bboxes( + image_height=height, + image_width=width, + slice_height=CFG.TILE_SIZE, + slice_width=CFG.TILE_SIZE, + overlap_height_ratio=overlap_ratio, + overlap_width_ratio=overlap_ratio + ) + + tiles: List[np.ndarray] = [] + tiles_to_process: List[np.ndarray] = [] + + for bbox in bboxes: + x1, y1, x2, y2 = bbox + tile = image[y1:y2, x1:x2] + if tile.size == 0: + continue + tiles.append(tile) + + # Process tiles in batches + all_results: List[Dict[str, List[np.ndarray]]] = [] + + # First check cache for all tiles + for tile in tiles: + cache_key = self._get_tile_hash(tile) + cached_result = self._load_from_cache(cache_key) + if cached_result is not None: + all_results.append(cached_result) + else: + tiles_to_process.append(tile) + + # Process remaining tiles in batches + for i in range(0, len(tiles_to_process), CFG.PREDICTION_BATCH_SIZE): + batch_tiles = tiles_to_process[i:i + CFG.PREDICTION_BATCH_SIZE] + batch_results = self._process_tiles_batch(batch_tiles) + all_results.extend(batch_results) + + merged_polygons = self._merge_polygons(all_results, bboxes) + + # Convert to list format + polygons_list = [poly.tolist() for poly in merged_polygons] + # Round coordinates to two decimal places + polygons_list = [[[round(x, 2), round(y, 2)] for x, y in polygon] for polygon in polygons_list] + + return polygons_list \ No newline at end of file From b1dc370a95ad44a12c2e0d50d12b61f5d99cd2bd Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Mon, 16 Jun 2025 16:01:50 -0500 Subject: [PATCH 03/45] Implement an API server to run inference on a single image. --- api.py | 110 +++++++++++++++++++++++++++++++++++++++++++++++++++ start_api.sh | 15 +++++++ 2 files changed, 125 insertions(+) create mode 100644 api.py create mode 100755 start_api.sh diff --git a/api.py b/api.py new file mode 100644 index 0000000..7cd33d0 --- /dev/null +++ b/api.py @@ -0,0 +1,110 @@ +import os +import json +from fastapi import FastAPI, UploadFile, HTTPException, Request +from fastapi.responses import JSONResponse +from fastapi.middleware.cors import CORSMiddleware +import uvicorn +from contextlib import asynccontextmanager +import base64 + +from polygon_inference import PolygonInference +from utils import log + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Initialize the predictor on startup.""" + experiment_path = os.getenv("EXPERIMENT_PATH") + if not experiment_path: + raise ValueError("EXPERIMENT_PATH environment variable must be set") + init_predictor(experiment_path) + yield + +app = FastAPI( + title="Polygon Inference API", + description="API for inferring polygons in images using a trained model", + version="1.0.0", + lifespan=lifespan +) + +# Add CORS middleware +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], # Allows all origins + allow_credentials=True, + allow_methods=["*"], # Allows all methods + allow_headers=["*"], # Allows all headers +) + +# Global predictor instance +predictor = None + +def init_predictor(experiment_path: str): + """Initialize the global predictor instance.""" + global predictor + if predictor is None: + predictor = PolygonInference(experiment_path) + +@app.post("/invocations") +async def invoke(request: Request, file: UploadFile = None): + """Main inference endpoint for processing images. + + The endpoint accepts image data in three different ways: + 1. As a file upload using multipart/form-data (via the file parameter) + 2. As a base64-encoded image in a JSON payload with an 'image' field + 3. As raw image data in the request body + + Args: + request: The request containing the image data + file: Optional uploaded file (multipart/form-data) + + Returns: + JSON response containing the inferred polygons + + Raises: + HTTPException: 400 if no image data is found in the request + HTTPException: 500 if there is an error processing the image + """ + try: + if file: + # Handle file upload + image_data = await file.read() + else: + # Read request body + body = await request.body() + + # Parse the request body + try: + data = json.loads(body) + if 'image' in data: + # Handle base64 encoded image + image_data = base64.b64decode(data['image']) + else: + raise HTTPException(status_code=400, detail="No image data found in request") + except json.JSONDecodeError: + # Handle raw image data + image_data = body + + # Get inferences + polygons = predictor.infer(image_data) + + # Prepare response + response = { + "polygons": polygons, + } + + return JSONResponse(content=response) + + except Exception as e: + log(f"Error processing image: {str(e)}", "ERROR") + raise HTTPException(status_code=500, detail=str(e)) + +@app.get("/ping") +async def ping(): + """Health check endpoint to verify service status.""" + if predictor is None: + raise HTTPException(status_code=503, detail="Model not loaded") + return {"status": "healthy"} + +if __name__ == "__main__": + # Run the API + uvicorn.run(app, host="0.0.0.0", port=8080) \ No newline at end of file diff --git a/start_api.sh b/start_api.sh new file mode 100755 index 0000000..8367df8 --- /dev/null +++ b/start_api.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +# Exit on error +set -e + +# Set experiment path +export EXPERIMENT_PATH=runs/Pix2Poly_inria_coco_224 + +# Activate conda environment +source $(conda info --base)/etc/profile.d/conda.sh +conda activate pix2poly + +# Start the API server +echo "Starting API server with experiment path: $EXPERIMENT_PATH" +uvicorn api:app --reload --port 8080 From 2051bcca9d4eb8ec3f444a2c37cb4af67d337c7a Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Mon, 16 Jun 2025 16:02:09 -0500 Subject: [PATCH 04/45] Implement environment.yml to let Conda automatically install all needed dependencies. --- environment.yml | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) create mode 100644 environment.yml diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000..3fadcb3 --- /dev/null +++ b/environment.yml @@ -0,0 +1,32 @@ +name: pix2poly +channels: + - conda-forge + - pytorch + - nvidia +dependencies: + - python=3.11 + - timm=0.9.12 + - transformers=4.32.1 + - pycocotools=2.0.6 + - torchmetrics=1.2.1 + - tensorboard=2.15.1 + - pip + - pip: + - torch==2.1.2 + - torchvision==0.16.2 + - torchaudio==2.1.2 + - albumentations==1.3.1 + - imageio==2.33.1 + - matplotlib-inline==0.1.6 + - opencv-python-headless==4.8.1.78 + - scikit-image==0.22.0 + - scikit-learn==1.3.2 + - scipy==1.11.4 + - shapely==2.0.4 + - fastapi>=0.68.0 + - uvicorn>=0.15.0 + - python-multipart>=0.0.5 + - tqdm>=4.62.0 + - pulumi>=3.0.0 + - pulumi-aws>=5.0.0 + \ No newline at end of file From b8f0299b9152a704b7bd098b6ebd7fc4f3d650d4 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Tue, 17 Jun 2025 16:39:16 -0500 Subject: [PATCH 05/45] Add a simple API key system and a working Dockerfile. --- Dockerfile | 34 ++++++++++++++++++++++++++++++++ api.py | 58 +++++++++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 89 insertions(+), 3 deletions(-) create mode 100644 Dockerfile diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..fb52659 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,34 @@ +FROM pytorch/pytorch:2.1.2-cuda11.8-cudnn8-runtime + +ENV DEBIAN_FRONTEND=noninteractive + +# Install system dependencies +RUN apt-get update && apt-get install -y \ + build-essential \ + curl \ + software-properties-common \ + git \ + wget \ + && rm -rf /var/lib/apt/lists/* + +# Set working directory +WORKDIR /opt/program + +# Copy environment file +COPY environment.yml . + +# Create conda environment +RUN conda env create -f environment.yml && \ + conda clean -afy + +# Copy the model code +COPY . . + +# Set environment variables +ENV PYTHONPATH=/opt/program +ENV EXPERIMENT_PATH=/opt/ml/model +ENV OPENBLAS_NUM_THREADS=1 + +# Activate conda environment and set the entrypoint +SHELL ["/bin/bash", "-c"] +ENTRYPOINT ["/bin/bash", "-c", "source /opt/conda/etc/profile.d/conda.sh && conda activate pix2poly && python api.py"] \ No newline at end of file diff --git a/api.py b/api.py index 7cd33d0..186d53a 100644 --- a/api.py +++ b/api.py @@ -1,15 +1,60 @@ import os import json -from fastapi import FastAPI, UploadFile, HTTPException, Request +from fastapi import FastAPI, UploadFile, HTTPException, Request, Depends, Query from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware +from fastapi.security import APIKeyHeader import uvicorn from contextlib import asynccontextmanager import base64 +from typing import Optional from polygon_inference import PolygonInference from utils import log +# API Key configuration +API_KEY_NAME = "X-API-Key" +api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) + +# Get API key from environment variable +API_KEY = os.getenv("API_KEY", "") + +async def verify_api_key( + header_key: Optional[str] = Depends(api_key_header), + query_key: Optional[str] = Query(None, alias="api_key") +) -> Optional[str]: + """Verify the API key from either header or query parameter. + + If API authentication is not enabled (no API key configured), + this function will always return None. + + Args: + header_key: API key from X-API-Key header + query_key: API key from api_key query parameter + + Returns: + The verified API key or None if authentication is disabled + + Raises: + HTTPException: 401 if API key is missing (when required) + HTTPException: 403 if API key is invalid (when required) + """ + if not API_KEY: + return None + + api_key = header_key or query_key + if not api_key: + raise HTTPException( + status_code=401, + detail="API key is missing" + ) + if api_key != API_KEY: + raise HTTPException( + status_code=403, + detail="Invalid API key" + ) + return api_key + @asynccontextmanager async def lifespan(app: FastAPI): """Initialize the predictor on startup.""" @@ -45,7 +90,7 @@ def init_predictor(experiment_path: str): predictor = PolygonInference(experiment_path) @app.post("/invocations") -async def invoke(request: Request, file: UploadFile = None): +async def invoke(request: Request, file: UploadFile = None, api_key: Optional[str] = Depends(verify_api_key)): """Main inference endpoint for processing images. The endpoint accepts image data in three different ways: @@ -53,9 +98,14 @@ async def invoke(request: Request, file: UploadFile = None): 2. As a base64-encoded image in a JSON payload with an 'image' field 3. As raw image data in the request body + Authentication can be provided in two ways: + 1. Via the X-API-Key header + 2. Via the api_key query parameter + Args: request: The request containing the image data file: Optional uploaded file (multipart/form-data) + api_key: Optional API key for authentication (required only if API key is configured) Returns: JSON response containing the inferred polygons @@ -63,6 +113,8 @@ async def invoke(request: Request, file: UploadFile = None): Raises: HTTPException: 400 if no image data is found in the request HTTPException: 500 if there is an error processing the image + HTTPException: 401 if API key is missing (when API key is configured) + HTTPException: 403 if API key is invalid (when API key is configured) """ try: if file: @@ -99,7 +151,7 @@ async def invoke(request: Request, file: UploadFile = None): raise HTTPException(status_code=500, detail=str(e)) @app.get("/ping") -async def ping(): +async def ping(api_key: Optional[str] = Depends(verify_api_key)): """Health check endpoint to verify service status.""" if predictor is None: raise HTTPException(status_code=503, detail="Model not loaded") From 66196da65ba1a981ac41cfa8aa0a3415dcf8e689 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Wed, 18 Jun 2025 16:47:17 -0500 Subject: [PATCH 06/45] Automatically download and decompress model files. --- api.py | 56 ++++++++++++++++++++++++++++++++++++++++++++----- environment.yml | 1 + start_api.sh | 4 ++-- 3 files changed, 54 insertions(+), 7 deletions(-) diff --git a/api.py b/api.py index 186d53a..fa590b6 100644 --- a/api.py +++ b/api.py @@ -8,6 +8,9 @@ from contextlib import asynccontextmanager import base64 from typing import Optional +import gdown +import shutil +from pathlib import Path from polygon_inference import PolygonInference from utils import log @@ -17,7 +20,8 @@ api_key_header = APIKeyHeader(name=API_KEY_NAME, auto_error=False) # Get API key from environment variable -API_KEY = os.getenv("API_KEY", "") +API_KEY = os.getenv("API_KEY") +EXPERIMENT_PATH = os.getenv("EXPERIMENT_PATH") async def verify_api_key( header_key: Optional[str] = Depends(api_key_header), @@ -55,13 +59,55 @@ async def verify_api_key( ) return api_key +def download_model_files(model_url: str, target_dir: str) -> str: + """Download model files from Google Drive to the target directory. + + Args: + model_url: Google Drive URL to download the model files from + target_dir: Directory to save the model files to + + Returns: + Path to the downloaded model directory + + Raises: + ValueError: If download fails or model files are invalid + """ + try: + # Create target directory if it doesn't exist + target_path = Path(target_dir) + target_path.mkdir(parents=True, exist_ok=True) + + # Check if model files already exist + if target_path.exists() and any(target_path.iterdir()): + log(f"Model files already exist in {target_dir}, skipping download", "INFO") + return str(target_path) + + # Download the model files using gdown + zip_path = target_path / "runs_share.zip" + gdown.download(model_url, str(zip_path), quiet=False) + + # Extract the zip file + shutil.unpack_archive(zip_path, target_path) + + # Remove the zip file + zip_path.unlink() + + return str(target_path) + + except Exception as e: + raise ValueError(f"Failed to download model files: {str(e)}") + @asynccontextmanager async def lifespan(app: FastAPI): """Initialize the predictor on startup.""" - experiment_path = os.getenv("EXPERIMENT_PATH") - if not experiment_path: - raise ValueError("EXPERIMENT_PATH environment variable must be set") - init_predictor(experiment_path) + # Download model files to a temporary directory + model_dir = download_model_files( + "https://drive.google.com/uc?id=1oEs2n81nMAzdY4G9bdrji13pOKk6MOET", + "/tmp/pix2poly_model" + ) + + # Initialize predictor with downloaded model + init_predictor(model_dir + "/" + EXPERIMENT_PATH) yield app = FastAPI( diff --git a/environment.yml b/environment.yml index 3fadcb3..b66fab1 100644 --- a/environment.yml +++ b/environment.yml @@ -29,4 +29,5 @@ dependencies: - tqdm>=4.62.0 - pulumi>=3.0.0 - pulumi-aws>=5.0.0 + - gdown>=4.7.1 \ No newline at end of file diff --git a/start_api.sh b/start_api.sh index 8367df8..7e2d458 100755 --- a/start_api.sh +++ b/start_api.sh @@ -4,7 +4,7 @@ set -e # Set experiment path -export EXPERIMENT_PATH=runs/Pix2Poly_inria_coco_224 +export EXPERIMENT_PATH=runs_share/Pix2Poly_inria_coco_224 # Activate conda environment source $(conda info --base)/etc/profile.d/conda.sh @@ -12,4 +12,4 @@ conda activate pix2poly # Start the API server echo "Starting API server with experiment path: $EXPERIMENT_PATH" -uvicorn api:app --reload --port 8080 +uvicorn api:app --reload --port 8080 --workers 1 --limit-concurrency 1 From 54a2bdfb67f6c8455bafde794941049e26579b27 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Fri, 20 Jun 2025 14:01:26 -0500 Subject: [PATCH 07/45] Documentation, formatting, improved caching, download models from Github. --- Dockerfile | 6 +- README.md | 77 ++++++++++++ api.py | 229 ++++++++++++++++++++++-------------- environment.yml | 4 +- infer_single_image.py | 8 +- polygon_inference.py | 264 +++++++++++++++++++----------------------- start_api.sh | 7 +- utils.py | 3 +- 8 files changed, 350 insertions(+), 248 deletions(-) diff --git a/Dockerfile b/Dockerfile index fb52659..cc7897a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -26,9 +26,7 @@ COPY . . # Set environment variables ENV PYTHONPATH=/opt/program -ENV EXPERIMENT_PATH=/opt/ml/model ENV OPENBLAS_NUM_THREADS=1 -# Activate conda environment and set the entrypoint -SHELL ["/bin/bash", "-c"] -ENTRYPOINT ["/bin/bash", "-c", "source /opt/conda/etc/profile.d/conda.sh && conda activate pix2poly && python api.py"] \ No newline at end of file +# Use the startup script as entrypoint +ENTRYPOINT ["./start_api.sh"] \ No newline at end of file diff --git a/README.md b/README.md index c7bf052..48ba574 100644 --- a/README.md +++ b/README.md @@ -166,3 +166,80 @@ This repository benefits from the following open-source work. We thank the autho 3. [Frame Field Learning](https://github.com/Lydorn/Polygonization-by-Frame-Field-Learning) 4. [PolyWorld](https://github.com/zorzi-s/PolyWorldPretrainedNetwork) 5. [HiSup](https://github.com/SarahwXU/HiSup) + +## Docker Usage + +Pix2Poly provides a Docker setup for easy deployment and inference. The Docker container includes a FastAPI server for REST API inference and supports command-line inference. The API request and response format are suitable for running as a AWS Sagemaker inference endpoint running on a ml.g4dn.xlarge and the inference AMI version al2-ami-sagemaker-inference-gpu-3-1, where it is able to infer at a rate of 5-10 seconds per tile. + +### Building the Docker Image + +```bash +docker build -t pix2poly . +``` + +### Running the API Server + +The Docker container automatically starts a FastAPI server on port 8080. You can run it with: + +```bash +docker run -p 8080:8080 pix2poly +``` + +The API server will automatically download the pretrained model files on first startup and provide the following endpoints: + +- `POST /invocations` - Main inference endpoint for processing images +- `GET /ping` - Health check endpoint + +#### API Usage + +The `/invocations` endpoint accepts images in multiple formats: + +1. **File Upload (multipart/form-data):** +```bash +curl -X POST "http://localhost:8080/invocations" \ + -H "Content-Type: multipart/form-data" \ + -F "file=@your_image.jpg" +``` + +2. **Base64 Encoded Image:** +```bash +curl -X POST "http://localhost:8080/invocations" \ + -H "Content-Type: application/json" \ + -d '{"image": "base64_encoded_image_data"}' +``` + +3. **Raw Image Data:** +```bash +curl -X POST "http://localhost:8080/invocations" \ + -H "Content-Type: image/jpeg" \ + --data-binary @your_image.jpg +``` + +The API returns JSON with the detected polygons: + +```text +{ + "polygons": [ + [[x1, y1], [x2, y2], ...], + ... + ] +} +``` + +### Environment Variables + +You can customize the Docker container behavior with these environment variables: + +- `MODEL_URL`: URL to download the pretrained model files (default: `https://github.com/safelease/Pix2Poly/releases/download/main/runs_share.zip`) +- `EXPERIMENT_PATH`: Path to the experiment folder (default: `runs_share/Pix2Poly_inria_coco_224`) +- `API_KEY`: Optional API key for authentication (if not set, authentication is disabled) + +Example with custom configuration: +```bash +docker run -p 8080:8080 \ + -e MODEL_URL=https://github.com/safelease/Pix2Poly/releases/download/main/runs_share.zip \ + -e EXPERIMENT_PATH=runs_share/Pix2Poly_inria_coco_224 \ + -e API_KEY=your_secret_key \ + pix2poly +``` + diff --git a/api.py b/api.py index fa590b6..d9f97ba 100644 --- a/api.py +++ b/api.py @@ -1,5 +1,7 @@ import os import json +import hashlib +import tempfile from fastapi import FastAPI, UploadFile, HTTPException, Request, Depends, Query from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware @@ -8,9 +10,10 @@ from contextlib import asynccontextmanager import base64 from typing import Optional -import gdown +import requests import shutil from pathlib import Path +from diskcache import Cache from polygon_inference import PolygonInference from utils import log @@ -21,100 +24,125 @@ # Get API key from environment variable API_KEY = os.getenv("API_KEY") -EXPERIMENT_PATH = os.getenv("EXPERIMENT_PATH") +EXPERIMENT_PATH = os.getenv("EXPERIMENT_PATH", "runs_share/Pix2Poly_inria_coco_224") +MODEL_URL = os.getenv("MODEL_URL", "https://github.com/safelease/Pix2Poly/releases/download/main/runs_share.zip") + +# Cache configuration +CACHE_TTL = int(os.getenv("CACHE_TTL", 24 * 3600)) # 24 hours + +# Global cache instance +cache = Cache( + directory=os.path.join(tempfile.gettempdir(), "pix2poly_cache"), + timeout=1, # 1 second timeout for cache operations + disk_min_file_size=0, # Store all items on disk + disk_pickle_protocol=4, # Use protocol 4 for better compatibility +) + +def get_cache_key(image_data: bytes) -> str: + """Generate a cache key from image data. + + Args: + image_data: Raw image data + + Returns: + SHA-256 hash of the image data as a string + """ + return hashlib.sha256(image_data).hexdigest() + async def verify_api_key( header_key: Optional[str] = Depends(api_key_header), - query_key: Optional[str] = Query(None, alias="api_key") + query_key: Optional[str] = Query(None, alias="api_key"), ) -> Optional[str]: """Verify the API key from either header or query parameter. - + If API authentication is not enabled (no API key configured), this function will always return None. - + Args: header_key: API key from X-API-Key header query_key: API key from api_key query parameter - + Returns: The verified API key or None if authentication is disabled - + Raises: HTTPException: 401 if API key is missing (when required) HTTPException: 403 if API key is invalid (when required) """ if not API_KEY: return None - + api_key = header_key or query_key if not api_key: - raise HTTPException( - status_code=401, - detail="API key is missing" - ) + raise HTTPException(status_code=401, detail="API key is missing") if api_key != API_KEY: - raise HTTPException( - status_code=403, - detail="Invalid API key" - ) + raise HTTPException(status_code=403, detail="Invalid API key") return api_key + def download_model_files(model_url: str, target_dir: str) -> str: - """Download model files from Google Drive to the target directory. - + """Download model files to the target directory. + Args: - model_url: Google Drive URL to download the model files from + model_url: URL to download the model files from target_dir: Directory to save the model files to - + Returns: Path to the downloaded model directory - + Raises: ValueError: If download fails or model files are invalid """ - try: - # Create target directory if it doesn't exist - target_path = Path(target_dir) - target_path.mkdir(parents=True, exist_ok=True) - - # Check if model files already exist - if target_path.exists() and any(target_path.iterdir()): - log(f"Model files already exist in {target_dir}, skipping download", "INFO") - return str(target_path) - - # Download the model files using gdown - zip_path = target_path / "runs_share.zip" - gdown.download(model_url, str(zip_path), quiet=False) - - # Extract the zip file - shutil.unpack_archive(zip_path, target_path) - - # Remove the zip file - zip_path.unlink() - + # Create target directory if it doesn't exist + target_path = Path(target_dir) + target_path.mkdir(parents=True, exist_ok=True) + + # Check if model files already exist + if target_path.exists() and any(target_path.iterdir()): + log(f"Model files already exist in {target_dir}, skipping download", "INFO") return str(target_path) - - except Exception as e: - raise ValueError(f"Failed to download model files: {str(e)}") + + # Download the model files using requests + zip_path = target_path / "runs_share.zip" + + log(f"Downloading model files from {model_url}", "INFO") + response = requests.get(model_url, stream=True) + response.raise_for_status() + + with open(zip_path, 'wb') as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + # Extract the zip file + log(f"Extracting model files to {target_dir}", "INFO") + shutil.unpack_archive(zip_path, target_path) + + # Remove the zip file + zip_path.unlink() + + return str(target_path) + @asynccontextmanager async def lifespan(app: FastAPI): """Initialize the predictor on startup.""" # Download model files to a temporary directory model_dir = download_model_files( - "https://drive.google.com/uc?id=1oEs2n81nMAzdY4G9bdrji13pOKk6MOET", - "/tmp/pix2poly_model" + MODEL_URL, + "/tmp/pix2poly_model", ) - + # Initialize predictor with downloaded model - init_predictor(model_dir + "/" + EXPERIMENT_PATH) + init_predictor(os.path.join(model_dir, EXPERIMENT_PATH)) yield + app = FastAPI( title="Polygon Inference API", description="API for inferring polygons in images using a trained model", version="1.0.0", - lifespan=lifespan + lifespan=lifespan, ) # Add CORS middleware @@ -129,72 +157,87 @@ async def lifespan(app: FastAPI): # Global predictor instance predictor = None + def init_predictor(experiment_path: str): """Initialize the global predictor instance.""" global predictor if predictor is None: predictor = PolygonInference(experiment_path) + @app.post("/invocations") -async def invoke(request: Request, file: UploadFile = None, api_key: Optional[str] = Depends(verify_api_key)): +async def invoke( + request: Request, + file: UploadFile = None, + api_key: Optional[str] = Depends(verify_api_key), +): """Main inference endpoint for processing images. - + The endpoint accepts image data in three different ways: 1. As a file upload using multipart/form-data (via the file parameter) 2. As a base64-encoded image in a JSON payload with an 'image' field 3. As raw image data in the request body - + Authentication can be provided in two ways: 1. Via the X-API-Key header 2. Via the api_key query parameter - + Args: request: The request containing the image data file: Optional uploaded file (multipart/form-data) api_key: Optional API key for authentication (required only if API key is configured) - + Returns: JSON response containing the inferred polygons - + Raises: HTTPException: 400 if no image data is found in the request HTTPException: 500 if there is an error processing the image HTTPException: 401 if API key is missing (when API key is configured) HTTPException: 403 if API key is invalid (when API key is configured) """ - try: - if file: - # Handle file upload - image_data = await file.read() - else: - # Read request body - body = await request.body() - - # Parse the request body - try: - data = json.loads(body) - if 'image' in data: - # Handle base64 encoded image - image_data = base64.b64decode(data['image']) - else: - raise HTTPException(status_code=400, detail="No image data found in request") - except json.JSONDecodeError: - # Handle raw image data - image_data = body - - # Get inferences - polygons = predictor.infer(image_data) - - # Prepare response - response = { - "polygons": polygons, - } - - return JSONResponse(content=response) - - except Exception as e: - log(f"Error processing image: {str(e)}", "ERROR") - raise HTTPException(status_code=500, detail=str(e)) + log(f"Invoking image analysis") + + if file: + # Handle file upload + image_data = await file.read() + else: + # Read request body + body = await request.body() + + # Parse the request body + try: + data = json.loads(body) + if "image" in data: + # Handle base64 encoded image + image_data = base64.b64decode(data["image"]) + else: + raise HTTPException( + status_code=400, detail="No image data found in request" + ) + except json.JSONDecodeError: + # Handle raw image data + image_data = body + + # Generate cache key and check cache + cache_key = get_cache_key(image_data) + cached_result = cache.get(cache_key) + + if cached_result is not None: + return JSONResponse(content=cached_result) + + # Get inferences + polygons = predictor.infer(image_data) + + # Prepare response + response = { + "polygons": polygons, + } + + # Store result in cache + cache.set(cache_key, response, expire=CACHE_TTL) + + return JSONResponse(content=response) @app.get("/ping") async def ping(api_key: Optional[str] = Depends(verify_api_key)): @@ -203,6 +246,14 @@ async def ping(api_key: Optional[str] = Depends(verify_api_key)): raise HTTPException(status_code=503, detail="Model not loaded") return {"status": "healthy"} + +@app.get("/clear-cache") +async def clear_cache(api_key: Optional[str] = Depends(verify_api_key)): + """Clear the cache endpoint to remove all cached results.""" + cache.clear() + return {"status": "success", "message": "Cache cleared successfully"} + + if __name__ == "__main__": # Run the API - uvicorn.run(app, host="0.0.0.0", port=8080) \ No newline at end of file + uvicorn.run(app, host="0.0.0.0", port=8080) diff --git a/environment.yml b/environment.yml index b66fab1..28ba547 100644 --- a/environment.yml +++ b/environment.yml @@ -27,7 +27,5 @@ dependencies: - uvicorn>=0.15.0 - python-multipart>=0.0.5 - tqdm>=4.62.0 - - pulumi>=3.0.0 - - pulumi-aws>=5.0.0 - - gdown>=4.7.1 + - diskcache>=5.6.0 \ No newline at end of file diff --git a/infer_single_image.py b/infer_single_image.py index d7efb48..c3e98ca 100644 --- a/infer_single_image.py +++ b/infer_single_image.py @@ -45,8 +45,12 @@ def main(): for contour in formatted_contours: for point in contour: x, y = point[0] - # Draw 2x2 red square - vis_image_merged[y-1:y+1, x-1:x+1] = [255, 0, 0] + # Draw 2x2 red square. + y_min = max(0, y-1) + y_max = min(height, y+1) + x_min = max(0, x-1) + x_max = min(width, x+1) + vis_image_merged[y_min:y_max, x_min:x_max] = [255, 0, 0] plt.imshow(vis_image_merged) plt.axis('off') diff --git a/polygon_inference.py b/polygon_inference.py index 433b99a..fab738f 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -1,8 +1,6 @@ # Standard library imports import os -import tempfile -import hashlib -from pathlib import Path +import time from typing import List, Tuple, Dict, Optional, Any # Third-party imports @@ -25,16 +23,13 @@ log, calculate_slice_bboxes, ) -from models.model import ( - Encoder, - Decoder, - EncoderDecoder -) +from models.model import Encoder, Decoder, EncoderDecoder + class PolygonInference: def __init__(self, experiment_path: str, device: Optional[str] = None) -> None: """Initialize the polygon inference with a trained model. - + Args: experiment_path (str): Path to the experiment folder containing the model checkpoint device (str | None, optional): Device to run the model on. Defaults to CFG.DEVICE @@ -44,14 +39,10 @@ def __init__(self, experiment_path: str, device: Optional[str] = None) -> None: self.model: Optional[EncoderDecoder] = None self.tokenizer: Optional[Tokenizer] = None self._initialize_model() - - # Create persistent temporary directory for caching - self.cache_dir: Path = Path(tempfile.gettempdir()) / "pix2poly_cache" - self.cache_dir.mkdir(parents=True, exist_ok=True) - + def _initialize_model(self) -> None: """Initialize the model and tokenizer. - + This method: 1. Creates a new tokenizer instance 2. Initializes the encoder-decoder model @@ -62,7 +53,7 @@ def _initialize_model(self) -> None: num_bins=CFG.NUM_BINS, width=CFG.INPUT_WIDTH, height=CFG.INPUT_HEIGHT, - max_len=CFG.MAX_LEN + max_len=CFG.MAX_LEN, ) CFG.PAD_IDX = self.tokenizer.PAD_code @@ -73,188 +64,179 @@ def _initialize_model(self) -> None: encoder_len=CFG.NUM_PATCHES, dim=256, num_heads=8, - num_layers=6 + num_layers=6, ) self.model = EncoderDecoder(cfg=CFG, encoder=encoder, decoder=decoder) self.model.to(self.device) self.model.eval() - + # Load latest checkpoint - latest_checkpoint = self._find_latest_checkpoint() - checkpoint_path = os.path.join(self.experiment_path, "logs", "checkpoints", latest_checkpoint) + latest_checkpoint = self._find_single_checkpoint() + checkpoint_path = os.path.join( + self.experiment_path, "logs", "checkpoints", latest_checkpoint + ) log(f"Loading checkpoint from {checkpoint_path}") - checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) - self.model.load_state_dict(checkpoint['state_dict']) + checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) + self.model.load_state_dict(checkpoint["state_dict"]) log("Checkpoint loaded successfully") - def _find_latest_checkpoint(self) -> str: - """Find the checkpoint with the highest epoch number. - + def _find_single_checkpoint(self) -> str: + """Find the single checkpoint file. Crashes if there is more than one checkpoint. + Returns: - str: Filename of the latest checkpoint + str: Filename of the single checkpoint + + Raises: + FileNotFoundError: If no checkpoint directory or files are found + RuntimeError: If more than one checkpoint file is found """ checkpoint_dir = os.path.join(self.experiment_path, "logs", "checkpoints") - checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith('epoch_') and f.endswith('.pth')] - latest_checkpoint = sorted(checkpoint_files)[-1] - return latest_checkpoint + if not os.path.exists(checkpoint_dir): + raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}") - def _get_tile_hash(self, tile: np.ndarray) -> str: - """Generate a hash for the input tile. - - Args: - tile (np.ndarray): Input image tile - - Returns: - str: MD5 hash of the tile - """ - return hashlib.md5(tile.tobytes()).hexdigest() + checkpoint_files = [ + f + for f in os.listdir(checkpoint_dir) + if f.startswith("epoch_") and f.endswith(".pth") + ] + if not checkpoint_files: + raise FileNotFoundError(f"No checkpoint files found in {checkpoint_dir}") - def _load_from_cache(self, cache_key: str) -> Optional[Dict[str, Any]]: - """Load cached result if it exists. - - Args: - cache_key (str): Key to look up in the cache - - Returns: - dict | None: Cached result if found, None otherwise - """ - cache_path = self.cache_dir / f"{cache_key}.npy" - if cache_path.exists(): - return np.load(cache_path, allow_pickle=True).item() - return None - - def _save_to_cache(self, cache_key: str, result: Dict[str, Any]) -> None: - """Save result to cache. - - Args: - cache_key (str): Key to store in the cache - result (dict): Result to cache - """ - cache_path = self.cache_dir / f"{cache_key}.npy" - np.save(cache_path, result, allow_pickle=True) + if len(checkpoint_files) > 1: + raise RuntimeError( + f"Multiple checkpoint files found in {checkpoint_dir}: {checkpoint_files}. Expected exactly one checkpoint." + ) + + return checkpoint_files[0] - def _process_tiles_batch(self, tiles: List[np.ndarray]) -> List[Dict[str, List[np.ndarray]]]: + def _process_tiles_batch( + self, tiles: List[np.ndarray] + ) -> List[Dict[str, List[np.ndarray]]]: """Process a single batch of tiles. - + Args: tiles (list[np.ndarray]): List of tile images to process - + Returns: list[dict]: List of results for each tile, where each result contains: - polygons: List of polygon coordinates """ - valid_transforms = A.Compose([ - A.Resize(height=CFG.INPUT_HEIGHT, width=CFG.INPUT_WIDTH), - A.Normalize( - mean=[0.0, 0.0, 0.0], - std=[1.0, 1.0, 1.0], - max_pixel_value=255.0 - ), - ToTensorV2(), - ]) - + valid_transforms = A.Compose( + [ + A.Resize(height=CFG.INPUT_HEIGHT, width=CFG.INPUT_WIDTH), + A.Normalize( + mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0], max_pixel_value=255.0 + ), + ToTensorV2(), + ] + ) + # Transform each tile individually and stack them transformed_tiles: List[torch.Tensor] = [] for tile in tiles: transformed = valid_transforms(image=tile) - transformed_tiles.append(transformed['image']) - + transformed_tiles.append(transformed["image"]) + # Stack the transformed tiles into a batch batch_tensor = torch.stack(transformed_tiles).to(self.device) - + with torch.no_grad(): batch_preds, batch_confs, perm_preds = test_generate( - self.model, batch_tensor, self.tokenizer, + self.model, + batch_tensor, + self.tokenizer, max_len=CFG.generation_steps, top_k=0, - top_p=1 + top_p=1, ) - + vertex_coords, confs = postprocess(batch_preds, batch_confs, self.tokenizer) - + results: List[Dict[str, List[np.ndarray]]] = [] for j in range(len(tiles)): if vertex_coords[j] is not None: coord = torch.from_numpy(vertex_coords[j]) else: coord = torch.tensor([]) - + padd = torch.ones((CFG.N_VERTICES - len(coord), 2)).fill_(CFG.PAD_IDX) coord = torch.cat([coord, padd], dim=0) - - batch_polygons = permutations_to_polygons(perm_preds[j:j+1], [coord], out='torch') - + + batch_polygons = permutations_to_polygons( + perm_preds[j : j + 1], [coord], out="torch" + ) + valid_polygons: List[np.ndarray] = [] for poly in batch_polygons[0]: poly = poly[poly[:, 0] != CFG.PAD_IDX] if len(poly) > 0: - valid_polygons.append(poly.cpu().numpy()[:, ::-1]) # Convert to [x,y] format - - result = { - 'polygons': valid_polygons - } - - # Cache the result - cache_key = self._get_tile_hash(tiles[j]) - self._save_to_cache(cache_key, result) - + valid_polygons.append( + poly.cpu().numpy()[:, ::-1] + ) # Convert to [x,y] format + + result = {"polygons": valid_polygons} + results.append(result) - + return results - def _merge_polygons(self, tile_results: List[Dict[str, List[np.ndarray]]], positions: List[Tuple[int, int, int, int]]) -> List[np.ndarray]: + def _merge_polygons( + self, + tile_results: List[Dict[str, List[np.ndarray]]], + positions: List[Tuple[int, int, int, int]], + ) -> List[np.ndarray]: """Merge polygon predictions from multiple tiles into a single set of polygons. - + Args: tile_results (list[dict]): List of dictionaries containing 'polygons' for each tile positions (list[tuple[int, int, int, int]]): List of (x, y, x_end, y_end) tuples for each tile's position - + Returns: list[np.ndarray]: List of merged polygons in original image coordinates """ all_polygons: List[shapely.geometry.Polygon] = [] - + # Transform each polygon from tile coordinates to original image coordinates for tile_result, (x, y, x_end, y_end) in zip(tile_results, positions): - tile_polygons = tile_result['polygons'] - height, width = CFG.INPUT_HEIGHT, CFG.INPUT_WIDTH - + tile_polygons = tile_result["polygons"] + # Transform each polygon from tile coordinates to original image coordinates for poly in tile_polygons: - # Scale coordinates to match tile size - scaled_poly = poly * np.array([width/CFG.INPUT_WIDTH, height/CFG.INPUT_HEIGHT]) - - # Translate to original image position - translated_poly = scaled_poly + np.array([x, y]) - + # Translate to original image position (no scaling needed since TILE_SIZE = INPUT_WIDTH) + translated_poly = poly + np.array([x, y]) + # Convert to shapely polygon shapely_poly = shapely.geometry.Polygon(translated_poly) if shapely_poly.is_valid and shapely_poly.area > CFG.MIN_POLYGON_AREA: all_polygons.append(shapely_poly) - + # Use shapely's unary_union to merge overlapping polygons merged = shapely.ops.unary_union(all_polygons) - + simplified_polygons: List[np.ndarray] = [] - # Merging tiles may have created redundant points in the middle of a line. We can remove them. + # Convert single polygon to MultiPolygon for consistent processing + if isinstance(merged, shapely.geometry.Polygon): + merged = shapely.geometry.MultiPolygon([merged]) + + # Process all polygons (now guaranteed to be a MultiPolygon) for poly in merged.geoms: simplified = poly.simplify(CFG.POLYGON_SIMPLIFICATION_TOLERANCE) if simplified.is_valid and not simplified.is_empty: simplified_polygons.append(np.array(simplified.exterior.coords)) - + return simplified_polygons def infer(self, image_data: bytes) -> List[List[List[float]]]: """Infer polygons in an image. - + Args: image_data (bytes): Raw image data - + Returns: list[list[list[float]]]: List of polygons where each polygon is a list of [x,y] coordinates. Each coordinate is rounded to 2 decimal places. - + Raises: ValueError: If the image data is invalid, empty, or cannot be decoded RuntimeError: If there are issues with model prediction or polygon processing @@ -269,62 +251,58 @@ def infer(self, image_data: bytes) -> List[List[List[float]]]: image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if image is None: raise ValueError("Failed to decode image data") - + if image.size == 0: raise ValueError("Decoded image is empty") # Convert to RGB image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - + # Split image into tiles height, width = image.shape[:2] if height == 0 or width == 0: raise ValueError("Invalid image dimensions") - + overlap_ratio = CFG.TILE_OVERLAP / CFG.TILE_SIZE - + bboxes = calculate_slice_bboxes( image_height=height, image_width=width, slice_height=CFG.TILE_SIZE, slice_width=CFG.TILE_SIZE, overlap_height_ratio=overlap_ratio, - overlap_width_ratio=overlap_ratio + overlap_width_ratio=overlap_ratio, ) - + tiles: List[np.ndarray] = [] - tiles_to_process: List[np.ndarray] = [] - + for bbox in bboxes: x1, y1, x2, y2 = bbox tile = image[y1:y2, x1:x2] if tile.size == 0: continue tiles.append(tile) - + # Process tiles in batches all_results: List[Dict[str, List[np.ndarray]]] = [] - - # First check cache for all tiles - for tile in tiles: - cache_key = self._get_tile_hash(tile) - cached_result = self._load_from_cache(cache_key) - if cached_result is not None: - all_results.append(cached_result) - else: - tiles_to_process.append(tile) - - # Process remaining tiles in batches - for i in range(0, len(tiles_to_process), CFG.PREDICTION_BATCH_SIZE): - batch_tiles = tiles_to_process[i:i + CFG.PREDICTION_BATCH_SIZE] + + for i in range(0, len(tiles), CFG.PREDICTION_BATCH_SIZE): + batch_start_time = time.time() + batch_tiles = tiles[i : i + CFG.PREDICTION_BATCH_SIZE] batch_results = self._process_tiles_batch(batch_tiles) all_results.extend(batch_results) + + batch_time = time.time() - batch_start_time + log(f"Processed batch of {len(batch_tiles)} tiles: {batch_time/len(batch_tiles):.3f}s per tile") merged_polygons = self._merge_polygons(all_results, bboxes) - + # Convert to list format polygons_list = [poly.tolist() for poly in merged_polygons] # Round coordinates to two decimal places - polygons_list = [[[round(x, 2), round(y, 2)] for x, y in polygon] for polygon in polygons_list] - - return polygons_list \ No newline at end of file + polygons_list = [ + [[round(x, 2), round(y, 2)] for x, y in polygon] + for polygon in polygons_list + ] + + return polygons_list diff --git a/start_api.sh b/start_api.sh index 7e2d458..7b43bac 100755 --- a/start_api.sh +++ b/start_api.sh @@ -3,13 +3,10 @@ # Exit on error set -e -# Set experiment path -export EXPERIMENT_PATH=runs_share/Pix2Poly_inria_coco_224 - # Activate conda environment source $(conda info --base)/etc/profile.d/conda.sh conda activate pix2poly # Start the API server -echo "Starting API server with experiment path: $EXPERIMENT_PATH" -uvicorn api:app --reload --port 8080 --workers 1 --limit-concurrency 1 +echo "Starting API server" +uvicorn api:app --port 8080 --workers 1 --backlog 10 diff --git a/utils.py b/utils.py index 37c70d0..49f6e2f 100644 --- a/utils.py +++ b/utils.py @@ -56,11 +56,10 @@ def create_mask(tgt, pad_idx): tgt_seq_len = tgt.size(1) tgt_mask = generate_square_subsequent_mask(tgt_seq_len) - tgt_padding_mask = (tgt == pad_idx).float().masked_fill(tgt == pad_idx, float('-inf')) + tgt_padding_mask = (tgt == pad_idx) return tgt_mask, tgt_padding_mask - class AverageMeter: def __init__(self, name="Metric"): self.name = name From 05839c0389779caded77acc4248f4cc60144e695 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Wed, 2 Jul 2025 11:19:14 -0500 Subject: [PATCH 08/45] Accuracy improvements: switch to a bitmap-based polygon merging algorithm, boost overlap to 50%, add caching for development. --- polygon_inference.py | 296 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 260 insertions(+), 36 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index fab738f..637edbb 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -1,6 +1,8 @@ # Standard library imports import os import time +import hashlib +import pickle from typing import List, Tuple, Dict, Optional, Any # Third-party imports @@ -9,8 +11,10 @@ import torch import albumentations as A from albumentations.pytorch import ToTensorV2 -import shapely.geometry -import shapely.ops + +import matplotlib.pyplot as plt +import matplotlib.patches as patches +import math # Local imports from config import CFG @@ -38,8 +42,77 @@ def __init__(self, experiment_path: str, device: Optional[str] = None) -> None: self.experiment_path: str = os.path.realpath(experiment_path) self.model: Optional[EncoderDecoder] = None self.tokenizer: Optional[Tokenizer] = None + self.cache_dir: str = "/tmp/pix2poly_cache" + self._ensure_cache_dir() self._initialize_model() + def _ensure_cache_dir(self) -> None: + """Ensure the cache directory exists.""" + os.makedirs(self.cache_dir, exist_ok=True) + + def _generate_cache_key(self, tiles: List[np.ndarray]) -> str: + """Generate a cache key based on the input tiles. + + Args: + tiles (List[np.ndarray]): List of tile images + + Returns: + str: Hash-based cache key + """ + # Create a hash based on all tile data + hasher = hashlib.sha256() + for tile in tiles: + hasher.update(tile.tobytes()) + return hasher.hexdigest() + + def _get_cache_path(self, cache_key: str) -> str: + """Get the full path for a cache file. + + Args: + cache_key (str): The cache key + + Returns: + str: Full path to the cache file + """ + return os.path.join(self.cache_dir, f"{cache_key}.pkl") + + def _load_from_cache(self, cache_key: str) -> Optional[List[Dict[str, List[np.ndarray]]]]: + """Load results from cache if they exist. + + Args: + cache_key (str): The cache key to look for + + Returns: + Optional[List[Dict[str, List[np.ndarray]]]]: Cached results if found, None otherwise + """ + cache_path = self._get_cache_path(cache_key) + if os.path.exists(cache_path): + try: + with open(cache_path, 'rb') as f: + return pickle.load(f) + except Exception as e: + log(f"Failed to load cache from {cache_path}: {e}") + # Remove corrupted cache file + try: + os.remove(cache_path) + except: + pass + return None + + def _save_to_cache(self, cache_key: str, results: List[Dict[str, List[np.ndarray]]]) -> None: + """Save results to cache. + + Args: + cache_key (str): The cache key + results (List[Dict[str, List[np.ndarray]]]): Results to cache + """ + cache_path = self._get_cache_path(cache_key) + try: + with open(cache_path, 'wb') as f: + pickle.dump(results, f) + except Exception as e: + log(f"Failed to save cache to {cache_path}: {e}") + def _initialize_model(self) -> None: """Initialize the model and tokenizer. @@ -121,6 +194,15 @@ def _process_tiles_batch( list[dict]: List of results for each tile, where each result contains: - polygons: List of polygon coordinates """ + # Generate cache key and try to load from cache + cache_key = self._generate_cache_key(tiles) + cached_results = self._load_from_cache(cache_key) + if cached_results is not None: + log(f"Cache hit for batch of {len(tiles)} tiles") + return cached_results + + log(f"Cache miss for batch of {len(tiles)} tiles, processing...") + valid_transforms = A.Compose( [ A.Resize(height=CFG.INPUT_HEIGHT, width=CFG.INPUT_WIDTH), @@ -178,54 +260,193 @@ def _process_tiles_batch( results.append(result) + # Save results to cache + self._save_to_cache(cache_key, results) + return results + def _create_tile_visualization( + self, + tiles: List[np.ndarray], + tile_results: List[Dict[str, List[np.ndarray]]], + positions: List[Tuple[int, int, int, int]], + ) -> None: + """Create a tile visualization showing each tile with its detected polygons. + + Args: + tiles (List[np.ndarray]): List of tile images + tile_results (List[Dict[str, List[np.ndarray]]]): List of results for each tile + positions (List[Tuple[int, int, int, int]]): List of (x, y, x_end, y_end) tuples for each tile's position + """ + if not tiles: + return + + # Calculate grid dimensions + num_tiles = len(tiles) + cols = math.ceil(math.sqrt(num_tiles)) + rows = math.ceil(num_tiles / cols) + + # Create figure + fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4)) + if rows == 1 and cols == 1: + axes = [axes] + elif rows == 1 or cols == 1: + axes = axes.flatten() + else: + axes = axes.flatten() + + for i in range(num_tiles): + ax = axes[i] + tile = tiles[i] + tile_result = tile_results[i] + + # Convert RGB to display format + display_tile = cv2.cvtColor(tile, cv2.COLOR_RGB2BGR) if tile.shape[-1] == 3 else tile + display_tile = cv2.cvtColor(display_tile, cv2.COLOR_BGR2RGB) + + ax.imshow(display_tile) + ax.set_title(f'Tile {i+1}') + ax.axis('off') + + # Draw polygons on this tile + for poly in tile_result["polygons"]: + if len(poly) > 2: + # Close the polygon for visualization + poly_closed = np.vstack([poly, poly[0]]) + ax.plot(poly_closed[:, 0], poly_closed[:, 1], 'g-', linewidth=2) + + # Draw vertices + ax.scatter(poly[:, 0], poly[:, 1], c='red', s=20, zorder=5) + + # Hide unused subplots + for i in range(num_tiles, len(axes)): + axes[i].axis('off') + + plt.tight_layout() + plt.savefig('tile-visualization.png', dpi=150, bbox_inches='tight') + plt.close() + log(f"Saved tile visualization with {num_tiles} tiles to tile-visualization.png") + def _merge_polygons( self, tile_results: List[Dict[str, List[np.ndarray]]], positions: List[Tuple[int, int, int, int]], + image_height: int, + image_width: int, ) -> List[np.ndarray]: - """Merge polygon predictions from multiple tiles into a single set of polygons. + """Merge polygon predictions from multiple tiles using a bitmap approach. + + This method creates a bitmap where pixels inside any polygon are set to True, + then vectorizes the bitmap back to polygons. This eliminates geometric artifacts + from traditional polygon union operations. Args: tile_results (list[dict]): List of dictionaries containing 'polygons' for each tile positions (list[tuple[int, int, int, int]]): List of (x, y, x_end, y_end) tuples for each tile's position + image_height (int): Height of the original image + image_width (int): Width of the original image Returns: list[np.ndarray]: List of merged polygons in original image coordinates """ - all_polygons: List[shapely.geometry.Polygon] = [] - - # Transform each polygon from tile coordinates to original image coordinates - for tile_result, (x, y, x_end, y_end) in zip(tile_results, positions): + # Scale factor for subpixel precision + scale_factor = 16 + + # Create bitmap at 8x resolution for subpixel precision + bitmap = np.zeros((image_height * scale_factor, image_width * scale_factor), dtype=np.uint8) + + # Fill bitmap with polygon regions + for tile_idx, (tile_result, (x, y, x_end, y_end)) in enumerate(zip(tile_results, positions)): tile_polygons = tile_result["polygons"] - - # Transform each polygon from tile coordinates to original image coordinates - for poly in tile_polygons: - # Translate to original image position (no scaling needed since TILE_SIZE = INPUT_WIDTH) - translated_poly = poly + np.array([x, y]) - - # Convert to shapely polygon - shapely_poly = shapely.geometry.Polygon(translated_poly) - if shapely_poly.is_valid and shapely_poly.area > CFG.MIN_POLYGON_AREA: - all_polygons.append(shapely_poly) - - # Use shapely's unary_union to merge overlapping polygons - merged = shapely.ops.unary_union(all_polygons) - - simplified_polygons: List[np.ndarray] = [] - - # Convert single polygon to MultiPolygon for consistent processing - if isinstance(merged, shapely.geometry.Polygon): - merged = shapely.geometry.MultiPolygon([merged]) - - # Process all polygons (now guaranteed to be a MultiPolygon) - for poly in merged.geoms: - simplified = poly.simplify(CFG.POLYGON_SIMPLIFICATION_TOLERANCE) - if simplified.is_valid and not simplified.is_empty: - simplified_polygons.append(np.array(simplified.exterior.coords)) - - return simplified_polygons + + for poly_idx, poly in enumerate(tile_polygons): + if len(poly) < 3: # Skip invalid polygons + continue + + # Check if polygon is in a corner (should be removed) + tile_width = x_end - x + tile_height = y_end - y + edge_tolerance = 8.0 # Consider points within 8 pixels of edge as "on edge" + corner_tolerance = 2.0 # Consider points within 2 pixels of corner as "near corner" + + # Check which edges have points + on_left_edge = poly[:, 0] <= edge_tolerance + on_right_edge = poly[:, 0] >= tile_width - edge_tolerance + on_top_edge = poly[:, 1] <= edge_tolerance + on_bottom_edge = poly[:, 1] >= tile_height - edge_tolerance + + # Check if near any corner + near_top_left = (poly[:, 0] <= corner_tolerance) & (poly[:, 1] <= corner_tolerance) + near_top_right = (poly[:, 0] >= tile_width - corner_tolerance) & (poly[:, 1] <= corner_tolerance) + near_bottom_left = (poly[:, 0] <= corner_tolerance) & (poly[:, 1] >= tile_height - corner_tolerance) + near_bottom_right = (poly[:, 0] >= tile_width - corner_tolerance) & (poly[:, 1] >= tile_height - corner_tolerance) + + # Check for corner polygons (polygons that span two adjacent edges AND are near the corner) + is_corner_polygon = ( + (np.any(on_top_edge) and np.any(on_left_edge) and np.any(near_top_left)) or + (np.any(on_top_edge) and np.any(on_right_edge) and np.any(near_top_right)) or + (np.any(on_bottom_edge) and np.any(on_left_edge) and np.any(near_bottom_left)) or + (np.any(on_bottom_edge) and np.any(on_right_edge) and np.any(near_bottom_right)) + ) + + if is_corner_polygon: + continue + + # Transform polygon from tile coordinates to image coordinates + transformed_poly = poly + np.array([x, y]) + + # Scale up coordinates for high-resolution bitmap + scaled_poly = transformed_poly * scale_factor + + # Ensure coordinates are within scaled bitmap bounds + scaled_poly[:, 0] = np.clip(scaled_poly[:, 0], 0, image_width * scale_factor - 1) + scaled_poly[:, 1] = np.clip(scaled_poly[:, 1], 0, image_height * scale_factor - 1) + + # Convert to integer coordinates for rasterization + poly_coords = scaled_poly.astype(np.int32) + + # Fill the polygon region in the bitmap + cv2.fillPoly(bitmap, [poly_coords], 255) + + # Apply morphological closing to fill small gaps and smooth edges + # Scale kernel size proportionally to the scaled bitmap + kernel_size = max(1, min(3 * scale_factor, min(image_height, image_width) * scale_factor // 1000)) # Adaptive kernel size + if kernel_size > 1: + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) + bitmap = cv2.morphologyEx(bitmap, cv2.MORPH_CLOSE, kernel) + + # Save bitmap for debugging (optional) + cv2.imwrite('debug_polygon_bitmap.png', bitmap) + log("Saved debug bitmap to debug_polygon_bitmap.png") + + # Find contours in the bitmap + contours, _ = cv2.findContours(bitmap, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + merged_polygons: List[np.ndarray] = [] + + for contour in contours: + # Skip very small contours (area is scaled by scale_factor^2) + area = cv2.contourArea(contour) + if area < CFG.MIN_POLYGON_AREA * (scale_factor ** 2): + continue + + # Simplify the contour to reduce jaggedness while preserving shape + perimeter = cv2.arcLength(contour, True) + epsilon = 0.01 * perimeter # 1% of perimeter + simplified_contour = cv2.approxPolyDP(contour, epsilon, True) + + # Convert from OpenCV format to our polygon format + if len(simplified_contour) >= 3: # Valid polygon needs at least 3 points + # Reshape from (n, 1, 2) to (n, 2) and convert to float + polygon_coords = simplified_contour.reshape(-1, 2).astype(np.float32) + + # Scale down coordinates back to original image coordinate system + polygon_coords = polygon_coords / scale_factor + + merged_polygons.append(polygon_coords) + + log(f"Bitmap approach: {len(merged_polygons)} polygons extracted from bitmap") + return merged_polygons def infer(self, image_data: bytes) -> List[List[List[float]]]: """Infer polygons in an image. @@ -263,7 +484,7 @@ def infer(self, image_data: bytes) -> List[List[List[float]]]: if height == 0 or width == 0: raise ValueError("Invalid image dimensions") - overlap_ratio = CFG.TILE_OVERLAP / CFG.TILE_SIZE + overlap_ratio = 0.5 bboxes = calculate_slice_bboxes( image_height=height, @@ -295,7 +516,10 @@ def infer(self, image_data: bytes) -> List[List[List[float]]]: batch_time = time.time() - batch_start_time log(f"Processed batch of {len(batch_tiles)} tiles: {batch_time/len(batch_tiles):.3f}s per tile") - merged_polygons = self._merge_polygons(all_results, bboxes) + # Create tile visualization + self._create_tile_visualization(tiles, all_results, bboxes) + + merged_polygons = self._merge_polygons(all_results, bboxes, height, width) # Convert to list format polygons_list = [poly.tolist() for poly in merged_polygons] From 767252093d04cc1f96266ef372aafc79dfeacf5d Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Wed, 2 Jul 2025 11:39:56 -0500 Subject: [PATCH 09/45] Improve ignoring heuristic, use a set kernel_size. --- polygon_inference.py | 62 ++++++++++++++++++++++++-------------------- 1 file changed, 34 insertions(+), 28 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index 637edbb..7b8314f 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -363,34 +363,43 @@ def _merge_polygons( if len(poly) < 3: # Skip invalid polygons continue - # Check if polygon is in a corner (should be removed) + # Check if polygon has any part in the central 50% of the tile tile_width = x_end - x tile_height = y_end - y - edge_tolerance = 8.0 # Consider points within 8 pixels of edge as "on edge" - corner_tolerance = 2.0 # Consider points within 2 pixels of corner as "near corner" - # Check which edges have points - on_left_edge = poly[:, 0] <= edge_tolerance - on_right_edge = poly[:, 0] >= tile_width - edge_tolerance - on_top_edge = poly[:, 1] <= edge_tolerance - on_bottom_edge = poly[:, 1] >= tile_height - edge_tolerance + # Define central region boundaries (25% to 75% in each dimension) + central_x_min = tile_width * 0.25 + central_x_max = tile_width * 0.75 + central_y_min = tile_height * 0.25 + central_y_max = tile_height * 0.75 - # Check if near any corner - near_top_left = (poly[:, 0] <= corner_tolerance) & (poly[:, 1] <= corner_tolerance) - near_top_right = (poly[:, 0] >= tile_width - corner_tolerance) & (poly[:, 1] <= corner_tolerance) - near_bottom_left = (poly[:, 0] <= corner_tolerance) & (poly[:, 1] >= tile_height - corner_tolerance) - near_bottom_right = (poly[:, 0] >= tile_width - corner_tolerance) & (poly[:, 1] >= tile_height - corner_tolerance) - - # Check for corner polygons (polygons that span two adjacent edges AND are near the corner) - is_corner_polygon = ( - (np.any(on_top_edge) and np.any(on_left_edge) and np.any(near_top_left)) or - (np.any(on_top_edge) and np.any(on_right_edge) and np.any(near_top_right)) or - (np.any(on_bottom_edge) and np.any(on_left_edge) and np.any(near_bottom_left)) or - (np.any(on_bottom_edge) and np.any(on_right_edge) and np.any(near_bottom_right)) + # Check if any vertex of the polygon is in the central region + in_central_region = ( + (poly[:, 0] >= central_x_min) & (poly[:, 0] <= central_x_max) & + (poly[:, 1] >= central_y_min) & (poly[:, 1] <= central_y_max) ) - if is_corner_polygon: - continue + # Skip polygon if no part of it exists in the central region + if not np.any(in_central_region): + continue + + # Additional rule: Skip if polygon occupies two or more corners + corner_tolerance = 20.0 # Distance from corner to be considered "in corner" + + # Define corner regions + corners = [ + (poly[:, 0] <= corner_tolerance) & (poly[:, 1] <= corner_tolerance), # top_left + (poly[:, 0] >= tile_width - corner_tolerance) & (poly[:, 1] <= corner_tolerance), # top_right + (poly[:, 0] <= corner_tolerance) & (poly[:, 1] >= tile_height - corner_tolerance), # bottom_left + (poly[:, 0] >= tile_width - corner_tolerance) & (poly[:, 1] >= tile_height - corner_tolerance) # bottom_right + ] + + # Count how many corners have any points + occupied_corners = sum(1 for corner_mask in corners if np.any(corner_mask)) + + # Skip polygon if it occupies two or more corners + if occupied_corners >= 2: + continue # Transform polygon from tile coordinates to image coordinates transformed_poly = poly + np.array([x, y]) @@ -408,12 +417,9 @@ def _merge_polygons( # Fill the polygon region in the bitmap cv2.fillPoly(bitmap, [poly_coords], 255) - # Apply morphological closing to fill small gaps and smooth edges - # Scale kernel size proportionally to the scaled bitmap - kernel_size = max(1, min(3 * scale_factor, min(image_height, image_width) * scale_factor // 1000)) # Adaptive kernel size - if kernel_size > 1: - kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) - bitmap = cv2.morphologyEx(bitmap, cv2.MORPH_CLOSE, kernel) + kernel_size = 32 + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) + bitmap = cv2.morphologyEx(bitmap, cv2.MORPH_CLOSE, kernel) # Save bitmap for debugging (optional) cv2.imwrite('debug_polygon_bitmap.png', bitmap) From c73bdaf695dd10dafb909935f7d394d5b7c560dc Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Wed, 2 Jul 2025 13:03:56 -0500 Subject: [PATCH 10/45] Improve heuristic to look at the entire polygon, add logging. --- polygon_inference.py | 28 +++++++++++++++++++++------- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index 7b8314f..af1c745 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -373,14 +373,26 @@ def _merge_polygons( central_y_min = tile_height * 0.25 central_y_max = tile_height * 0.75 - # Check if any vertex of the polygon is in the central region - in_central_region = ( - (poly[:, 0] >= central_x_min) & (poly[:, 0] <= central_x_max) & - (poly[:, 1] >= central_y_min) & (poly[:, 1] <= central_y_max) - ) + # Check if any part of the polygon (including interior) intersects with central region + # Create a mask for the polygon + poly_mask = np.zeros((tile_height, tile_width), dtype=np.uint8) + poly_coords_int = poly.astype(np.int32) + cv2.fillPoly(poly_mask, [poly_coords_int], 255) + + # Create a mask for the central region + central_mask = np.zeros((tile_height, tile_width), dtype=np.uint8) + central_x_min_int = int(central_x_min) + central_y_min_int = int(central_y_min) + central_x_max_int = int(central_x_max) + central_y_max_int = int(central_y_max) + central_mask[central_y_min_int:central_y_max_int, central_x_min_int:central_x_max_int] = 255 - # Skip polygon if no part of it exists in the central region - if not np.any(in_central_region): + # Check if there's any intersection between polygon and central region + intersection = cv2.bitwise_and(poly_mask, central_mask) + has_intersection = np.any(intersection > 0) + + # Skip polygon if no part of it intersects with the central region + if not has_intersection: continue # Additional rule: Skip if polygon occupies two or more corners @@ -510,6 +522,8 @@ def infer(self, image_data: bytes) -> List[List[List[float]]]: continue tiles.append(tile) + log(f"Total number of tiles to process: {len(tiles)}") + # Process tiles in batches all_results: List[Dict[str, List[np.ndarray]]] = [] From e3b5bfed7d28c687c368dd9b2bb959264ce0df09 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Wed, 2 Jul 2025 15:22:29 -0500 Subject: [PATCH 11/45] More efficient intersection detection. --- polygon_inference.py | 61 ++++++++++++++++++++++---------------------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index af1c745..7bd5755 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -11,6 +11,7 @@ import torch import albumentations as A from albumentations.pytorch import ToTensorV2 +from shapely.geometry import Polygon import matplotlib.pyplot as plt import matplotlib.patches as patches @@ -300,11 +301,8 @@ def _create_tile_visualization( tile = tiles[i] tile_result = tile_results[i] - # Convert RGB to display format - display_tile = cv2.cvtColor(tile, cv2.COLOR_RGB2BGR) if tile.shape[-1] == 3 else tile - display_tile = cv2.cvtColor(display_tile, cv2.COLOR_BGR2RGB) - - ax.imshow(display_tile) + # Tiles are already in RGB format, no conversion needed for matplotlib + ax.imshow(tile) ax.set_title(f'Tile {i+1}') ax.axis('off') @@ -359,39 +357,40 @@ def _merge_polygons( for tile_idx, (tile_result, (x, y, x_end, y_end)) in enumerate(zip(tile_results, positions)): tile_polygons = tile_result["polygons"] + # Check if polygon has any part in the central 50% of the tile + tile_width = x_end - x + tile_height = y_end - y + + # Define central region boundaries (25% to 75% in each dimension) + central_x_min = tile_width * 0.25 + central_x_max = tile_width * 0.75 + central_y_min = tile_height * 0.25 + central_y_max = tile_height * 0.75 + + # Create square polygon representing the central region + central_region_coords = [ + (central_x_min, central_y_min), # top-left + (central_x_max, central_y_min), # top-right + (central_x_max, central_y_max), # bottom-right + (central_x_min, central_y_max) # bottom-left + ] + central_polygon = Polygon(central_region_coords) + for poly_idx, poly in enumerate(tile_polygons): if len(poly) < 3: # Skip invalid polygons continue - # Check if polygon has any part in the central 50% of the tile - tile_width = x_end - x - tile_height = y_end - y - - # Define central region boundaries (25% to 75% in each dimension) - central_x_min = tile_width * 0.25 - central_x_max = tile_width * 0.75 - central_y_min = tile_height * 0.25 - central_y_max = tile_height * 0.75 - - # Check if any part of the polygon (including interior) intersects with central region - # Create a mask for the polygon - poly_mask = np.zeros((tile_height, tile_width), dtype=np.uint8) - poly_coords_int = poly.astype(np.int32) - cv2.fillPoly(poly_mask, [poly_coords_int], 255) + # Use Shapely for precise polygon intersection detection + # Convert numpy arrays to tuples for Shapely + poly_coords = [(float(x), float(y)) for x, y in poly] - # Create a mask for the central region - central_mask = np.zeros((tile_height, tile_width), dtype=np.uint8) - central_x_min_int = int(central_x_min) - central_y_min_int = int(central_y_min) - central_x_max_int = int(central_x_max) - central_y_max_int = int(central_y_max) - central_mask[central_y_min_int:central_y_max_int, central_x_min_int:central_x_max_int] = 255 + # Create Shapely polygon for detected shape + detected_polygon = Polygon(poly_coords) - # Check if there's any intersection between polygon and central region - intersection = cv2.bitwise_and(poly_mask, central_mask) - has_intersection = np.any(intersection > 0) + # Check for intersection + has_intersection = detected_polygon.intersects(central_polygon) - # Skip polygon if no part of it intersects with the central region + # Skip polygon if no intersection with central region if not has_intersection: continue From 5d5c84c527115e3f491bea6c3f1588782dee34ed Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Mon, 7 Jul 2025 10:45:05 -0500 Subject: [PATCH 12/45] Use an improved heuristic: look at other tiles for confirmation of horizontal/vertical edges. --- polygon_inference.py | 325 +++++++++++++++++++++++++++++++++---------- 1 file changed, 248 insertions(+), 77 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index 7bd5755..cc5620e 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -282,48 +282,263 @@ def _create_tile_visualization( if not tiles: return - # Calculate grid dimensions - num_tiles = len(tiles) - cols = math.ceil(math.sqrt(num_tiles)) - rows = math.ceil(num_tiles / cols) + # Calculate grid dimensions based on actual spatial arrangement + # Extract unique x and y starting positions + x_positions = sorted(set(pos[0] for pos in positions)) + y_positions = sorted(set(pos[1] for pos in positions)) + + cols = len(x_positions) + rows = len(y_positions) + + # Create mapping from (x, y) position to (row, col) index + x_to_col = {x: i for i, x in enumerate(x_positions)} + y_to_row = {y: i for i, y in enumerate(y_positions)} # Create figure fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4)) + + # Handle different subplot layouts if rows == 1 and cols == 1: + axes = [[axes]] + elif rows == 1: axes = [axes] - elif rows == 1 or cols == 1: - axes = axes.flatten() - else: - axes = axes.flatten() - - for i in range(num_tiles): - ax = axes[i] - tile = tiles[i] - tile_result = tile_results[i] + elif cols == 1: + axes = [[ax] for ax in axes] + + # Initialize all subplots as empty + for i in range(rows): + for j in range(cols): + axes[i][j].axis('off') + + # Place each tile in the correct position + for i, (tile, tile_result, pos) in enumerate(zip(tiles, tile_results, positions)): + x, y, x_end, y_end = pos + + # Get the grid position for this tile + row = y_to_row[y] + col = x_to_col[x] + + ax = axes[row][col] # Tiles are already in RGB format, no conversion needed for matplotlib ax.imshow(tile) - ax.set_title(f'Tile {i+1}') + ax.set_title(f'Tile {i}') ax.axis('off') # Draw polygons on this tile - for poly in tile_result["polygons"]: + polygons = tile_result["polygons"] + polygon_valid = tile_result["polygon_valid"] + + for poly_idx, (poly, is_valid) in enumerate(zip(polygons, polygon_valid)): if len(poly) > 2: + # Use green for valid polygons, red for invalid ones + color = 'g' if is_valid else 'r' + vertex_color = 'red' if is_valid else 'darkred' + # Close the polygon for visualization poly_closed = np.vstack([poly, poly[0]]) - ax.plot(poly_closed[:, 0], poly_closed[:, 1], 'g-', linewidth=2) + ax.plot(poly_closed[:, 0], poly_closed[:, 1], f'{color}-', linewidth=2) # Draw vertices - ax.scatter(poly[:, 0], poly[:, 1], c='red', s=20, zorder=5) - - # Hide unused subplots - for i in range(num_tiles, len(axes)): - axes[i].axis('off') + ax.scatter(poly[:, 0], poly[:, 1], c=vertex_color, s=20, zorder=5) + + # Calculate centroid and render polygon index + centroid_x = np.mean(poly[:, 0]) + centroid_y = np.mean(poly[:, 1]) + + # Use white text with black outline for visibility + text_color = 'white' + outline_color = 'black' + + # Add text with outline for better visibility + ax.text(centroid_x, centroid_y, str(poly_idx), + fontsize=12, fontweight='bold', color=text_color, + ha='center', va='center', zorder=6, + bbox=dict(boxstyle='round,pad=0.3', facecolor=outline_color, alpha=0.7)) plt.tight_layout() plt.savefig('tile-visualization.png', dpi=150, bbox_inches='tight') plt.close() - log(f"Saved tile visualization with {num_tiles} tiles to tile-visualization.png") + log(f"Saved tile visualization to tile-visualization.png") + + def _validate_all_polygons( + self, + tile_results: List[Dict[str, List[np.ndarray]]], + positions: List[Tuple[int, int, int, int]], + image_height: int, + image_width: int + ) -> List[Dict[str, List[np.ndarray]]]: + """Validate all polygons in the tile results and add validation attributes. + + This method implements a heuristic to validate polygons by checking if their boundary edges + have points that are contained in polygons from other tiles. + + Args: + tile_results (List[Dict[str, List[np.ndarray]]]): List of tile results containing polygons + positions (List[Tuple[int, int, int, int]]): List of (x, y, x_end, y_end) tuples for each tile's position + image_height (int): Height of the original image + image_width (int): Width of the original image + + Returns: + List[Dict[str, List[np.ndarray]]]: Updated tile results with validation attributes + """ + # Initialize polygon_valid list for each tile + for tile_result in tile_results: + tile_result["polygon_valid"] = [True] * len(tile_result["polygons"]) + + def is_edge_near_tile_boundary(p1, p2, tile_bounds, tolerance=2): + """Check if an edge is colinear with the tile boundary within tolerance.""" + x_min, y_min, x_max, y_max = tile_bounds + x1, y1 = p1 + x2, y2 = p2 + + # Check if edge is roughly horizontal and colinear with top boundary + if (abs(y1 - y_min) <= tolerance and abs(y2 - y_min) <= tolerance and + abs(y1 - y2) <= tolerance): + return True + + # Check if edge is roughly horizontal and colinear with bottom boundary + if (abs(y1 - y_max) <= tolerance and abs(y2 - y_max) <= tolerance and + abs(y1 - y2) <= tolerance): + return True + + # Check if edge is roughly vertical and colinear with left boundary + if (abs(x1 - x_min) <= tolerance and abs(x2 - x_min) <= tolerance and + abs(x1 - x2) <= tolerance): + return True + + # Check if edge is roughly vertical and colinear with right boundary + if (abs(x1 - x_max) <= tolerance and abs(x2 - x_max) <= tolerance and + abs(x1 - x2) <= tolerance): + return True + + return False + + def generate_edge_sample_points(p1, p2, num_points=10, margin_px=10): + """Generate equally spaced points along an edge, leaving a fixed margin at each end.""" + # Calculate edge length + edge_length = math.sqrt((p2[0] - p1[0])**2 + (p2[1] - p1[1])**2) + + # If edge is too short to accommodate margins, return empty list + if edge_length <= 2 * margin_px: + return [] + + # Calculate the usable length (excluding margins) + usable_length = edge_length - 2 * margin_px + + # Calculate t values for the start and end of the usable region + t_start = margin_px / edge_length + t_end = 1.0 - margin_px / edge_length + + points = [] + for i in range(num_points): + # Distribute points evenly within the usable region + t_local = i / (num_points - 1) if num_points > 1 else 0.5 + t = t_start + t_local * (t_end - t_start) + + x = p1[0] + t * (p2[0] - p1[0]) + y = p1[1] + t * (p2[1] - p1[1]) + points.append((x, y)) + + return points + + def point_in_polygon(point, polygon): + """Check if a point is inside a polygon using OpenCV.""" + if len(polygon) < 3: + return False + # Convert polygon to the format expected by cv2.pointPolygonTest + poly_points = polygon.astype(np.float32).reshape((-1, 1, 2)) + return cv2.pointPolygonTest(poly_points, point, False) >= 0 + + # Process each tile + for tile_idx, (tile_result, tile_pos) in enumerate(zip(tile_results, positions)): + x, y, x_end, y_end = tile_pos + tile_width = x_end - x + tile_height = y_end - y + tile_bounds = (0, 0, tile_width, tile_height) # tile local coordinates + + polygons = tile_result["polygons"] + polygon_valid = tile_result["polygon_valid"] + + # Check each polygon in this tile + for poly_idx, polygon in enumerate(polygons): + if len(polygon) < 3: + polygon_valid[poly_idx] = False + continue + + # Find edges that are near tile boundaries + boundary_edges = [] + for i in range(len(polygon)): + p1 = polygon[i] + p2 = polygon[(i + 1) % len(polygon)] + + if is_edge_near_tile_boundary(p1, p2, tile_bounds): + boundary_edges.append((p1, p2)) + + # If no boundary edges, polygon is valid (not on tile boundary) + if not boundary_edges: + continue + + # Check sample points along boundary edges + polygon_is_valid = True + + for p1, p2 in boundary_edges: + sample_points = generate_edge_sample_points(p1, p2) + + # Determine if this edge is horizontal or vertical + is_horizontal_edge = abs(p1[1] - p2[1]) <= 2 # Edge is roughly horizontal + is_vertical_edge = abs(p1[0] - p2[0]) <= 2 # Edge is roughly vertical + + # Convert sample points to global image coordinates + global_sample_points = [(px + x, py + y) for px, py in sample_points] + + # Check if each sample point is contained in any polygon from other tiles + for global_point in global_sample_points: + point_found_in_other_polygon = False + + # Check all other tiles + for other_tile_idx, (other_tile_result, other_tile_pos) in enumerate(zip(tile_results, positions)): + if other_tile_idx == tile_idx: + continue + + other_x, other_y, other_x_end, other_y_end = other_tile_pos + + # Skip tiles in same row for horizontal edges + if is_horizontal_edge and other_y == y: + continue + + # Skip tiles in same column for vertical edges + if is_vertical_edge and other_x == x: + continue + + # Check if point is within other tile bounds + if (other_x <= global_point[0] < other_x_end and + other_y <= global_point[1] < other_y_end): + + # Convert global point to other tile's local coordinates + local_point = (global_point[0] - other_x, global_point[1] - other_y) + + # Check if point is inside any polygon in this other tile + for other_polygon in other_tile_result["polygons"]: + if point_in_polygon(local_point, other_polygon): + point_found_in_other_polygon = True + break + + if point_found_in_other_polygon: + break + + # If any sample point is not found in other polygons, mark as invalid + if not point_found_in_other_polygon: + polygon_is_valid = False + break + + if not polygon_is_valid: + break + + # Update polygon validity + polygon_valid[poly_idx] = polygon_is_valid + + return tile_results def _merge_polygons( self, @@ -355,63 +570,16 @@ def _merge_polygons( # Fill bitmap with polygon regions for tile_idx, (tile_result, (x, y, x_end, y_end)) in enumerate(zip(tile_results, positions)): + log(f"tile_idx: {tile_idx}") + tile_polygons = tile_result["polygons"] + polygon_valid = tile_result["polygon_valid"] - # Check if polygon has any part in the central 50% of the tile - tile_width = x_end - x - tile_height = y_end - y - - # Define central region boundaries (25% to 75% in each dimension) - central_x_min = tile_width * 0.25 - central_x_max = tile_width * 0.75 - central_y_min = tile_height * 0.25 - central_y_max = tile_height * 0.75 - - # Create square polygon representing the central region - central_region_coords = [ - (central_x_min, central_y_min), # top-left - (central_x_max, central_y_min), # top-right - (central_x_max, central_y_max), # bottom-right - (central_x_min, central_y_max) # bottom-left - ] - central_polygon = Polygon(central_region_coords) - - for poly_idx, poly in enumerate(tile_polygons): - if len(poly) < 3: # Skip invalid polygons - continue - - # Use Shapely for precise polygon intersection detection - # Convert numpy arrays to tuples for Shapely - poly_coords = [(float(x), float(y)) for x, y in poly] - - # Create Shapely polygon for detected shape - detected_polygon = Polygon(poly_coords) - - # Check for intersection - has_intersection = detected_polygon.intersects(central_polygon) - - # Skip polygon if no intersection with central region - if not has_intersection: + for poly_idx, (poly, is_valid) in enumerate(zip(tile_polygons, polygon_valid)): + # Skip invalid polygons + if not is_valid: continue - - # Additional rule: Skip if polygon occupies two or more corners - corner_tolerance = 20.0 # Distance from corner to be considered "in corner" - - # Define corner regions - corners = [ - (poly[:, 0] <= corner_tolerance) & (poly[:, 1] <= corner_tolerance), # top_left - (poly[:, 0] >= tile_width - corner_tolerance) & (poly[:, 1] <= corner_tolerance), # top_right - (poly[:, 0] <= corner_tolerance) & (poly[:, 1] >= tile_height - corner_tolerance), # bottom_left - (poly[:, 0] >= tile_width - corner_tolerance) & (poly[:, 1] >= tile_height - corner_tolerance) # bottom_right - ] - - # Count how many corners have any points - occupied_corners = sum(1 for corner_mask in corners if np.any(corner_mask)) - - # Skip polygon if it occupies two or more corners - if occupied_corners >= 2: - continue - + # Transform polygon from tile coordinates to image coordinates transformed_poly = poly + np.array([x, y]) @@ -424,7 +592,7 @@ def _merge_polygons( # Convert to integer coordinates for rasterization poly_coords = scaled_poly.astype(np.int32) - + # Fill the polygon region in the bitmap cv2.fillPoly(bitmap, [poly_coords], 255) @@ -535,6 +703,9 @@ def infer(self, image_data: bytes) -> List[List[List[float]]]: batch_time = time.time() - batch_start_time log(f"Processed batch of {len(batch_tiles)} tiles: {batch_time/len(batch_tiles):.3f}s per tile") + # Validate all polygons and add validation attributes + all_results = self._validate_all_polygons(all_results, bboxes, height, width) + # Create tile visualization self._create_tile_visualization(tiles, all_results, bboxes) From d08200503012038e63050ff5089c65b239008947 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Mon, 7 Jul 2025 12:40:54 -0500 Subject: [PATCH 13/45] Control debug output and development caching, reject overlapping polygons, always check at least one edge point. --- infer_single_image.py | 2 +- polygon_inference.py | 170 ++++++++++++++++++++++++++++++++++-------- 2 files changed, 139 insertions(+), 33 deletions(-) diff --git a/infer_single_image.py b/infer_single_image.py index c3e98ca..e24fdbf 100644 --- a/infer_single_image.py +++ b/infer_single_image.py @@ -21,7 +21,7 @@ def main(): inference = PolygonInference(args.experiment_path) # Get inference results - polygons_list = inference.infer(image_data) + polygons_list = inference.infer(image_data, debug=True) # Decode image for visualization nparr = np.frombuffer(image_data, np.uint8) diff --git a/polygon_inference.py b/polygon_inference.py index cc5620e..78e1d65 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -184,7 +184,7 @@ def _find_single_checkpoint(self) -> str: return checkpoint_files[0] def _process_tiles_batch( - self, tiles: List[np.ndarray] + self, tiles: List[np.ndarray], debug: bool = False ) -> List[Dict[str, List[np.ndarray]]]: """Process a single batch of tiles. @@ -195,14 +195,18 @@ def _process_tiles_batch( list[dict]: List of results for each tile, where each result contains: - polygons: List of polygon coordinates """ - # Generate cache key and try to load from cache - cache_key = self._generate_cache_key(tiles) - cached_results = self._load_from_cache(cache_key) - if cached_results is not None: - log(f"Cache hit for batch of {len(tiles)} tiles") - return cached_results - - log(f"Cache miss for batch of {len(tiles)} tiles, processing...") + # Generate cache key and try to load from cache (only when debug=True) + if debug: + cache_key = self._generate_cache_key(tiles) + cached_results = self._load_from_cache(cache_key) + if cached_results is not None: + log(f"Cache hit for batch of {len(tiles)} tiles") + return cached_results + + log(f"Cache miss for batch of {len(tiles)} tiles, processing...") + else: + log(f"Processing batch of {len(tiles)} tiles (caching disabled)...") + cache_key = None valid_transforms = A.Compose( [ @@ -261,8 +265,9 @@ def _process_tiles_batch( results.append(result) - # Save results to cache - self._save_to_cache(cache_key, results) + # Save results to cache (only when debug=True) + if debug and cache_key is not None: + self._save_to_cache(cache_key, results) return results @@ -386,6 +391,89 @@ def _validate_all_polygons( for tile_result in tile_results: tile_result["polygon_valid"] = [True] * len(tile_result["polygons"]) + # Remove overlapping polygons within each tile (before edge validation) + + def check_polygon_overlap(poly1, poly2): + """Check if two polygons overlap using Shapely.""" + try: + # Convert numpy arrays to Shapely polygons + if len(poly1) < 3 or len(poly2) < 3: + return False + + shapely_poly1 = Polygon(poly1) + shapely_poly2 = Polygon(poly2) + + # Check if polygons are valid + if not shapely_poly1.is_valid or not shapely_poly2.is_valid: + return False + + # Check for intersection (but not just touching) + return shapely_poly1.intersects(shapely_poly2) and not shapely_poly1.touches(shapely_poly2) + except: + return False + + def calculate_polygon_area(poly): + """Calculate the area of a polygon.""" + try: + if len(poly) < 3: + return 0 + shapely_poly = Polygon(poly) + if not shapely_poly.is_valid: + return 0 + return shapely_poly.area + except: + return 0 + + for tile_result in tile_results: + polygons = tile_result["polygons"] + polygon_valid = tile_result["polygon_valid"] + + if len(polygons) <= 1: + continue # Skip tiles with 0 or 1 polygon + + # Keep iterating until no overlaps are found + while True: + # Get currently valid polygons with their indices + valid_polygons = [(i, poly) for i, poly in enumerate(polygons) if polygon_valid[i]] + + if len(valid_polygons) <= 1: + break # No overlaps possible with 0 or 1 valid polygons + + # Find all overlapping pairs + overlapping_pairs = [] + for i in range(len(valid_polygons)): + for j in range(i + 1, len(valid_polygons)): + idx1, poly1 = valid_polygons[i] + idx2, poly2 = valid_polygons[j] + + if check_polygon_overlap(poly1, poly2): + overlapping_pairs.append((idx1, idx2)) + + if not overlapping_pairs: + break # No overlaps found + + # Find all polygons involved in overlaps + overlapping_indices = set() + for idx1, idx2 in overlapping_pairs: + overlapping_indices.add(idx1) + overlapping_indices.add(idx2) + + # Calculate areas for overlapping polygons + polygon_areas = [] + for idx in overlapping_indices: + area = calculate_polygon_area(polygons[idx]) + polygon_areas.append((idx, area)) + + # Find the largest polygon + largest_idx, _ = max(polygon_areas, key=lambda x: x[1]) + + # Mark the largest polygon as invalid + polygon_valid[largest_idx] = False + + # Continue to next iteration to check for remaining overlaps + + # Now perform edge validation on remaining valid polygons + def is_edge_near_tile_boundary(p1, p2, tile_bounds, tolerance=2): """Check if an edge is colinear with the tile boundary within tolerance.""" x_min, y_min, x_max, y_max = tile_bounds @@ -415,25 +503,33 @@ def is_edge_near_tile_boundary(p1, p2, tile_bounds, tolerance=2): return False def generate_edge_sample_points(p1, p2, num_points=10, margin_px=10): - """Generate equally spaced points along an edge, leaving a fixed margin at each end.""" + """Generate equally spaced points along an edge, leaving a fixed margin at each end. + Always generates at least one point in the center of the line.""" # Calculate edge length edge_length = math.sqrt((p2[0] - p1[0])**2 + (p2[1] - p1[1])**2) - # If edge is too short to accommodate margins, return empty list - if edge_length <= 2 * margin_px: - return [] + # Always generate center point + center_x = p1[0] + 0.5 * (p2[0] - p1[0]) + center_y = p1[1] + 0.5 * (p2[1] - p1[1]) - # Calculate the usable length (excluding margins) - usable_length = edge_length - 2 * margin_px + # If edge is too short to accommodate margins, return just the center point + if edge_length <= 2 * margin_px: + return [(center_x, center_y)] # Calculate t values for the start and end of the usable region t_start = margin_px / edge_length t_end = 1.0 - margin_px / edge_length points = [] + + # If only one point requested, return center point + if num_points == 1: + return [(center_x, center_y)] + + # Generate points evenly spaced within the usable region for i in range(num_points): # Distribute points evenly within the usable region - t_local = i / (num_points - 1) if num_points > 1 else 0.5 + t_local = i / (num_points - 1) t = t_start + t_local * (t_end - t_start) x = p1[0] + t * (p2[0] - p1[0]) @@ -451,7 +547,7 @@ def point_in_polygon(point, polygon): return cv2.pointPolygonTest(poly_points, point, False) >= 0 # Process each tile - for tile_idx, (tile_result, tile_pos) in enumerate(zip(tile_results, positions)): + for tile_result, tile_pos in zip(tile_results, positions): x, y, x_end, y_end = tile_pos tile_width = x_end - x tile_height = y_end - y @@ -460,8 +556,12 @@ def point_in_polygon(point, polygon): polygons = tile_result["polygons"] polygon_valid = tile_result["polygon_valid"] - # Check each polygon in this tile + # Check each polygon in this tile (only those still valid after overlap removal) for poly_idx, polygon in enumerate(polygons): + # Skip polygons already rejected for overlap + if not polygon_valid[poly_idx]: + continue + if len(polygon) < 3: polygon_valid[poly_idx] = False continue @@ -497,8 +597,8 @@ def point_in_polygon(point, polygon): point_found_in_other_polygon = False # Check all other tiles - for other_tile_idx, (other_tile_result, other_tile_pos) in enumerate(zip(tile_results, positions)): - if other_tile_idx == tile_idx: + for other_tile_result, other_tile_pos in zip(tile_results, positions): + if other_tile_result is tile_result: continue other_x, other_y, other_x_end, other_y_end = other_tile_pos @@ -518,8 +618,12 @@ def point_in_polygon(point, polygon): # Convert global point to other tile's local coordinates local_point = (global_point[0] - other_x, global_point[1] - other_y) - # Check if point is inside any polygon in this other tile - for other_polygon in other_tile_result["polygons"]: + # Check if point is inside any valid polygon in this other tile + for other_poly_idx, other_polygon in enumerate(other_tile_result["polygons"]): + # Only consider polygons that are still valid (not rejected for overlap) + if not other_tile_result["polygon_valid"][other_poly_idx]: + continue + if point_in_polygon(local_point, other_polygon): point_found_in_other_polygon = True break @@ -546,6 +650,7 @@ def _merge_polygons( positions: List[Tuple[int, int, int, int]], image_height: int, image_width: int, + debug: bool = False, ) -> List[np.ndarray]: """Merge polygon predictions from multiple tiles using a bitmap approach. @@ -570,8 +675,6 @@ def _merge_polygons( # Fill bitmap with polygon regions for tile_idx, (tile_result, (x, y, x_end, y_end)) in enumerate(zip(tile_results, positions)): - log(f"tile_idx: {tile_idx}") - tile_polygons = tile_result["polygons"] polygon_valid = tile_result["polygon_valid"] @@ -601,8 +704,9 @@ def _merge_polygons( bitmap = cv2.morphologyEx(bitmap, cv2.MORPH_CLOSE, kernel) # Save bitmap for debugging (optional) - cv2.imwrite('debug_polygon_bitmap.png', bitmap) - log("Saved debug bitmap to debug_polygon_bitmap.png") + if debug: + cv2.imwrite('debug_polygon_bitmap.png', bitmap) + log("Saved debug bitmap to debug_polygon_bitmap.png") # Find contours in the bitmap contours, _ = cv2.findContours(bitmap, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) @@ -633,11 +737,12 @@ def _merge_polygons( log(f"Bitmap approach: {len(merged_polygons)} polygons extracted from bitmap") return merged_polygons - def infer(self, image_data: bytes) -> List[List[List[float]]]: + def infer(self, image_data: bytes, debug: bool = False) -> List[List[List[float]]]: """Infer polygons in an image. Args: image_data (bytes): Raw image data + debug (bool): Whether to save debug images (tile visualization and bitmap) Returns: list[list[list[float]]]: List of polygons where each polygon is a list of [x,y] coordinates. @@ -697,7 +802,7 @@ def infer(self, image_data: bytes) -> List[List[List[float]]]: for i in range(0, len(tiles), CFG.PREDICTION_BATCH_SIZE): batch_start_time = time.time() batch_tiles = tiles[i : i + CFG.PREDICTION_BATCH_SIZE] - batch_results = self._process_tiles_batch(batch_tiles) + batch_results = self._process_tiles_batch(batch_tiles, debug) all_results.extend(batch_results) batch_time = time.time() - batch_start_time @@ -707,9 +812,10 @@ def infer(self, image_data: bytes) -> List[List[List[float]]]: all_results = self._validate_all_polygons(all_results, bboxes, height, width) # Create tile visualization - self._create_tile_visualization(tiles, all_results, bboxes) + if debug: + self._create_tile_visualization(tiles, all_results, bboxes) - merged_polygons = self._merge_polygons(all_results, bboxes, height, width) + merged_polygons = self._merge_polygons(all_results, bboxes, height, width, debug) # Convert to list format polygons_list = [poly.tolist() for poly in merged_polygons] From f74fb14d1286b1f52c5d1ebaedd6e50cc2faf969 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Mon, 7 Jul 2025 13:35:02 -0500 Subject: [PATCH 14/45] Stop generating zero-length line segments. --- polygon_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index 78e1d65..1264c65 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -568,9 +568,9 @@ def point_in_polygon(point, polygon): # Find edges that are near tile boundaries boundary_edges = [] - for i in range(len(polygon)): + for i in range(len(polygon) - 1): p1 = polygon[i] - p2 = polygon[(i + 1) % len(polygon)] + p2 = polygon[i + 1] if is_edge_near_tile_boundary(p1, p2, tile_bounds): boundary_edges.append((p1, p2)) From de4ec431a101d38e91ccb5c4f0f5d762d75dd08d Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Thu, 10 Jul 2025 12:25:48 -0500 Subject: [PATCH 15/45] Improve visualization output, improve logging, add 2 pixel merge tolerance. --- polygon_inference.py | 154 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 126 insertions(+), 28 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index 1264c65..d8a997c 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -277,7 +277,7 @@ def _create_tile_visualization( tile_results: List[Dict[str, List[np.ndarray]]], positions: List[Tuple[int, int, int, int]], ) -> None: - """Create a tile visualization showing each tile with its detected polygons. + """Create a tile visualization showing each tile with its detected polygons and coordinate scales. Args: tiles (List[np.ndarray]): List of tile images @@ -328,7 +328,36 @@ def _create_tile_visualization( # Tiles are already in RGB format, no conversion needed for matplotlib ax.imshow(tile) ax.set_title(f'Tile {i}') - ax.axis('off') + + # Enable axis and set up coordinate scales + ax.axis('on') + + # Get tile dimensions + tile_height, tile_width = tile.shape[:2] + + # Set up x-axis ticks and labels (global coordinates) + x_range = x_end - x + x_step = max(1, x_range // 8) # Show ~8 ticks across width + x_tick_positions = range(0, tile_width, max(1, tile_width // 8)) + x_global_coords = [x + pos * x_range // tile_width for pos in x_tick_positions] + ax.set_xticks(x_tick_positions) + ax.set_xticklabels([str(coord) for coord in x_global_coords], fontsize=8) + + # Set up y-axis ticks and labels (global coordinates) + y_range = y_end - y + y_step = max(1, y_range // 8) # Show ~8 ticks across height + y_tick_positions = range(0, tile_height, max(1, tile_height // 8)) + y_global_coords = [y + pos * y_range // tile_height for pos in y_tick_positions] + ax.set_yticks(y_tick_positions) + ax.set_yticklabels([str(coord) for coord in y_global_coords], fontsize=8) + + # Set axis limits to match tile dimensions + ax.set_xlim(0, tile_width) + ax.set_ylim(tile_height, 0) # Invert y-axis for image coordinates + + # Style the grid and ticks + ax.grid(True, alpha=0.3, linewidth=0.5) + ax.tick_params(axis='both', which='major', labelsize=8, length=3) # Draw polygons on this tile polygons = tile_result["polygons"] @@ -387,9 +416,16 @@ def _validate_all_polygons( Returns: List[Dict[str, List[np.ndarray]]]: Updated tile results with validation attributes """ + log("Starting polygon validation process...") + # Initialize polygon_valid list for each tile - for tile_result in tile_results: + total_polygons = 0 + for tile_idx, tile_result in enumerate(tile_results): tile_result["polygon_valid"] = [True] * len(tile_result["polygons"]) + total_polygons += len(tile_result["polygons"]) + log(f"Tile {tile_idx}: {len(tile_result['polygons'])} polygons to validate") + + log(f"Total polygons to validate: {total_polygons}") # Remove overlapping polygons within each tile (before edge validation) @@ -424,11 +460,15 @@ def calculate_polygon_area(poly): except: return 0 - for tile_result in tile_results: + log("Phase 1: Removing overlapping polygons within each tile...") + overlap_removed_count = 0 + + for tile_idx, tile_result in enumerate(tile_results): polygons = tile_result["polygons"] polygon_valid = tile_result["polygon_valid"] if len(polygons) <= 1: + log(f"Tile {tile_idx}: Skipping overlap check (≤1 polygon)") continue # Skip tiles with 0 or 1 polygon # Keep iterating until no overlaps are found @@ -448,8 +488,10 @@ def calculate_polygon_area(poly): if check_polygon_overlap(poly1, poly2): overlapping_pairs.append((idx1, idx2)) + log(f"Tile {tile_idx}: Overlap detected between polygon {idx1} and polygon {idx2}") if not overlapping_pairs: + log(f"Tile {tile_idx}: No overlaps found, overlap validation complete") break # No overlaps found # Find all polygons involved in overlaps @@ -465,14 +507,25 @@ def calculate_polygon_area(poly): polygon_areas.append((idx, area)) # Find the largest polygon - largest_idx, _ = max(polygon_areas, key=lambda x: x[1]) + largest_idx, largest_area = max(polygon_areas, key=lambda x: x[1]) # Mark the largest polygon as invalid polygon_valid[largest_idx] = False + overlap_removed_count += 1 + log(f"Tile {tile_idx}: Marking polygon {largest_idx} as INVALID (largest overlapping polygon, area={largest_area:.2f})") + + # Log areas of all overlapping polygons for context + for idx, area in polygon_areas: + if idx != largest_idx: + log(f"Tile {tile_idx}: Polygon {idx} (area={area:.2f}) remains valid") # Continue to next iteration to check for remaining overlaps + log(f"Phase 1 complete: {overlap_removed_count} polygons removed due to overlaps") + # Now perform edge validation on remaining valid polygons + log("Phase 2: Validating polygons with boundary edges...") + edge_rejected_count = 0 def is_edge_near_tile_boundary(p1, p2, tile_bounds, tolerance=2): """Check if an edge is colinear with the tile boundary within tolerance.""" @@ -544,10 +597,10 @@ def point_in_polygon(point, polygon): return False # Convert polygon to the format expected by cv2.pointPolygonTest poly_points = polygon.astype(np.float32).reshape((-1, 1, 2)) - return cv2.pointPolygonTest(poly_points, point, False) >= 0 + return cv2.pointPolygonTest(poly_points, point, True) >= -2 # Process each tile - for tile_result, tile_pos in zip(tile_results, positions): + for tile_idx, (tile_result, tile_pos) in enumerate(zip(tile_results, positions)): x, y, x_end, y_end = tile_pos tile_width = x_end - x tile_height = y_end - y @@ -556,14 +609,19 @@ def point_in_polygon(point, polygon): polygons = tile_result["polygons"] polygon_valid = tile_result["polygon_valid"] + valid_polygons_count = sum(polygon_valid) + log(f"Tile {tile_idx}: Processing {valid_polygons_count} valid polygons for edge validation") + # Check each polygon in this tile (only those still valid after overlap removal) for poly_idx, polygon in enumerate(polygons): # Skip polygons already rejected for overlap if not polygon_valid[poly_idx]: + log(f"Tile {tile_idx}, Polygon {poly_idx}: SKIPPED (already invalid from overlap check)") continue if len(polygon) < 3: polygon_valid[poly_idx] = False + log(f"Tile {tile_idx}, Polygon {poly_idx}: INVALID (< 3 vertices)") continue # Find edges that are near tile boundaries @@ -577,24 +635,31 @@ def point_in_polygon(point, polygon): # If no boundary edges, polygon is valid (not on tile boundary) if not boundary_edges: + log(f"Tile {tile_idx}, Polygon {poly_idx}: VALID (no boundary edges)") continue + log(f"Tile {tile_idx}, Polygon {poly_idx}: Found {len(boundary_edges)} boundary edges, validating...") + # Check sample points along boundary edges polygon_is_valid = True - for p1, p2 in boundary_edges: + for edge_idx, (p1, p2) in enumerate(boundary_edges): sample_points = generate_edge_sample_points(p1, p2) # Determine if this edge is horizontal or vertical is_horizontal_edge = abs(p1[1] - p2[1]) <= 2 # Edge is roughly horizontal is_vertical_edge = abs(p1[0] - p2[0]) <= 2 # Edge is roughly vertical + edge_type = "horizontal" if is_horizontal_edge else "vertical" if is_vertical_edge else "diagonal" + log(f"Tile {tile_idx}, Polygon {poly_idx}, Edge {edge_idx}: {edge_type} edge with {len(sample_points)} sample points") + # Convert sample points to global image coordinates global_sample_points = [(px + x, py + y) for px, py in sample_points] # Check if each sample point is contained in any polygon from other tiles - for global_point in global_sample_points: + for point_idx, global_point in enumerate(global_sample_points): point_found_in_other_polygon = False + match_info = None # Check all other tiles for other_tile_result, other_tile_pos in zip(tile_results, positions): @@ -611,36 +676,70 @@ def point_in_polygon(point, polygon): if is_vertical_edge and other_x == x: continue - # Check if point is within other tile bounds - if (other_x <= global_point[0] < other_x_end and - other_y <= global_point[1] < other_y_end): - - # Convert global point to other tile's local coordinates - local_point = (global_point[0] - other_x, global_point[1] - other_y) + # Convert global point to other tile's local coordinates + local_point = (global_point[0] - other_x, global_point[1] - other_y) + + # Check if point is inside any valid polygon in this other tile + for other_poly_idx, other_polygon in enumerate(other_tile_result["polygons"]): + # Only consider polygons that are still valid (not rejected for overlap) + if not other_tile_result["polygon_valid"][other_poly_idx]: + continue - # Check if point is inside any valid polygon in this other tile - for other_poly_idx, other_polygon in enumerate(other_tile_result["polygons"]): - # Only consider polygons that are still valid (not rejected for overlap) - if not other_tile_result["polygon_valid"][other_poly_idx]: - continue + if point_in_polygon(local_point, other_polygon): + point_found_in_other_polygon = True + # Store match location information for detailed logging + match_tile_idx = None + for search_tile_idx, (search_tile_result, search_tile_pos) in enumerate(zip(tile_results, positions)): + if search_tile_result is other_tile_result: + match_tile_idx = search_tile_idx + break - if point_in_polygon(local_point, other_polygon): - point_found_in_other_polygon = True - break - - if point_found_in_other_polygon: + match_info = { + 'tile_idx': match_tile_idx, + 'polygon_idx': other_poly_idx, + 'tile_bounds': (other_x, other_y, other_x_end, other_y_end), + 'local_point': local_point, + 'global_point': global_point + } break - + + if point_found_in_other_polygon: + break + # If any sample point is not found in other polygons, mark as invalid if not point_found_in_other_polygon: polygon_is_valid = False + log(f"Tile {tile_idx}, Polygon {poly_idx}: Sample point {point_idx} at {global_point} NOT found in other polygons") break + else: + # Get the matched polygon points and convert to global coordinates + matched_polygon = tile_results[match_info['tile_idx']]["polygons"][match_info['polygon_idx']] + match_tile_bounds = match_info['tile_bounds'] + # Convert from local tile coordinates to global coordinates + global_polygon_points = matched_polygon + np.array([match_tile_bounds[0], match_tile_bounds[1]]) + # Format points as list of [x, y] coordinates for logging + polygon_points_str = '[' + ', '.join(f'[{x:.1f}, {y:.1f}]' for x, y in global_polygon_points) + ']' + + log(f"Tile {tile_idx}, Polygon {poly_idx}: Sample point {point_idx} at {global_point} found in " + f"Tile {match_info['tile_idx']}, Polygon {match_info['polygon_idx']} " + f"(polygon points: {polygon_points_str})") if not polygon_is_valid: break # Update polygon validity polygon_valid[poly_idx] = polygon_is_valid + + if polygon_is_valid: + log(f"Tile {tile_idx}, Polygon {poly_idx}: VALID (all boundary edge sample points found in other polygons)") + else: + edge_rejected_count += 1 + log(f"Tile {tile_idx}, Polygon {poly_idx}: INVALID (boundary edge validation failed)") + + # Log final validation summary + final_valid_count = sum(sum(tile_result["polygon_valid"]) for tile_result in tile_results) + log(f"Phase 2 complete: {edge_rejected_count} polygons removed due to edge validation") + log(f"Validation summary: {final_valid_count}/{total_polygons} polygons remain valid") return tile_results @@ -720,8 +819,7 @@ def _merge_polygons( continue # Simplify the contour to reduce jaggedness while preserving shape - perimeter = cv2.arcLength(contour, True) - epsilon = 0.01 * perimeter # 1% of perimeter + epsilon = 10 simplified_contour = cv2.approxPolyDP(contour, epsilon, True) # Convert from OpenCV format to our polygon format From b4f13b8fc404257497389f2374f5fb84d7d6484c Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Thu, 10 Jul 2025 13:02:49 -0500 Subject: [PATCH 16/45] Use a specialized building footprint simplifier instead of the generic line simplification algorithm. --- environment.yml | 4 +++- polygon_inference.py | 41 +++++++++++++++++++++++++++++------------ 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/environment.yml b/environment.yml index 28ba547..2499e4d 100644 --- a/environment.yml +++ b/environment.yml @@ -7,9 +7,11 @@ dependencies: - python=3.11 - timm=0.9.12 - transformers=4.32.1 - - pycocotools=2.0.6 + - pycocotools>=2.0.6 - torchmetrics=1.2.1 - tensorboard=2.15.1 + - buildingregulariser>=0.2.2 + - geopandas=1.1.1 - pip - pip: - torch==2.1.2 diff --git a/polygon_inference.py b/polygon_inference.py index d8a997c..7ccc064 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -12,6 +12,8 @@ import albumentations as A from albumentations.pytorch import ToTensorV2 from shapely.geometry import Polygon +from buildingregulariser import regularize_geodataframe +import geopandas as gpd import matplotlib.pyplot as plt import matplotlib.patches as patches @@ -810,7 +812,8 @@ def _merge_polygons( # Find contours in the bitmap contours, _ = cv2.findContours(bitmap, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - merged_polygons: List[np.ndarray] = [] + # Collect all valid contours into shapely polygons + shapely_polygons = [] for contour in contours: # Skip very small contours (area is scaled by scale_factor^2) @@ -818,19 +821,33 @@ def _merge_polygons( if area < CFG.MIN_POLYGON_AREA * (scale_factor ** 2): continue - # Simplify the contour to reduce jaggedness while preserving shape - epsilon = 10 - simplified_contour = cv2.approxPolyDP(contour, epsilon, True) + # Convert contour to Shapely Polygon + contour_points = contour.reshape(-1, 2).astype(np.float64) + shapely_polygon = Polygon(contour_points) + shapely_polygons.append(shapely_polygon) + + # Create single GeoDataFrame with all polygons and regularize them all at once + if shapely_polygons: + gdf = gpd.GeoDataFrame({'geometry': shapely_polygons}) + regularized_gdf = regularize_geodataframe(gdf, simplify_tolerance=20, parallel_threshold=100) - # Convert from OpenCV format to our polygon format - if len(simplified_contour) >= 3: # Valid polygon needs at least 3 points - # Reshape from (n, 1, 2) to (n, 2) and convert to float - polygon_coords = simplified_contour.reshape(-1, 2).astype(np.float32) - - # Scale down coordinates back to original image coordinate system - polygon_coords = polygon_coords / scale_factor + # Process the regularized polygons + merged_polygons: List[np.ndarray] = [] + + for regularized_polygon in regularized_gdf.geometry: + # Convert back to numpy array for OpenCV format + coords = np.array(regularized_polygon.exterior.coords[:-1]) # Remove duplicate last point + simplified_contour = coords.reshape(-1, 1, 2).astype(np.int32) - merged_polygons.append(polygon_coords) + # Convert from OpenCV format to our polygon format + if len(simplified_contour) >= 3: # Valid polygon needs at least 3 points + # Reshape from (n, 1, 2) to (n, 2) and convert to float + polygon_coords = simplified_contour.reshape(-1, 2).astype(np.float32) + + # Scale down coordinates back to original image coordinate system + polygon_coords = polygon_coords / scale_factor + + merged_polygons.append(polygon_coords) log(f"Bitmap approach: {len(merged_polygons)} polygons extracted from bitmap") return merged_polygons From abd598a446b3081050aa44c9b53a3129c7698876 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Thu, 10 Jul 2025 16:33:41 -0500 Subject: [PATCH 17/45] Simplify and improve code. --- polygon_inference.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index 7ccc064..9aa310a 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -826,26 +826,23 @@ def _merge_polygons( shapely_polygon = Polygon(contour_points) shapely_polygons.append(shapely_polygon) + # Initialize merged_polygons to avoid NameError if no valid contours are found + merged_polygons: List[np.ndarray] = [] + # Create single GeoDataFrame with all polygons and regularize them all at once if shapely_polygons: gdf = gpd.GeoDataFrame({'geometry': shapely_polygons}) regularized_gdf = regularize_geodataframe(gdf, simplify_tolerance=20, parallel_threshold=100) # Process the regularized polygons - merged_polygons: List[np.ndarray] = [] - for regularized_polygon in regularized_gdf.geometry: # Convert back to numpy array for OpenCV format coords = np.array(regularized_polygon.exterior.coords[:-1]) # Remove duplicate last point - simplified_contour = coords.reshape(-1, 1, 2).astype(np.int32) # Convert from OpenCV format to our polygon format - if len(simplified_contour) >= 3: # Valid polygon needs at least 3 points - # Reshape from (n, 1, 2) to (n, 2) and convert to float - polygon_coords = simplified_contour.reshape(-1, 2).astype(np.float32) - + if len(coords) >= 3: # Valid polygon needs at least 3 points # Scale down coordinates back to original image coordinate system - polygon_coords = polygon_coords / scale_factor + polygon_coords = coords.astype(np.float32) / scale_factor merged_polygons.append(polygon_coords) From de7e6b586bb014a96cd4fbd33533ec0ffa7dda5d Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Fri, 11 Jul 2025 08:52:00 -0500 Subject: [PATCH 18/45] Make key parameters configurable through the API request. --- api.py | 37 ++++++++++++++++++++++++++++++++++--- config.py | 4 ++++ polygon_inference.py | 18 +++++++++++++----- 3 files changed, 51 insertions(+), 8 deletions(-) diff --git a/api.py b/api.py index d9f97ba..305ccff 100644 --- a/api.py +++ b/api.py @@ -170,6 +170,8 @@ async def invoke( request: Request, file: UploadFile = None, api_key: Optional[str] = Depends(verify_api_key), + merge_tolerance: Optional[float] = Query(None, description="Tolerance for point-in-polygon tests during validation (in pixels, allows points to be slightly outside)"), + tile_overlap_ratio: Optional[float] = Query(None, description="Overlap ratio between tiles (0.0 = no overlap, 1.0 = complete overlap)"), ): """Main inference endpoint for processing images. @@ -182,10 +184,16 @@ async def invoke( 1. Via the X-API-Key header 2. Via the api_key query parameter + Configuration parameters can be provided in two ways: + 1. Via query parameters (merge_tolerance, tile_overlap_ratio) + 2. Via the JSON payload fields (merge_tolerance, tile_overlap_ratio) + Args: request: The request containing the image data file: Optional uploaded file (multipart/form-data) api_key: Optional API key for authentication (required only if API key is configured) + merge_tolerance: Optional tolerance for point-in-polygon tests during validation (in pixels, allows points to be slightly outside) + tile_overlap_ratio: Optional overlap ratio between tiles (0.0 = no overlap, 1.0 = complete overlap) Returns: JSON response containing the inferred polygons @@ -198,6 +206,18 @@ async def invoke( """ log(f"Invoking image analysis") + # Initialize configuration parameters and validate ranges + effective_merge_tolerance = merge_tolerance + effective_tile_overlap_ratio = tile_overlap_ratio + + # Validate merge_tolerance (should be positive) + if effective_merge_tolerance is not None and effective_merge_tolerance < 0: + raise HTTPException(status_code=400, detail="merge_tolerance must be non-negative") + + # Validate tile_overlap_ratio (should be between 0.0 and 1.0) + if effective_tile_overlap_ratio is not None and (effective_tile_overlap_ratio < 0.0 or effective_tile_overlap_ratio > 1.0): + raise HTTPException(status_code=400, detail="tile_overlap_ratio must be between 0.0 and 1.0") + if file: # Handle file upload image_data = await file.read() @@ -211,6 +231,16 @@ async def invoke( if "image" in data: # Handle base64 encoded image image_data = base64.b64decode(data["image"]) + + # Extract configuration parameters from JSON (if query params not provided) + if effective_merge_tolerance is None and "merge_tolerance" in data: + effective_merge_tolerance = float(data["merge_tolerance"]) + if effective_merge_tolerance < 0: + raise HTTPException(status_code=400, detail="merge_tolerance must be non-negative") + if effective_tile_overlap_ratio is None and "tile_overlap_ratio" in data: + effective_tile_overlap_ratio = float(data["tile_overlap_ratio"]) + if effective_tile_overlap_ratio < 0.0 or effective_tile_overlap_ratio > 1.0: + raise HTTPException(status_code=400, detail="tile_overlap_ratio must be between 0.0 and 1.0") else: raise HTTPException( status_code=400, detail="No image data found in request" @@ -219,15 +249,16 @@ async def invoke( # Handle raw image data image_data = body - # Generate cache key and check cache - cache_key = get_cache_key(image_data) + # Generate cache key including configuration parameters + cache_key_base = get_cache_key(image_data) + cache_key = f"{cache_key_base}_{effective_merge_tolerance}_{effective_tile_overlap_ratio}" cached_result = cache.get(cache_key) if cached_result is not None: return JSONResponse(content=cached_result) # Get inferences - polygons = predictor.infer(image_data) + polygons = predictor.infer(image_data, merge_tolerance=effective_merge_tolerance, tile_overlap_ratio=effective_tile_overlap_ratio) # Prepare response response = { diff --git a/config.py b/config.py index 1519f40..0c2ddf5 100644 --- a/config.py +++ b/config.py @@ -74,6 +74,10 @@ class CFG: # Prediction configuration PREDICTION_BATCH_SIZE = 8 # Batch size for processing tiles during prediction + # Polygon validation configuration + MERGE_TOLERANCE = 2 # Tolerance for point-in-polygon tests during validation (in pixels, allows points to be slightly outside) + TILE_OVERLAP_RATIO = 0.5 # Overlap ratio between tiles (0.0 = no overlap, 1.0 = complete overlap) + BATCH_SIZE = 24 # batch size per gpu; effective batch size = BATCH_SIZE * NUM_GPUs START_EPOCH = 0 NUM_EPOCHS = 500 diff --git a/polygon_inference.py b/polygon_inference.py index 9aa310a..5e76699 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -402,7 +402,8 @@ def _validate_all_polygons( tile_results: List[Dict[str, List[np.ndarray]]], positions: List[Tuple[int, int, int, int]], image_height: int, - image_width: int + image_width: int, + merge_tolerance: float ) -> List[Dict[str, List[np.ndarray]]]: """Validate all polygons in the tile results and add validation attributes. @@ -414,6 +415,7 @@ def _validate_all_polygons( positions (List[Tuple[int, int, int, int]]): List of (x, y, x_end, y_end) tuples for each tile's position image_height (int): Height of the original image image_width (int): Width of the original image + merge_tolerance (float): Tolerance for point-in-polygon tests during validation (in pixels) Returns: List[Dict[str, List[np.ndarray]]]: Updated tile results with validation attributes @@ -599,7 +601,7 @@ def point_in_polygon(point, polygon): return False # Convert polygon to the format expected by cv2.pointPolygonTest poly_points = polygon.astype(np.float32).reshape((-1, 1, 2)) - return cv2.pointPolygonTest(poly_points, point, True) >= -2 + return cv2.pointPolygonTest(poly_points, point, True) >= -merge_tolerance # Process each tile for tile_idx, (tile_result, tile_pos) in enumerate(zip(tile_results, positions)): @@ -849,12 +851,14 @@ def _merge_polygons( log(f"Bitmap approach: {len(merged_polygons)} polygons extracted from bitmap") return merged_polygons - def infer(self, image_data: bytes, debug: bool = False) -> List[List[List[float]]]: + def infer(self, image_data: bytes, debug: bool = False, merge_tolerance: Optional[float] = None, tile_overlap_ratio: Optional[float] = None) -> List[List[List[float]]]: """Infer polygons in an image. Args: image_data (bytes): Raw image data debug (bool): Whether to save debug images (tile visualization and bitmap) + merge_tolerance (Optional[float]): Tolerance for point-in-polygon tests during validation (in pixels, allows points to be slightly outside). If None, uses CFG.MERGE_TOLERANCE + tile_overlap_ratio (Optional[float]): Overlap ratio between tiles (0.0 = no overlap, 1.0 = complete overlap). If None, uses CFG.TILE_OVERLAP_RATIO Returns: list[list[list[float]]]: List of polygons where each polygon is a list of [x,y] coordinates. @@ -886,7 +890,11 @@ def infer(self, image_data: bytes, debug: bool = False) -> List[List[List[float] if height == 0 or width == 0: raise ValueError("Invalid image dimensions") - overlap_ratio = 0.5 + # Use provided parameters or fall back to config defaults + effective_merge_tolerance = merge_tolerance if merge_tolerance is not None else CFG.MERGE_TOLERANCE + effective_tile_overlap_ratio = tile_overlap_ratio if tile_overlap_ratio is not None else CFG.TILE_OVERLAP_RATIO + + overlap_ratio = effective_tile_overlap_ratio bboxes = calculate_slice_bboxes( image_height=height, @@ -921,7 +929,7 @@ def infer(self, image_data: bytes, debug: bool = False) -> List[List[List[float] log(f"Processed batch of {len(batch_tiles)} tiles: {batch_time/len(batch_tiles):.3f}s per tile") # Validate all polygons and add validation attributes - all_results = self._validate_all_polygons(all_results, bboxes, height, width) + all_results = self._validate_all_polygons(all_results, bboxes, height, width, effective_merge_tolerance) # Create tile visualization if debug: From b6822fddf517594d3dedb2a409f80c619ade30e6 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Mon, 14 Jul 2025 10:59:56 -0500 Subject: [PATCH 19/45] Handle whatever comes out of findContours. --- polygon_inference.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/polygon_inference.py b/polygon_inference.py index 5e76699..a7453d8 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -12,6 +12,7 @@ import albumentations as A from albumentations.pytorch import ToTensorV2 from shapely.geometry import Polygon +from shapely.validation import make_valid from buildingregulariser import regularize_geodataframe import geopandas as gpd @@ -826,7 +827,12 @@ def _merge_polygons( # Convert contour to Shapely Polygon contour_points = contour.reshape(-1, 2).astype(np.float64) shapely_polygon = Polygon(contour_points) - shapely_polygons.append(shapely_polygon) + + shapely_polygon = make_valid(shapely_polygon) + if shapely_polygon.is_valid: + shapely_polygons.append(shapely_polygon) + else: + log(f"Skipping invalid polygon") # Initialize merged_polygons to avoid NameError if no valid contours are found merged_polygons: List[np.ndarray] = [] From 9c3ee2ed26b338b3a8180062de295ddc77c261d7 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Mon, 14 Jul 2025 14:27:34 -0500 Subject: [PATCH 20/45] Further improvements to normalization of findContours results. --- polygon_inference.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/polygon_inference.py b/polygon_inference.py index a7453d8..8f41792 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -829,13 +829,27 @@ def _merge_polygons( shapely_polygon = Polygon(contour_points) shapely_polygon = make_valid(shapely_polygon) + + # Handle case where make_valid returns a MultiPolygon if shapely_polygon.is_valid: - shapely_polygons.append(shapely_polygon) + if shapely_polygon.geom_type == 'MultiPolygon': + # Extract individual polygons from MultiPolygon + for individual_poly in shapely_polygon.geoms: + # Only keep exterior ring (no holes) + simple_poly = Polygon(individual_poly.exterior.coords) + if simple_poly.is_valid and simple_poly.area > 0: + shapely_polygons.append(simple_poly) + elif shapely_polygon.geom_type == 'Polygon': + # Only keep exterior ring (no holes) + simple_poly = Polygon(shapely_polygon.exterior.coords) + if simple_poly.is_valid and simple_poly.area > 0: + shapely_polygons.append(simple_poly) else: log(f"Skipping invalid polygon") # Initialize merged_polygons to avoid NameError if no valid contours are found merged_polygons: List[np.ndarray] = [] + log(f"Shapely polygons: {shapely_polygons}") # Create single GeoDataFrame with all polygons and regularize them all at once if shapely_polygons: From 6049010956a014b3ed48f2b501ffb381eaa135e4 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Mon, 14 Jul 2025 14:27:58 -0500 Subject: [PATCH 21/45] Further improvements to normalization of findContours results. --- polygon_inference.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index 8f41792..1b71df2 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -835,12 +835,10 @@ def _merge_polygons( if shapely_polygon.geom_type == 'MultiPolygon': # Extract individual polygons from MultiPolygon for individual_poly in shapely_polygon.geoms: - # Only keep exterior ring (no holes) simple_poly = Polygon(individual_poly.exterior.coords) if simple_poly.is_valid and simple_poly.area > 0: shapely_polygons.append(simple_poly) elif shapely_polygon.geom_type == 'Polygon': - # Only keep exterior ring (no holes) simple_poly = Polygon(shapely_polygon.exterior.coords) if simple_poly.is_valid and simple_poly.area > 0: shapely_polygons.append(simple_poly) @@ -849,7 +847,6 @@ def _merge_polygons( # Initialize merged_polygons to avoid NameError if no valid contours are found merged_polygons: List[np.ndarray] = [] - log(f"Shapely polygons: {shapely_polygons}") # Create single GeoDataFrame with all polygons and regularize them all at once if shapely_polygons: From 5d0e308b6c78d7aa4982f304e1fdfe1c2d1d7972 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Tue, 15 Jul 2025 16:24:14 -0500 Subject: [PATCH 22/45] Load checkpoint weights dynamically to get all pretrained models working without crashing. --- polygon_inference.py | 78 ++++++++++++++++++++++++++++++++------------ 1 file changed, 57 insertions(+), 21 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index 1b71df2..303882d 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -3,12 +3,14 @@ import time import hashlib import pickle +import copy from typing import List, Tuple, Dict, Optional, Any # Third-party imports import numpy as np import cv2 import torch +import torch.nn.functional as F import albumentations as A from albumentations.pytorch import ToTensorV2 from shapely.geometry import Polygon @@ -121,41 +123,73 @@ def _initialize_model(self) -> None: """Initialize the model and tokenizer. This method: - 1. Creates a new tokenizer instance - 2. Initializes the encoder-decoder model - 3. Loads the latest checkpoint from the experiment directory + 1. Loads the checkpoint to inspect the saved model configuration + 2. Dynamically adapts the configuration to match the checkpoint + 3. Creates a new tokenizer instance + 4. Initializes the encoder-decoder model with the correct architecture + 5. Loads the checkpoint weights """ + # Load checkpoint first to inspect saved model configuration + latest_checkpoint = self._find_single_checkpoint() + checkpoint_path = os.path.join( + self.experiment_path, "logs", "checkpoints", latest_checkpoint + ) + checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) + + # Create a copy of CFG for model creation + model_cfg = copy.deepcopy(CFG) + + # Dynamically determine configuration from the saved positional embeddings + decoder_pos_embed_key = "decoder.decoder_pos_embed" + encoder_pos_embed_key = "decoder.encoder_pos_embed" + + if decoder_pos_embed_key in checkpoint["state_dict"]: + saved_decoder_pos_embed_shape = checkpoint["state_dict"][decoder_pos_embed_key].shape + checkpoint_max_len_minus_1 = saved_decoder_pos_embed_shape[1] # Shape is [1, MAX_LEN-1, embed_dim] + checkpoint_max_len = checkpoint_max_len_minus_1 + 1 + checkpoint_n_vertices = (checkpoint_max_len - 2) // 2 # Reverse: MAX_LEN = (N_VERTICES*2) + 2 + + if checkpoint_n_vertices != CFG.N_VERTICES: + model_cfg.N_VERTICES = checkpoint_n_vertices + model_cfg.MAX_LEN = checkpoint_max_len + + if encoder_pos_embed_key in checkpoint["state_dict"]: + saved_encoder_pos_embed_shape = checkpoint["state_dict"][encoder_pos_embed_key].shape + checkpoint_num_patches = saved_encoder_pos_embed_shape[1] # Shape is [1, num_patches, embed_dim] + + if checkpoint_num_patches != CFG.NUM_PATCHES: + model_cfg.NUM_PATCHES = checkpoint_num_patches + + # Create tokenizer with the adapted configuration self.tokenizer = Tokenizer( num_classes=1, - num_bins=CFG.NUM_BINS, - width=CFG.INPUT_WIDTH, - height=CFG.INPUT_HEIGHT, - max_len=CFG.MAX_LEN, + num_bins=model_cfg.NUM_BINS, + width=model_cfg.INPUT_WIDTH, + height=model_cfg.INPUT_HEIGHT, + max_len=model_cfg.MAX_LEN, ) + # Use the original CFG for PAD_IDX to maintain compatibility CFG.PAD_IDX = self.tokenizer.PAD_code - encoder = Encoder(model_name=CFG.MODEL_NAME, pretrained=True, out_dim=256) + # Create model with the adapted configuration + encoder = Encoder(model_name=model_cfg.MODEL_NAME, pretrained=True, out_dim=256) decoder = Decoder( - cfg=CFG, + cfg=model_cfg, # Use adapted configuration vocab_size=self.tokenizer.vocab_size, - encoder_len=CFG.NUM_PATCHES, + encoder_len=model_cfg.NUM_PATCHES, dim=256, num_heads=8, num_layers=6, ) - self.model = EncoderDecoder(cfg=CFG, encoder=encoder, decoder=decoder) + self.model = EncoderDecoder(cfg=model_cfg, encoder=encoder, decoder=decoder) self.model.to(self.device) self.model.eval() + + # Store the adapted configuration for inference + self.model_cfg = model_cfg - # Load latest checkpoint - latest_checkpoint = self._find_single_checkpoint() - checkpoint_path = os.path.join( - self.experiment_path, "logs", "checkpoints", latest_checkpoint - ) - log(f"Loading checkpoint from {checkpoint_path}") - checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) + # Load checkpoint weights - should now match perfectly self.model.load_state_dict(checkpoint["state_dict"]) - log("Checkpoint loaded successfully") def _find_single_checkpoint(self) -> str: """Find the single checkpoint file. Crashes if there is more than one checkpoint. @@ -231,11 +265,13 @@ def _process_tiles_batch( batch_tensor = torch.stack(transformed_tiles).to(self.device) with torch.no_grad(): + # Use adapted configuration for generation + adapted_generation_steps = (self.model_cfg.N_VERTICES * 2) + 1 batch_preds, batch_confs, perm_preds = test_generate( self.model, batch_tensor, self.tokenizer, - max_len=CFG.generation_steps, + max_len=adapted_generation_steps, top_k=0, top_p=1, ) @@ -249,7 +285,7 @@ def _process_tiles_batch( else: coord = torch.tensor([]) - padd = torch.ones((CFG.N_VERTICES - len(coord), 2)).fill_(CFG.PAD_IDX) + padd = torch.ones((self.model_cfg.N_VERTICES - len(coord), 2)).fill_(CFG.PAD_IDX) coord = torch.cat([coord, padd], dim=0) batch_polygons = permutations_to_polygons( From e19b3642bd9512ab6dda4e141c2e938a9e8e2950 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Tue, 15 Jul 2025 16:33:54 -0500 Subject: [PATCH 23/45] Make it possible to control the model through the API request. --- api.py | 97 ++++++++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 81 insertions(+), 16 deletions(-) diff --git a/api.py b/api.py index 305ccff..652860a 100644 --- a/api.py +++ b/api.py @@ -14,6 +14,7 @@ import shutil from pathlib import Path from diskcache import Cache +import re from polygon_inference import PolygonInference from utils import log @@ -27,6 +28,9 @@ EXPERIMENT_PATH = os.getenv("EXPERIMENT_PATH", "runs_share/Pix2Poly_inria_coco_224") MODEL_URL = os.getenv("MODEL_URL", "https://github.com/safelease/Pix2Poly/releases/download/main/runs_share.zip") +# Default model name extracted from EXPERIMENT_PATH +DEFAULT_MODEL_NAME = os.path.basename(EXPERIMENT_PATH) + # Cache configuration CACHE_TTL = int(os.getenv("CACHE_TTL", 24 * 3600)) # 24 hours @@ -38,16 +42,33 @@ disk_pickle_protocol=4, # Use protocol 4 for better compatibility ) -def get_cache_key(image_data: bytes) -> str: - """Generate a cache key from image data. +def get_cache_key(image_data: bytes, model_name: str = None, merge_tolerance: float = None, tile_overlap_ratio: float = None) -> str: + """Generate a cache key from image data and parameters. Args: image_data: Raw image data + model_name: Model name being used + merge_tolerance: Merge tolerance parameter + tile_overlap_ratio: Tile overlap ratio parameter + + Returns: + SHA-256 hash of the image data combined with parameters as a string + """ + image_hash = hashlib.sha256(image_data).hexdigest() + return f"{image_hash}_{model_name}_{merge_tolerance}_{tile_overlap_ratio}" + + +def validate_model_name(model_name: str) -> bool: + """Validate that the model name contains only safe characters. + + Args: + model_name: The model name to validate Returns: - SHA-256 hash of the image data as a string + True if the model name is valid, False otherwise """ - return hashlib.sha256(image_data).hexdigest() + # Allow alphanumeric characters, underscores, and hyphens + return bool(re.match(r'^[a-zA-Z0-9_-]+$', model_name)) async def verify_api_key( @@ -127,14 +148,16 @@ def download_model_files(model_url: str, target_dir: str) -> str: @asynccontextmanager async def lifespan(app: FastAPI): """Initialize the predictor on startup.""" + global model_dir + # Download model files to a temporary directory model_dir = download_model_files( MODEL_URL, "/tmp/pix2poly_model", ) - # Initialize predictor with downloaded model - init_predictor(os.path.join(model_dir, EXPERIMENT_PATH)) + # Initialize predictor with downloaded model using the default model name + init_predictor(os.path.join(model_dir, EXPERIMENT_PATH), DEFAULT_MODEL_NAME) yield @@ -154,15 +177,48 @@ async def lifespan(app: FastAPI): allow_headers=["*"], # Allows all headers ) -# Global predictor instance +# Global predictor instance and current model tracking predictor = None +current_model_name = None +model_dir = None -def init_predictor(experiment_path: str): +def init_predictor(experiment_path: str, model_name: str = None): """Initialize the global predictor instance.""" - global predictor - if predictor is None: + global predictor, current_model_name + if predictor is None or current_model_name != model_name: predictor = PolygonInference(experiment_path) + current_model_name = model_name + log(f"Loaded model: {model_name}", "INFO") + + +def load_model(model_name: str): + """Load a specific model by name. + + Args: + model_name: The name of the model to load (e.g., "Pix2Poly_inria_coco_224") + + Raises: + HTTPException: If model name is invalid or model files don't exist + """ + global predictor, current_model_name, model_dir + + if not validate_model_name(model_name): + raise HTTPException(status_code=400, detail="Invalid model name. Only alphanumeric characters, underscores, and hyphens are allowed.") + + # Skip reloading if it's the same model + if current_model_name == model_name and predictor is not None: + return + + # Construct the full experiment path + experiment_path = os.path.join(model_dir, "runs_share", model_name) + + # Check if the model directory exists + if not os.path.exists(experiment_path): + raise HTTPException(status_code=404, detail=f"Model '{model_name}' not found in downloaded model files") + + # Initialize the predictor with the new model + init_predictor(experiment_path, model_name) @app.post("/invocations") @@ -172,6 +228,7 @@ async def invoke( api_key: Optional[str] = Depends(verify_api_key), merge_tolerance: Optional[float] = Query(None, description="Tolerance for point-in-polygon tests during validation (in pixels, allows points to be slightly outside)"), tile_overlap_ratio: Optional[float] = Query(None, description="Overlap ratio between tiles (0.0 = no overlap, 1.0 = complete overlap)"), + model_name: Optional[str] = Query(None, description="Name of the model to use (e.g., 'Pix2Poly_inria_coco_224')"), ): """Main inference endpoint for processing images. @@ -185,8 +242,8 @@ async def invoke( 2. Via the api_key query parameter Configuration parameters can be provided in two ways: - 1. Via query parameters (merge_tolerance, tile_overlap_ratio) - 2. Via the JSON payload fields (merge_tolerance, tile_overlap_ratio) + 1. Via query parameters (merge_tolerance, tile_overlap_ratio, model_name) + 2. Via the JSON payload fields (merge_tolerance, tile_overlap_ratio, model_name) Args: request: The request containing the image data @@ -194,12 +251,14 @@ async def invoke( api_key: Optional API key for authentication (required only if API key is configured) merge_tolerance: Optional tolerance for point-in-polygon tests during validation (in pixels, allows points to be slightly outside) tile_overlap_ratio: Optional overlap ratio between tiles (0.0 = no overlap, 1.0 = complete overlap) + model_name: Optional name of the model to use (e.g., 'Pix2Poly_inria_coco_224') Returns: JSON response containing the inferred polygons Raises: HTTPException: 400 if no image data is found in the request + HTTPException: 404 if the specified model is not found HTTPException: 500 if there is an error processing the image HTTPException: 401 if API key is missing (when API key is configured) HTTPException: 403 if API key is invalid (when API key is configured) @@ -209,6 +268,7 @@ async def invoke( # Initialize configuration parameters and validate ranges effective_merge_tolerance = merge_tolerance effective_tile_overlap_ratio = tile_overlap_ratio + effective_model_name = model_name or DEFAULT_MODEL_NAME # Validate merge_tolerance (should be positive) if effective_merge_tolerance is not None and effective_merge_tolerance < 0: @@ -241,6 +301,8 @@ async def invoke( effective_tile_overlap_ratio = float(data["tile_overlap_ratio"]) if effective_tile_overlap_ratio < 0.0 or effective_tile_overlap_ratio > 1.0: raise HTTPException(status_code=400, detail="tile_overlap_ratio must be between 0.0 and 1.0") + if model_name is None and "model_name" in data: + effective_model_name = str(data["model_name"]) else: raise HTTPException( status_code=400, detail="No image data found in request" @@ -249,9 +311,11 @@ async def invoke( # Handle raw image data image_data = body - # Generate cache key including configuration parameters - cache_key_base = get_cache_key(image_data) - cache_key = f"{cache_key_base}_{effective_merge_tolerance}_{effective_tile_overlap_ratio}" + # Load the requested model (this will only reload if it's different from the current model) + load_model(effective_model_name) + + # Generate cache key including all configuration parameters + cache_key = get_cache_key(image_data, effective_model_name, effective_merge_tolerance, effective_tile_overlap_ratio) cached_result = cache.get(cache_key) if cached_result is not None: @@ -263,6 +327,7 @@ async def invoke( # Prepare response response = { "polygons": polygons, + "model_name": effective_model_name, } # Store result in cache @@ -275,7 +340,7 @@ async def ping(api_key: Optional[str] = Depends(verify_api_key)): """Health check endpoint to verify service status.""" if predictor is None: raise HTTPException(status_code=503, detail="Model not loaded") - return {"status": "healthy"} + return {"status": "healthy", "current_model": current_model_name} @app.get("/clear-cache") From 03ee2e71d661cd9ba948a0e853ad458a0a1c353d Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Tue, 15 Jul 2025 21:23:16 -0500 Subject: [PATCH 24/45] Fix bad colors in visualization, remove redundant EXPERIMENT_PATH configuration. --- api.py | 9 ++++----- infer_single_image.py | 4 +++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/api.py b/api.py index 652860a..ff8c6da 100644 --- a/api.py +++ b/api.py @@ -25,11 +25,10 @@ # Get API key from environment variable API_KEY = os.getenv("API_KEY") -EXPERIMENT_PATH = os.getenv("EXPERIMENT_PATH", "runs_share/Pix2Poly_inria_coco_224") MODEL_URL = os.getenv("MODEL_URL", "https://github.com/safelease/Pix2Poly/releases/download/main/runs_share.zip") -# Default model name extracted from EXPERIMENT_PATH -DEFAULT_MODEL_NAME = os.path.basename(EXPERIMENT_PATH) +# Default model name (inria dataset) +DEFAULT_MODEL_NAME = "Pix2Poly_inria_coco_224" # Cache configuration CACHE_TTL = int(os.getenv("CACHE_TTL", 24 * 3600)) # 24 hours @@ -157,7 +156,7 @@ async def lifespan(app: FastAPI): ) # Initialize predictor with downloaded model using the default model name - init_predictor(os.path.join(model_dir, EXPERIMENT_PATH), DEFAULT_MODEL_NAME) + init_predictor(os.path.join(model_dir, "runs_share", DEFAULT_MODEL_NAME), DEFAULT_MODEL_NAME) yield @@ -340,7 +339,7 @@ async def ping(api_key: Optional[str] = Depends(verify_api_key)): """Health check endpoint to verify service status.""" if predictor is None: raise HTTPException(status_code=503, detail="Model not loaded") - return {"status": "healthy", "current_model": current_model_name} + return {"status": "healthy"} @app.get("/clear-cache") diff --git a/infer_single_image.py b/infer_single_image.py index e24fdbf..49eddaa 100644 --- a/infer_single_image.py +++ b/infer_single_image.py @@ -52,7 +52,9 @@ def main(): x_max = min(width, x+1) vis_image_merged[y_min:y_max, x_min:x_max] = [255, 0, 0] - plt.imshow(vis_image_merged) + # Convert BGR to RGB for correct display in matplotlib + vis_image_merged_rgb = cv2.cvtColor(vis_image_merged, cv2.COLOR_BGR2RGB) + plt.imshow(vis_image_merged_rgb) plt.axis('off') plt.subplots_adjust(left=0, right=1, top=1, bottom=0) From 433c49ce3f78b35cdf1c9a5e8b7414d64db95f14 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Tue, 15 Jul 2025 21:26:14 -0500 Subject: [PATCH 25/45] Improve visualization. --- polygon_inference.py | 40 ++++++++++++++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index 303882d..e7025ce 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -376,17 +376,49 @@ def _create_tile_visualization( # Set up x-axis ticks and labels (global coordinates) x_range = x_end - x - x_step = max(1, x_range // 8) # Show ~8 ticks across width - x_tick_positions = range(0, tile_width, max(1, tile_width // 8)) + # Generate tick positions ensuring min and max are included + num_x_ticks = 8 + if tile_width > 1: + x_tick_positions = [0] # Always include minimum + if num_x_ticks > 2: + # Add intermediate positions + step = tile_width / (num_x_ticks - 1) + for i in range(1, num_x_ticks - 1): + x_tick_positions.append(int(i * step)) + x_tick_positions.append(tile_width - 1) # Always include maximum + else: + x_tick_positions = [0] + + # Calculate corresponding global coordinates x_global_coords = [x + pos * x_range // tile_width for pos in x_tick_positions] + # Ensure the last coordinate is exactly x_end + if len(x_global_coords) > 1: + x_global_coords[-1] = x_end + ax.set_xticks(x_tick_positions) ax.set_xticklabels([str(coord) for coord in x_global_coords], fontsize=8) # Set up y-axis ticks and labels (global coordinates) y_range = y_end - y - y_step = max(1, y_range // 8) # Show ~8 ticks across height - y_tick_positions = range(0, tile_height, max(1, tile_height // 8)) + # Generate tick positions ensuring min and max are included + num_y_ticks = 8 + if tile_height > 1: + y_tick_positions = [0] # Always include minimum + if num_y_ticks > 2: + # Add intermediate positions + step = tile_height / (num_y_ticks - 1) + for i in range(1, num_y_ticks - 1): + y_tick_positions.append(int(i * step)) + y_tick_positions.append(tile_height - 1) # Always include maximum + else: + y_tick_positions = [0] + + # Calculate corresponding global coordinates y_global_coords = [y + pos * y_range // tile_height for pos in y_tick_positions] + # Ensure the last coordinate is exactly y_end + if len(y_global_coords) > 1: + y_global_coords[-1] = y_end + ax.set_yticks(y_tick_positions) ax.set_yticklabels([str(coord) for coord in y_global_coords], fontsize=8) From 1f2f71bbf219e8aebb8bf4c83970ecac9dcc2f71 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Tue, 15 Jul 2025 21:29:44 -0500 Subject: [PATCH 26/45] Remove logging. --- polygon_inference.py | 80 ++------------------------------------------ 1 file changed, 3 insertions(+), 77 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index e7025ce..666616d 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -489,16 +489,9 @@ def _validate_all_polygons( Returns: List[Dict[str, List[np.ndarray]]]: Updated tile results with validation attributes """ - log("Starting polygon validation process...") - # Initialize polygon_valid list for each tile - total_polygons = 0 - for tile_idx, tile_result in enumerate(tile_results): + for tile_result in tile_results: tile_result["polygon_valid"] = [True] * len(tile_result["polygons"]) - total_polygons += len(tile_result["polygons"]) - log(f"Tile {tile_idx}: {len(tile_result['polygons'])} polygons to validate") - - log(f"Total polygons to validate: {total_polygons}") # Remove overlapping polygons within each tile (before edge validation) @@ -533,15 +526,11 @@ def calculate_polygon_area(poly): except: return 0 - log("Phase 1: Removing overlapping polygons within each tile...") - overlap_removed_count = 0 - - for tile_idx, tile_result in enumerate(tile_results): + for tile_result in tile_results: polygons = tile_result["polygons"] polygon_valid = tile_result["polygon_valid"] if len(polygons) <= 1: - log(f"Tile {tile_idx}: Skipping overlap check (≤1 polygon)") continue # Skip tiles with 0 or 1 polygon # Keep iterating until no overlaps are found @@ -561,10 +550,8 @@ def calculate_polygon_area(poly): if check_polygon_overlap(poly1, poly2): overlapping_pairs.append((idx1, idx2)) - log(f"Tile {tile_idx}: Overlap detected between polygon {idx1} and polygon {idx2}") if not overlapping_pairs: - log(f"Tile {tile_idx}: No overlaps found, overlap validation complete") break # No overlaps found # Find all polygons involved in overlaps @@ -584,21 +571,10 @@ def calculate_polygon_area(poly): # Mark the largest polygon as invalid polygon_valid[largest_idx] = False - overlap_removed_count += 1 - log(f"Tile {tile_idx}: Marking polygon {largest_idx} as INVALID (largest overlapping polygon, area={largest_area:.2f})") - - # Log areas of all overlapping polygons for context - for idx, area in polygon_areas: - if idx != largest_idx: - log(f"Tile {tile_idx}: Polygon {idx} (area={area:.2f}) remains valid") # Continue to next iteration to check for remaining overlaps - log(f"Phase 1 complete: {overlap_removed_count} polygons removed due to overlaps") - # Now perform edge validation on remaining valid polygons - log("Phase 2: Validating polygons with boundary edges...") - edge_rejected_count = 0 def is_edge_near_tile_boundary(p1, p2, tile_bounds, tolerance=2): """Check if an edge is colinear with the tile boundary within tolerance.""" @@ -673,7 +649,7 @@ def point_in_polygon(point, polygon): return cv2.pointPolygonTest(poly_points, point, True) >= -merge_tolerance # Process each tile - for tile_idx, (tile_result, tile_pos) in enumerate(zip(tile_results, positions)): + for tile_result, tile_pos in zip(tile_results, positions): x, y, x_end, y_end = tile_pos tile_width = x_end - x tile_height = y_end - y @@ -682,19 +658,14 @@ def point_in_polygon(point, polygon): polygons = tile_result["polygons"] polygon_valid = tile_result["polygon_valid"] - valid_polygons_count = sum(polygon_valid) - log(f"Tile {tile_idx}: Processing {valid_polygons_count} valid polygons for edge validation") - # Check each polygon in this tile (only those still valid after overlap removal) for poly_idx, polygon in enumerate(polygons): # Skip polygons already rejected for overlap if not polygon_valid[poly_idx]: - log(f"Tile {tile_idx}, Polygon {poly_idx}: SKIPPED (already invalid from overlap check)") continue if len(polygon) < 3: polygon_valid[poly_idx] = False - log(f"Tile {tile_idx}, Polygon {poly_idx}: INVALID (< 3 vertices)") continue # Find edges that are near tile boundaries @@ -708,11 +679,8 @@ def point_in_polygon(point, polygon): # If no boundary edges, polygon is valid (not on tile boundary) if not boundary_edges: - log(f"Tile {tile_idx}, Polygon {poly_idx}: VALID (no boundary edges)") continue - log(f"Tile {tile_idx}, Polygon {poly_idx}: Found {len(boundary_edges)} boundary edges, validating...") - # Check sample points along boundary edges polygon_is_valid = True @@ -723,16 +691,12 @@ def point_in_polygon(point, polygon): is_horizontal_edge = abs(p1[1] - p2[1]) <= 2 # Edge is roughly horizontal is_vertical_edge = abs(p1[0] - p2[0]) <= 2 # Edge is roughly vertical - edge_type = "horizontal" if is_horizontal_edge else "vertical" if is_vertical_edge else "diagonal" - log(f"Tile {tile_idx}, Polygon {poly_idx}, Edge {edge_idx}: {edge_type} edge with {len(sample_points)} sample points") - # Convert sample points to global image coordinates global_sample_points = [(px + x, py + y) for px, py in sample_points] # Check if each sample point is contained in any polygon from other tiles for point_idx, global_point in enumerate(global_sample_points): point_found_in_other_polygon = False - match_info = None # Check all other tiles for other_tile_result, other_tile_pos in zip(tile_results, positions): @@ -760,20 +724,6 @@ def point_in_polygon(point, polygon): if point_in_polygon(local_point, other_polygon): point_found_in_other_polygon = True - # Store match location information for detailed logging - match_tile_idx = None - for search_tile_idx, (search_tile_result, search_tile_pos) in enumerate(zip(tile_results, positions)): - if search_tile_result is other_tile_result: - match_tile_idx = search_tile_idx - break - - match_info = { - 'tile_idx': match_tile_idx, - 'polygon_idx': other_poly_idx, - 'tile_bounds': (other_x, other_y, other_x_end, other_y_end), - 'local_point': local_point, - 'global_point': global_point - } break if point_found_in_other_polygon: @@ -782,37 +732,13 @@ def point_in_polygon(point, polygon): # If any sample point is not found in other polygons, mark as invalid if not point_found_in_other_polygon: polygon_is_valid = False - log(f"Tile {tile_idx}, Polygon {poly_idx}: Sample point {point_idx} at {global_point} NOT found in other polygons") break - else: - # Get the matched polygon points and convert to global coordinates - matched_polygon = tile_results[match_info['tile_idx']]["polygons"][match_info['polygon_idx']] - match_tile_bounds = match_info['tile_bounds'] - # Convert from local tile coordinates to global coordinates - global_polygon_points = matched_polygon + np.array([match_tile_bounds[0], match_tile_bounds[1]]) - # Format points as list of [x, y] coordinates for logging - polygon_points_str = '[' + ', '.join(f'[{x:.1f}, {y:.1f}]' for x, y in global_polygon_points) + ']' - - log(f"Tile {tile_idx}, Polygon {poly_idx}: Sample point {point_idx} at {global_point} found in " - f"Tile {match_info['tile_idx']}, Polygon {match_info['polygon_idx']} " - f"(polygon points: {polygon_points_str})") if not polygon_is_valid: break # Update polygon validity polygon_valid[poly_idx] = polygon_is_valid - - if polygon_is_valid: - log(f"Tile {tile_idx}, Polygon {poly_idx}: VALID (all boundary edge sample points found in other polygons)") - else: - edge_rejected_count += 1 - log(f"Tile {tile_idx}, Polygon {poly_idx}: INVALID (boundary edge validation failed)") - - # Log final validation summary - final_valid_count = sum(sum(tile_result["polygon_valid"]) for tile_result in tile_results) - log(f"Phase 2 complete: {edge_rejected_count} polygons removed due to edge validation") - log(f"Validation summary: {final_valid_count}/{total_polygons} polygons remain valid") return tile_results From b19347630762db76576ce41b28fbd9353b5bc0ea Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Tue, 15 Jul 2025 21:33:50 -0500 Subject: [PATCH 27/45] Restructure code. --- polygon_inference.py | 213 ++++++++++++++++++++++--------------------- 1 file changed, 109 insertions(+), 104 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index 666616d..427b201 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -36,6 +36,114 @@ from models.model import Encoder, Decoder, EncoderDecoder +def check_polygon_overlap(poly1, poly2): + """Check if two polygons overlap using Shapely.""" + try: + # Convert numpy arrays to Shapely polygons + if len(poly1) < 3 or len(poly2) < 3: + return False + + shapely_poly1 = Polygon(poly1) + shapely_poly2 = Polygon(poly2) + + # Check if polygons are valid + if not shapely_poly1.is_valid or not shapely_poly2.is_valid: + return False + + # Check for intersection (but not just touching) + return shapely_poly1.intersects(shapely_poly2) and not shapely_poly1.touches(shapely_poly2) + except: + return False + + +def calculate_polygon_area(poly): + """Calculate the area of a polygon.""" + try: + if len(poly) < 3: + return 0 + shapely_poly = Polygon(poly) + if not shapely_poly.is_valid: + return 0 + return shapely_poly.area + except: + return 0 + + +def is_edge_near_tile_boundary(p1, p2, tile_bounds, tolerance=2): + """Check if an edge is colinear with the tile boundary within tolerance.""" + x_min, y_min, x_max, y_max = tile_bounds + x1, y1 = p1 + x2, y2 = p2 + + # Check if edge is roughly horizontal and colinear with top boundary + if (abs(y1 - y_min) <= tolerance and abs(y2 - y_min) <= tolerance and + abs(y1 - y2) <= tolerance): + return True + + # Check if edge is roughly horizontal and colinear with bottom boundary + if (abs(y1 - y_max) <= tolerance and abs(y2 - y_max) <= tolerance and + abs(y1 - y2) <= tolerance): + return True + + # Check if edge is roughly vertical and colinear with left boundary + if (abs(x1 - x_min) <= tolerance and abs(x2 - x_min) <= tolerance and + abs(x1 - x2) <= tolerance): + return True + + # Check if edge is roughly vertical and colinear with right boundary + if (abs(x1 - x_max) <= tolerance and abs(x2 - x_max) <= tolerance and + abs(x1 - x2) <= tolerance): + return True + + return False + + +def generate_edge_sample_points(p1, p2, num_points=10, margin_px=10): + """Generate equally spaced points along an edge, leaving a fixed margin at each end. + Always generates at least one point in the center of the line.""" + # Calculate edge length + edge_length = math.sqrt((p2[0] - p1[0])**2 + (p2[1] - p1[1])**2) + + # Always generate center point + center_x = p1[0] + 0.5 * (p2[0] - p1[0]) + center_y = p1[1] + 0.5 * (p2[1] - p1[1]) + + # If edge is too short to accommodate margins, return just the center point + if edge_length <= 2 * margin_px: + return [(center_x, center_y)] + + # Calculate t values for the start and end of the usable region + t_start = margin_px / edge_length + t_end = 1.0 - margin_px / edge_length + + points = [] + + # If only one point requested, return center point + if num_points == 1: + return [(center_x, center_y)] + + # Generate points evenly spaced within the usable region + for i in range(num_points): + # Distribute points evenly within the usable region + t_local = i / (num_points - 1) + t = t_start + t_local * (t_end - t_start) + + x = p1[0] + t * (p2[0] - p1[0]) + y = p1[1] + t * (p2[1] - p1[1]) + points.append((x, y)) + + return points + + +def point_in_polygon(point, polygon, merge_tolerance): + """Check if a point is inside a polygon using OpenCV.""" + if len(polygon) < 3: + return False + # Convert polygon to the format expected by cv2.pointPolygonTest + poly_points = polygon.astype(np.float32).reshape((-1, 1, 2)) + return cv2.pointPolygonTest(poly_points, point, True) >= -merge_tolerance + + class PolygonInference: def __init__(self, experiment_path: str, device: Optional[str] = None) -> None: """Initialize the polygon inference with a trained model. @@ -495,37 +603,6 @@ def _validate_all_polygons( # Remove overlapping polygons within each tile (before edge validation) - def check_polygon_overlap(poly1, poly2): - """Check if two polygons overlap using Shapely.""" - try: - # Convert numpy arrays to Shapely polygons - if len(poly1) < 3 or len(poly2) < 3: - return False - - shapely_poly1 = Polygon(poly1) - shapely_poly2 = Polygon(poly2) - - # Check if polygons are valid - if not shapely_poly1.is_valid or not shapely_poly2.is_valid: - return False - - # Check for intersection (but not just touching) - return shapely_poly1.intersects(shapely_poly2) and not shapely_poly1.touches(shapely_poly2) - except: - return False - - def calculate_polygon_area(poly): - """Calculate the area of a polygon.""" - try: - if len(poly) < 3: - return 0 - shapely_poly = Polygon(poly) - if not shapely_poly.is_valid: - return 0 - return shapely_poly.area - except: - return 0 - for tile_result in tile_results: polygons = tile_result["polygons"] polygon_valid = tile_result["polygon_valid"] @@ -576,78 +653,6 @@ def calculate_polygon_area(poly): # Now perform edge validation on remaining valid polygons - def is_edge_near_tile_boundary(p1, p2, tile_bounds, tolerance=2): - """Check if an edge is colinear with the tile boundary within tolerance.""" - x_min, y_min, x_max, y_max = tile_bounds - x1, y1 = p1 - x2, y2 = p2 - - # Check if edge is roughly horizontal and colinear with top boundary - if (abs(y1 - y_min) <= tolerance and abs(y2 - y_min) <= tolerance and - abs(y1 - y2) <= tolerance): - return True - - # Check if edge is roughly horizontal and colinear with bottom boundary - if (abs(y1 - y_max) <= tolerance and abs(y2 - y_max) <= tolerance and - abs(y1 - y2) <= tolerance): - return True - - # Check if edge is roughly vertical and colinear with left boundary - if (abs(x1 - x_min) <= tolerance and abs(x2 - x_min) <= tolerance and - abs(x1 - x2) <= tolerance): - return True - - # Check if edge is roughly vertical and colinear with right boundary - if (abs(x1 - x_max) <= tolerance and abs(x2 - x_max) <= tolerance and - abs(x1 - x2) <= tolerance): - return True - - return False - - def generate_edge_sample_points(p1, p2, num_points=10, margin_px=10): - """Generate equally spaced points along an edge, leaving a fixed margin at each end. - Always generates at least one point in the center of the line.""" - # Calculate edge length - edge_length = math.sqrt((p2[0] - p1[0])**2 + (p2[1] - p1[1])**2) - - # Always generate center point - center_x = p1[0] + 0.5 * (p2[0] - p1[0]) - center_y = p1[1] + 0.5 * (p2[1] - p1[1]) - - # If edge is too short to accommodate margins, return just the center point - if edge_length <= 2 * margin_px: - return [(center_x, center_y)] - - # Calculate t values for the start and end of the usable region - t_start = margin_px / edge_length - t_end = 1.0 - margin_px / edge_length - - points = [] - - # If only one point requested, return center point - if num_points == 1: - return [(center_x, center_y)] - - # Generate points evenly spaced within the usable region - for i in range(num_points): - # Distribute points evenly within the usable region - t_local = i / (num_points - 1) - t = t_start + t_local * (t_end - t_start) - - x = p1[0] + t * (p2[0] - p1[0]) - y = p1[1] + t * (p2[1] - p1[1]) - points.append((x, y)) - - return points - - def point_in_polygon(point, polygon): - """Check if a point is inside a polygon using OpenCV.""" - if len(polygon) < 3: - return False - # Convert polygon to the format expected by cv2.pointPolygonTest - poly_points = polygon.astype(np.float32).reshape((-1, 1, 2)) - return cv2.pointPolygonTest(poly_points, point, True) >= -merge_tolerance - # Process each tile for tile_result, tile_pos in zip(tile_results, positions): x, y, x_end, y_end = tile_pos @@ -722,7 +727,7 @@ def point_in_polygon(point, polygon): if not other_tile_result["polygon_valid"][other_poly_idx]: continue - if point_in_polygon(local_point, other_polygon): + if point_in_polygon(local_point, other_polygon, merge_tolerance): point_found_in_other_polygon = True break From 5ab737d4c3c2751554e8b6241b9bbb82ecee7f81 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Tue, 15 Jul 2025 21:37:36 -0500 Subject: [PATCH 28/45] Remove unused variables. --- polygon_inference.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index 427b201..abacadc 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -4,13 +4,12 @@ import hashlib import pickle import copy -from typing import List, Tuple, Dict, Optional, Any +from typing import List, Tuple, Dict, Optional # Third-party imports import numpy as np import cv2 import torch -import torch.nn.functional as F import albumentations as A from albumentations.pytorch import ToTensorV2 from shapely.geometry import Polygon @@ -19,7 +18,6 @@ import geopandas as gpd import matplotlib.pyplot as plt -import matplotlib.patches as patches import math # Local imports @@ -689,7 +687,7 @@ def _validate_all_polygons( # Check sample points along boundary edges polygon_is_valid = True - for edge_idx, (p1, p2) in enumerate(boundary_edges): + for p1, p2 in boundary_edges: sample_points = generate_edge_sample_points(p1, p2) # Determine if this edge is horizontal or vertical @@ -700,7 +698,7 @@ def _validate_all_polygons( global_sample_points = [(px + x, py + y) for px, py in sample_points] # Check if each sample point is contained in any polygon from other tiles - for point_idx, global_point in enumerate(global_sample_points): + for global_point in global_sample_points: point_found_in_other_polygon = False # Check all other tiles @@ -777,11 +775,11 @@ def _merge_polygons( bitmap = np.zeros((image_height * scale_factor, image_width * scale_factor), dtype=np.uint8) # Fill bitmap with polygon regions - for tile_idx, (tile_result, (x, y, x_end, y_end)) in enumerate(zip(tile_results, positions)): + for tile_result, (x, y, x_end, y_end) in zip(tile_results, positions): tile_polygons = tile_result["polygons"] polygon_valid = tile_result["polygon_valid"] - for poly_idx, (poly, is_valid) in enumerate(zip(tile_polygons, polygon_valid)): + for poly, is_valid in zip(tile_polygons, polygon_valid): # Skip invalid polygons if not is_valid: continue From 0cd1b2481cc56b7b2fd695917f51b54d07f91443 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Tue, 15 Jul 2025 21:40:21 -0500 Subject: [PATCH 29/45] Make output filenames consistent. --- polygon_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index abacadc..112cc17 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -806,8 +806,8 @@ def _merge_polygons( # Save bitmap for debugging (optional) if debug: - cv2.imwrite('debug_polygon_bitmap.png', bitmap) - log("Saved debug bitmap to debug_polygon_bitmap.png") + cv2.imwrite('visualization-bitmap.png', bitmap) + log("Saved debug bitmap to visualization-bitmap.png") # Find contours in the bitmap contours, _ = cv2.findContours(bitmap, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) From ed5c3aa6c0cb54c7112f3cec13d7592550d2d3ea Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Tue, 15 Jul 2025 21:59:58 -0500 Subject: [PATCH 30/45] Restructuring. --- polygon_inference.py | 229 ++++++++++++++++++++++--------------------- 1 file changed, 119 insertions(+), 110 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index 112cc17..656db92 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -34,105 +34,6 @@ from models.model import Encoder, Decoder, EncoderDecoder -def check_polygon_overlap(poly1, poly2): - """Check if two polygons overlap using Shapely.""" - try: - # Convert numpy arrays to Shapely polygons - if len(poly1) < 3 or len(poly2) < 3: - return False - - shapely_poly1 = Polygon(poly1) - shapely_poly2 = Polygon(poly2) - - # Check if polygons are valid - if not shapely_poly1.is_valid or not shapely_poly2.is_valid: - return False - - # Check for intersection (but not just touching) - return shapely_poly1.intersects(shapely_poly2) and not shapely_poly1.touches(shapely_poly2) - except: - return False - - -def calculate_polygon_area(poly): - """Calculate the area of a polygon.""" - try: - if len(poly) < 3: - return 0 - shapely_poly = Polygon(poly) - if not shapely_poly.is_valid: - return 0 - return shapely_poly.area - except: - return 0 - - -def is_edge_near_tile_boundary(p1, p2, tile_bounds, tolerance=2): - """Check if an edge is colinear with the tile boundary within tolerance.""" - x_min, y_min, x_max, y_max = tile_bounds - x1, y1 = p1 - x2, y2 = p2 - - # Check if edge is roughly horizontal and colinear with top boundary - if (abs(y1 - y_min) <= tolerance and abs(y2 - y_min) <= tolerance and - abs(y1 - y2) <= tolerance): - return True - - # Check if edge is roughly horizontal and colinear with bottom boundary - if (abs(y1 - y_max) <= tolerance and abs(y2 - y_max) <= tolerance and - abs(y1 - y2) <= tolerance): - return True - - # Check if edge is roughly vertical and colinear with left boundary - if (abs(x1 - x_min) <= tolerance and abs(x2 - x_min) <= tolerance and - abs(x1 - x2) <= tolerance): - return True - - # Check if edge is roughly vertical and colinear with right boundary - if (abs(x1 - x_max) <= tolerance and abs(x2 - x_max) <= tolerance and - abs(x1 - x2) <= tolerance): - return True - - return False - - -def generate_edge_sample_points(p1, p2, num_points=10, margin_px=10): - """Generate equally spaced points along an edge, leaving a fixed margin at each end. - Always generates at least one point in the center of the line.""" - # Calculate edge length - edge_length = math.sqrt((p2[0] - p1[0])**2 + (p2[1] - p1[1])**2) - - # Always generate center point - center_x = p1[0] + 0.5 * (p2[0] - p1[0]) - center_y = p1[1] + 0.5 * (p2[1] - p1[1]) - - # If edge is too short to accommodate margins, return just the center point - if edge_length <= 2 * margin_px: - return [(center_x, center_y)] - - # Calculate t values for the start and end of the usable region - t_start = margin_px / edge_length - t_end = 1.0 - margin_px / edge_length - - points = [] - - # If only one point requested, return center point - if num_points == 1: - return [(center_x, center_y)] - - # Generate points evenly spaced within the usable region - for i in range(num_points): - # Distribute points evenly within the usable region - t_local = i / (num_points - 1) - t = t_start + t_local * (t_end - t_start) - - x = p1[0] + t * (p2[0] - p1[0]) - y = p1[1] + t * (p2[1] - p1[1]) - points.append((x, y)) - - return points - - def point_in_polygon(point, polygon, merge_tolerance): """Check if a point is inside a polygon using OpenCV.""" if len(polygon) < 3: @@ -155,6 +56,8 @@ def __init__(self, experiment_path: str, device: Optional[str] = None) -> None: self.model: Optional[EncoderDecoder] = None self.tokenizer: Optional[Tokenizer] = None self.cache_dir: str = "/tmp/pix2poly_cache" + # Extract descriptive model name from experiment path (e.g., "Pix2Poly_inria_coco_224") + self.model_display_name: str = os.path.basename(self.experiment_path) self._ensure_cache_dir() self._initialize_model() @@ -163,7 +66,7 @@ def _ensure_cache_dir(self) -> None: os.makedirs(self.cache_dir, exist_ok=True) def _generate_cache_key(self, tiles: List[np.ndarray]) -> str: - """Generate a cache key based on the input tiles. + """Generate a cache key based on the input tiles and model. Args: tiles (List[np.ndarray]): List of tile images @@ -171,8 +74,10 @@ def _generate_cache_key(self, tiles: List[np.ndarray]) -> str: Returns: str: Hash-based cache key """ - # Create a hash based on all tile data + # Create a hash based on all tile data and model identifier hasher = hashlib.sha256() + # Include model experiment path to make cache model-specific + hasher.update(self.experiment_path.encode('utf-8')) for tile in tiles: hasher.update(tile.tobytes()) return hasher.hexdigest() @@ -225,6 +130,101 @@ def _save_to_cache(self, cache_key: str, results: List[Dict[str, List[np.ndarray except Exception as e: log(f"Failed to save cache to {cache_path}: {e}") + def _check_polygon_overlap(self, poly1, poly2): + """Check if two polygons overlap using Shapely.""" + try: + # Convert numpy arrays to Shapely polygons + if len(poly1) < 3 or len(poly2) < 3: + return False + + shapely_poly1 = Polygon(poly1) + shapely_poly2 = Polygon(poly2) + + # Check if polygons are valid + if not shapely_poly1.is_valid or not shapely_poly2.is_valid: + return False + + # Check for intersection (but not just touching) + return shapely_poly1.intersects(shapely_poly2) and not shapely_poly1.touches(shapely_poly2) + except: + return False + + def _calculate_polygon_area(self, poly): + """Calculate the area of a polygon.""" + try: + if len(poly) < 3: + return 0 + shapely_poly = Polygon(poly) + if not shapely_poly.is_valid: + return 0 + return shapely_poly.area + except: + return 0 + + def _is_edge_near_tile_boundary(self, p1, p2, tile_bounds, tolerance=2): + """Check if an edge is colinear with the tile boundary within tolerance.""" + x_min, y_min, x_max, y_max = tile_bounds + x1, y1 = p1 + x2, y2 = p2 + + # Check if edge is roughly horizontal and colinear with top boundary + if (abs(y1 - y_min) <= tolerance and abs(y2 - y_min) <= tolerance and + abs(y1 - y2) <= tolerance): + return True + + # Check if edge is roughly horizontal and colinear with bottom boundary + if (abs(y1 - y_max) <= tolerance and abs(y2 - y_max) <= tolerance and + abs(y1 - y2) <= tolerance): + return True + + # Check if edge is roughly vertical and colinear with left boundary + if (abs(x1 - x_min) <= tolerance and abs(x2 - x_min) <= tolerance and + abs(x1 - x2) <= tolerance): + return True + + # Check if edge is roughly vertical and colinear with right boundary + if (abs(x1 - x_max) <= tolerance and abs(x2 - x_max) <= tolerance and + abs(x1 - x2) <= tolerance): + return True + + return False + + def _generate_edge_sample_points(self, p1, p2, num_points=10, margin_px=10): + """Generate equally spaced points along an edge, leaving a fixed margin at each end. + Always generates at least one point in the center of the line.""" + # Calculate edge length + edge_length = math.sqrt((p2[0] - p1[0])**2 + (p2[1] - p1[1])**2) + + # Always generate center point + center_x = p1[0] + 0.5 * (p2[0] - p1[0]) + center_y = p1[1] + 0.5 * (p2[1] - p1[1]) + + # If edge is too short to accommodate margins, return just the center point + if edge_length <= 2 * margin_px: + return [(center_x, center_y)] + + # Calculate t values for the start and end of the usable region + t_start = margin_px / edge_length + t_end = 1.0 - margin_px / edge_length + + points = [] + + # If only one point requested, return center point + if num_points == 1: + return [(center_x, center_y)] + + # Generate points evenly spaced within the usable region + for i in range(num_points): + # Distribute points evenly within the usable region + t_local = i / (num_points - 1) + t = t_start + t_local * (t_end - t_start) + + x = p1[0] + t * (p2[0] - p1[0]) + y = p1[1] + t * (p2[1] - p1[1]) + points.append((x, y)) + + return points + def _initialize_model(self) -> None: """Initialize the model and tokenizer. @@ -567,7 +567,13 @@ def _create_tile_visualization( ha='center', va='center', zorder=6, bbox=dict(boxstyle='round,pad=0.3', facecolor=outline_color, alpha=0.7)) - plt.tight_layout() + # Leave space at the bottom for the model name + plt.tight_layout(rect=[0, 0.05, 1, 1]) + + # Add model name at the bottom of the visualization + plt.figtext(0.5, 0.01, f'Model: {self.model_display_name}', + ha='center', va='bottom') + plt.savefig('tile-visualization.png', dpi=150, bbox_inches='tight') plt.close() log(f"Saved tile visualization to tile-visualization.png") @@ -623,7 +629,7 @@ def _validate_all_polygons( idx1, poly1 = valid_polygons[i] idx2, poly2 = valid_polygons[j] - if check_polygon_overlap(poly1, poly2): + if self._check_polygon_overlap(poly1, poly2): overlapping_pairs.append((idx1, idx2)) if not overlapping_pairs: @@ -638,7 +644,7 @@ def _validate_all_polygons( # Calculate areas for overlapping polygons polygon_areas = [] for idx in overlapping_indices: - area = calculate_polygon_area(polygons[idx]) + area = self._calculate_polygon_area(polygons[idx]) polygon_areas.append((idx, area)) # Find the largest polygon @@ -677,7 +683,7 @@ def _validate_all_polygons( p1 = polygon[i] p2 = polygon[i + 1] - if is_edge_near_tile_boundary(p1, p2, tile_bounds): + if self._is_edge_near_tile_boundary(p1, p2, tile_bounds): boundary_edges.append((p1, p2)) # If no boundary edges, polygon is valid (not on tile boundary) @@ -688,7 +694,7 @@ def _validate_all_polygons( polygon_is_valid = True for p1, p2 in boundary_edges: - sample_points = generate_edge_sample_points(p1, p2) + sample_points = self._generate_edge_sample_points(p1, p2) # Determine if this edge is horizontal or vertical is_horizontal_edge = abs(p1[1] - p2[1]) <= 2 # Edge is roughly horizontal @@ -806,8 +812,8 @@ def _merge_polygons( # Save bitmap for debugging (optional) if debug: - cv2.imwrite('visualization-bitmap.png', bitmap) - log("Saved debug bitmap to visualization-bitmap.png") + cv2.imwrite('bitmap-visualization.png', bitmap) + log("Saved debug bitmap to bitmap-visualization.png") # Find contours in the bitmap contours, _ = cv2.findContours(bitmap, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) @@ -847,6 +853,7 @@ def _merge_polygons( # Create single GeoDataFrame with all polygons and regularize them all at once if shapely_polygons: + log(f"Regularizing {len(shapely_polygons)} polygons") gdf = gpd.GeoDataFrame({'geometry': shapely_polygons}) regularized_gdf = regularize_geodataframe(gdf, simplify_tolerance=20, parallel_threshold=100) @@ -862,7 +869,7 @@ def _merge_polygons( merged_polygons.append(polygon_coords) - log(f"Bitmap approach: {len(merged_polygons)} polygons extracted from bitmap") + log(f"Polygons extracted: {len(merged_polygons)}") return merged_polygons def infer(self, image_data: bytes, debug: bool = False, merge_tolerance: Optional[float] = None, tile_overlap_ratio: Optional[float] = None) -> List[List[List[float]]]: @@ -940,7 +947,9 @@ def infer(self, image_data: bytes, debug: bool = False, merge_tolerance: Optiona all_results.extend(batch_results) batch_time = time.time() - batch_start_time - log(f"Processed batch of {len(batch_tiles)} tiles: {batch_time/len(batch_tiles):.3f}s per tile") + tiles_processed_so_far = i + len(batch_tiles) + total_tiles = len(tiles) + log(f"Processed batch of {len(batch_tiles)} tiles ({tiles_processed_so_far}/{total_tiles}): {batch_time/len(batch_tiles):.3f}s per tile") # Validate all polygons and add validation attributes all_results = self._validate_all_polygons(all_results, bboxes, height, width, effective_merge_tolerance) From 1c789a4eec2c31210ce063e27d3513f37de88183 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Tue, 15 Jul 2025 22:05:39 -0500 Subject: [PATCH 31/45] Improve logging. --- polygon_inference.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index 656db92..4573be2 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -345,12 +345,10 @@ def _process_tiles_batch( if cached_results is not None: log(f"Cache hit for batch of {len(tiles)} tiles") return cached_results - - log(f"Cache miss for batch of {len(tiles)} tiles, processing...") else: - log(f"Processing batch of {len(tiles)} tiles (caching disabled)...") cache_key = None + log(f"Processing batch of {len(tiles)} tiles...") valid_transforms = A.Compose( [ A.Resize(height=CFG.INPUT_HEIGHT, width=CFG.INPUT_WIDTH), @@ -848,7 +846,6 @@ def _merge_polygons( else: log(f"Skipping invalid polygon") - # Initialize merged_polygons to avoid NameError if no valid contours are found merged_polygons: List[np.ndarray] = [] # Create single GeoDataFrame with all polygons and regularize them all at once @@ -954,6 +951,8 @@ def infer(self, image_data: bytes, debug: bool = False, merge_tolerance: Optiona # Validate all polygons and add validation attributes all_results = self._validate_all_polygons(all_results, bboxes, height, width, effective_merge_tolerance) + log(f"Validated {sum(sum(tile_result['polygon_valid']) for tile_result in all_results)} out of {sum(len(tile_result['polygons']) for tile_result in all_results)} polygons") + # Create tile visualization if debug: self._create_tile_visualization(tiles, all_results, bboxes) From 1850c85e635386b38a6a54be36af683315dff1fc Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Tue, 15 Jul 2025 22:09:54 -0500 Subject: [PATCH 32/45] Improve logging. --- polygon_inference.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index 4573be2..7c24e13 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -348,6 +348,8 @@ def _process_tiles_batch( else: cache_key = None + # Start timing for actual processing + batch_start_time = time.time() log(f"Processing batch of {len(tiles)} tiles...") valid_transforms = A.Compose( [ @@ -412,6 +414,10 @@ def _process_tiles_batch( if debug and cache_key is not None: self._save_to_cache(cache_key, results) + # Log processing time per tile + batch_time = time.time() - batch_start_time + log(f"Batch processing time: {batch_time/len(tiles):.3f}s per tile") + return results def _create_tile_visualization( @@ -811,7 +817,7 @@ def _merge_polygons( # Save bitmap for debugging (optional) if debug: cv2.imwrite('bitmap-visualization.png', bitmap) - log("Saved debug bitmap to bitmap-visualization.png") + log("Saved bitmap visualization to bitmap-visualization.png") # Find contours in the bitmap contours, _ = cv2.findContours(bitmap, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) @@ -938,15 +944,13 @@ def infer(self, image_data: bytes, debug: bool = False, merge_tolerance: Optiona all_results: List[Dict[str, List[np.ndarray]]] = [] for i in range(0, len(tiles), CFG.PREDICTION_BATCH_SIZE): - batch_start_time = time.time() batch_tiles = tiles[i : i + CFG.PREDICTION_BATCH_SIZE] batch_results = self._process_tiles_batch(batch_tiles, debug) all_results.extend(batch_results) - batch_time = time.time() - batch_start_time tiles_processed_so_far = i + len(batch_tiles) total_tiles = len(tiles) - log(f"Processed batch of {len(batch_tiles)} tiles ({tiles_processed_so_far}/{total_tiles}): {batch_time/len(batch_tiles):.3f}s per tile") + log(f"Processed batch of {len(batch_tiles)} tiles ({tiles_processed_so_far}/{total_tiles})") # Validate all polygons and add validation attributes all_results = self._validate_all_polygons(all_results, bboxes, height, width, effective_merge_tolerance) From 753c1f5d7f2834112e8b0380306cbf6225ab900b Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Tue, 15 Jul 2025 22:17:10 -0500 Subject: [PATCH 33/45] Improve logging. --- infer_single_image.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/infer_single_image.py b/infer_single_image.py index 49eddaa..5454619 100644 --- a/infer_single_image.py +++ b/infer_single_image.py @@ -63,6 +63,8 @@ def main(): plt.savefig(output_path, dpi=100, bbox_inches='tight', pad_inches=0) plt.close() + log(f"Saved main visualization to {output_path}") + # Print polygons to stdout print(json.dumps(polygons_list)) From e22b6aceb6aca6701e0baa8b5ab7368368af5e75 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Tue, 15 Jul 2025 22:23:55 -0500 Subject: [PATCH 34/45] Performance improvement. --- polygon_inference.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index 7c24e13..099f82e 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -784,18 +784,23 @@ def _merge_polygons( # Create bitmap at 8x resolution for subpixel precision bitmap = np.zeros((image_height * scale_factor, image_width * scale_factor), dtype=np.uint8) - # Fill bitmap with polygon regions + # Collect all valid polygons for batch processing + all_polygon_coords = [] + for tile_result, (x, y, x_end, y_end) in zip(tile_results, positions): tile_polygons = tile_result["polygons"] polygon_valid = tile_result["polygon_valid"] + # Pre-allocate translation vector for this tile + translation_vector = np.array([x, y]) + for poly, is_valid in zip(tile_polygons, polygon_valid): # Skip invalid polygons if not is_valid: continue # Transform polygon from tile coordinates to image coordinates - transformed_poly = poly + np.array([x, y]) + transformed_poly = poly + translation_vector # Scale up coordinates for high-resolution bitmap scaled_poly = transformed_poly * scale_factor @@ -806,9 +811,11 @@ def _merge_polygons( # Convert to integer coordinates for rasterization poly_coords = scaled_poly.astype(np.int32) - - # Fill the polygon region in the bitmap - cv2.fillPoly(bitmap, [poly_coords], 255) + all_polygon_coords.append(poly_coords) + + # Fill all polygons at once - much more efficient than individual calls + if all_polygon_coords: + cv2.fillPoly(bitmap, all_polygon_coords, 255) kernel_size = 32 kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) From c0ddf259e78d058569e32fd9588a0fd7fe3e2ffc Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Tue, 15 Jul 2025 22:28:29 -0500 Subject: [PATCH 35/45] Documentation. --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index 48ba574..deeb3f9 100644 --- a/README.md +++ b/README.md @@ -231,7 +231,6 @@ The API returns JSON with the detected polygons: You can customize the Docker container behavior with these environment variables: - `MODEL_URL`: URL to download the pretrained model files (default: `https://github.com/safelease/Pix2Poly/releases/download/main/runs_share.zip`) -- `EXPERIMENT_PATH`: Path to the experiment folder (default: `runs_share/Pix2Poly_inria_coco_224`) - `API_KEY`: Optional API key for authentication (if not set, authentication is disabled) Example with custom configuration: From cc9de47a9223d9753d5f258eb54ce313e3d06212 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Tue, 15 Jul 2025 22:28:45 -0500 Subject: [PATCH 36/45] Documentation. --- README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/README.md b/README.md index deeb3f9..8f71e20 100644 --- a/README.md +++ b/README.md @@ -237,7 +237,6 @@ Example with custom configuration: ```bash docker run -p 8080:8080 \ -e MODEL_URL=https://github.com/safelease/Pix2Poly/releases/download/main/runs_share.zip \ - -e EXPERIMENT_PATH=runs_share/Pix2Poly_inria_coco_224 \ -e API_KEY=your_secret_key \ pix2poly ``` From 4449c025df0a889208f343ced4f95bb7c2370cbd Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Tue, 15 Jul 2025 22:32:35 -0500 Subject: [PATCH 37/45] Regression. --- polygon_inference.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index 099f82e..d7f153a 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -784,9 +784,7 @@ def _merge_polygons( # Create bitmap at 8x resolution for subpixel precision bitmap = np.zeros((image_height * scale_factor, image_width * scale_factor), dtype=np.uint8) - # Collect all valid polygons for batch processing - all_polygon_coords = [] - + # Process all valid polygons and fill them immediately for tile_result, (x, y, x_end, y_end) in zip(tile_results, positions): tile_polygons = tile_result["polygons"] polygon_valid = tile_result["polygon_valid"] @@ -811,11 +809,9 @@ def _merge_polygons( # Convert to integer coordinates for rasterization poly_coords = scaled_poly.astype(np.int32) - all_polygon_coords.append(poly_coords) - - # Fill all polygons at once - much more efficient than individual calls - if all_polygon_coords: - cv2.fillPoly(bitmap, all_polygon_coords, 255) + + # Fill polygon immediately to avoid winding order issues + cv2.fillPoly(bitmap, [poly_coords], 255) kernel_size = 32 kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) From 8ebcd2c2b10d593b1c41831698f22ba876bc1975 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Tue, 15 Jul 2025 22:46:51 -0500 Subject: [PATCH 38/45] Restructuring. --- polygon_inference.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/polygon_inference.py b/polygon_inference.py index d7f153a..a760170 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -225,6 +225,14 @@ def _generate_edge_sample_points(self, p1, p2, num_points=10, margin_px=10): return points + def _point_in_polygon(self, point, polygon, merge_tolerance): + """Check if a point is inside a polygon using OpenCV.""" + if len(polygon) < 3: + return False + # Convert polygon to the format expected by cv2.pointPolygonTest + poly_points = polygon.astype(np.float32).reshape((-1, 1, 2)) + return cv2.pointPolygonTest(poly_points, point, True) >= -merge_tolerance + def _initialize_model(self) -> None: """Initialize the model and tokenizer. @@ -735,7 +743,7 @@ def _validate_all_polygons( if not other_tile_result["polygon_valid"][other_poly_idx]: continue - if point_in_polygon(local_point, other_polygon, merge_tolerance): + if self._point_in_polygon(local_point, other_polygon, merge_tolerance): point_found_in_other_polygon = True break From 4ba35a17e3125d3918cfc0d9a43b224a410d83e4 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Tue, 15 Jul 2025 22:47:26 -0500 Subject: [PATCH 39/45] Restructuring. --- polygon_inference.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index a760170..183809b 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -34,15 +34,6 @@ from models.model import Encoder, Decoder, EncoderDecoder -def point_in_polygon(point, polygon, merge_tolerance): - """Check if a point is inside a polygon using OpenCV.""" - if len(polygon) < 3: - return False - # Convert polygon to the format expected by cv2.pointPolygonTest - poly_points = polygon.astype(np.float32).reshape((-1, 1, 2)) - return cv2.pointPolygonTest(poly_points, point, True) >= -merge_tolerance - - class PolygonInference: def __init__(self, experiment_path: str, device: Optional[str] = None) -> None: """Initialize the polygon inference with a trained model. From 9574de9066c13818d291432121b66221d61f9237 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Tue, 15 Jul 2025 22:53:04 -0500 Subject: [PATCH 40/45] Add types. --- polygon_inference.py | 398 ++++++++++++++++++++++++++----------------- 1 file changed, 244 insertions(+), 154 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index 183809b..169b922 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -4,10 +4,11 @@ import hashlib import pickle import copy -from typing import List, Tuple, Dict, Optional +from typing import List, Tuple, Dict, Optional, Union, Any # Third-party imports import numpy as np +import numpy.typing as npt import cv2 import torch import albumentations as A @@ -33,6 +34,13 @@ ) from models.model import Encoder, Decoder, EncoderDecoder +# Type aliases for better readability +PolygonArray = npt.NDArray[np.floating[Any]] +TilePosition = Tuple[int, int, int, int] # (x, y, x_end, y_end) +TileResult = Dict[str, Union[List[PolygonArray], List[bool]]] +Point2D = Tuple[float, float] +BoundingBox = Tuple[float, float, float, float] # (x_min, y_min, x_max, y_max) + class PolygonInference: def __init__(self, experiment_path: str, device: Optional[str] = None) -> None: @@ -49,6 +57,7 @@ def __init__(self, experiment_path: str, device: Optional[str] = None) -> None: self.cache_dir: str = "/tmp/pix2poly_cache" # Extract descriptive model name from experiment path (e.g., "Pix2Poly_inria_coco_224") self.model_display_name: str = os.path.basename(self.experiment_path) + self.model_cfg: Optional[Any] = None # Store adapted configuration self._ensure_cache_dir() self._initialize_model() @@ -56,17 +65,17 @@ def _ensure_cache_dir(self) -> None: """Ensure the cache directory exists.""" os.makedirs(self.cache_dir, exist_ok=True) - def _generate_cache_key(self, tiles: List[np.ndarray]) -> str: + def _generate_cache_key(self, tiles: List[npt.NDArray[np.uint8]]) -> str: """Generate a cache key based on the input tiles and model. Args: - tiles (List[np.ndarray]): List of tile images + tiles (List[npt.NDArray[np.uint8]]): List of tile images Returns: str: Hash-based cache key """ # Create a hash based on all tile data and model identifier - hasher = hashlib.sha256() + hasher: hashlib.sha256 = hashlib.sha256() # Include model experiment path to make cache model-specific hasher.update(self.experiment_path.encode('utf-8')) for tile in tiles: @@ -84,20 +93,21 @@ def _get_cache_path(self, cache_key: str) -> str: """ return os.path.join(self.cache_dir, f"{cache_key}.pkl") - def _load_from_cache(self, cache_key: str) -> Optional[List[Dict[str, List[np.ndarray]]]]: + def _load_from_cache(self, cache_key: str) -> Optional[List[TileResult]]: """Load results from cache if they exist. Args: cache_key (str): The cache key to look for Returns: - Optional[List[Dict[str, List[np.ndarray]]]]: Cached results if found, None otherwise + Optional[List[TileResult]]: Cached results if found, None otherwise """ - cache_path = self._get_cache_path(cache_key) + cache_path: str = self._get_cache_path(cache_key) if os.path.exists(cache_path): try: with open(cache_path, 'rb') as f: - return pickle.load(f) + cached_data: Any = pickle.load(f) + return cached_data except Exception as e: log(f"Failed to load cache from {cache_path}: {e}") # Remove corrupted cache file @@ -107,29 +117,37 @@ def _load_from_cache(self, cache_key: str) -> Optional[List[Dict[str, List[np.nd pass return None - def _save_to_cache(self, cache_key: str, results: List[Dict[str, List[np.ndarray]]]) -> None: + def _save_to_cache(self, cache_key: str, results: List[TileResult]) -> None: """Save results to cache. Args: cache_key (str): The cache key - results (List[Dict[str, List[np.ndarray]]]): Results to cache + results (List[TileResult]): Results to cache """ - cache_path = self._get_cache_path(cache_key) + cache_path: str = self._get_cache_path(cache_key) try: with open(cache_path, 'wb') as f: pickle.dump(results, f) except Exception as e: log(f"Failed to save cache to {cache_path}: {e}") - def _check_polygon_overlap(self, poly1, poly2): - """Check if two polygons overlap using Shapely.""" + def _check_polygon_overlap(self, poly1: PolygonArray, poly2: PolygonArray) -> bool: + """Check if two polygons overlap using Shapely. + + Args: + poly1 (PolygonArray): First polygon as array of [x, y] coordinates + poly2 (PolygonArray): Second polygon as array of [x, y] coordinates + + Returns: + bool: True if polygons overlap (intersect but don't just touch) + """ try: # Convert numpy arrays to Shapely polygons if len(poly1) < 3 or len(poly2) < 3: return False - shapely_poly1 = Polygon(poly1) - shapely_poly2 = Polygon(poly2) + shapely_poly1: Polygon = Polygon(poly1) + shapely_poly2: Polygon = Polygon(poly2) # Check if polygons are valid if not shapely_poly1.is_valid or not shapely_poly2.is_valid: @@ -140,20 +158,43 @@ def _check_polygon_overlap(self, poly1, poly2): except: return False - def _calculate_polygon_area(self, poly): - """Calculate the area of a polygon.""" + def _calculate_polygon_area(self, poly: PolygonArray) -> float: + """Calculate the area of a polygon. + + Args: + poly (PolygonArray): Polygon as array of [x, y] coordinates + + Returns: + float: Area of the polygon, 0 if invalid + """ try: if len(poly) < 3: - return 0 - shapely_poly = Polygon(poly) + return 0.0 + shapely_poly: Polygon = Polygon(poly) if not shapely_poly.is_valid: - return 0 - return shapely_poly.area + return 0.0 + return float(shapely_poly.area) except: - return 0 + return 0.0 - def _is_edge_near_tile_boundary(self, p1, p2, tile_bounds, tolerance=2): - """Check if an edge is colinear with the tile boundary within tolerance.""" + def _is_edge_near_tile_boundary( + self, + p1: Point2D, + p2: Point2D, + tile_bounds: BoundingBox, + tolerance: float = 2.0 + ) -> bool: + """Check if an edge is colinear with the tile boundary within tolerance. + + Args: + p1 (Point2D): First point of the edge + p2 (Point2D): Second point of the edge + tile_bounds (BoundingBox): Tile boundaries as (x_min, y_min, x_max, y_max) + tolerance (float): Tolerance for boundary detection in pixels + + Returns: + bool: True if edge is near a tile boundary + """ x_min, y_min, x_max, y_max = tile_bounds x1, y1 = p1 x2, y2 = p2 @@ -180,25 +221,41 @@ def _is_edge_near_tile_boundary(self, p1, p2, tile_bounds, tolerance=2): return False - def _generate_edge_sample_points(self, p1, p2, num_points=10, margin_px=10): + def _generate_edge_sample_points( + self, + p1: Point2D, + p2: Point2D, + num_points: int = 10, + margin_px: float = 10.0 + ) -> List[Point2D]: """Generate equally spaced points along an edge, leaving a fixed margin at each end. - Always generates at least one point in the center of the line.""" + Always generates at least one point in the center of the line. + + Args: + p1 (Point2D): Start point of the edge + p2 (Point2D): End point of the edge + num_points (int): Number of sample points to generate + margin_px (float): Margin in pixels to leave at each end + + Returns: + List[Point2D]: List of sample points along the edge + """ # Calculate edge length - edge_length = math.sqrt((p2[0] - p1[0])**2 + (p2[1] - p1[1])**2) + edge_length: float = math.sqrt((p2[0] - p1[0])**2 + (p2[1] - p1[1])**2) # Always generate center point - center_x = p1[0] + 0.5 * (p2[0] - p1[0]) - center_y = p1[1] + 0.5 * (p2[1] - p1[1]) + center_x: float = p1[0] + 0.5 * (p2[0] - p1[0]) + center_y: float = p1[1] + 0.5 * (p2[1] - p1[1]) # If edge is too short to accommodate margins, return just the center point if edge_length <= 2 * margin_px: return [(center_x, center_y)] # Calculate t values for the start and end of the usable region - t_start = margin_px / edge_length - t_end = 1.0 - margin_px / edge_length + t_start: float = margin_px / edge_length + t_end: float = 1.0 - margin_px / edge_length - points = [] + points: List[Point2D] = [] # If only one point requested, return center point if num_points == 1: @@ -207,22 +264,37 @@ def _generate_edge_sample_points(self, p1, p2, num_points=10, margin_px=10): # Generate points evenly spaced within the usable region for i in range(num_points): # Distribute points evenly within the usable region - t_local = i / (num_points - 1) - t = t_start + t_local * (t_end - t_start) + t_local: float = i / (num_points - 1) + t: float = t_start + t_local * (t_end - t_start) - x = p1[0] + t * (p2[0] - p1[0]) - y = p1[1] + t * (p2[1] - p1[1]) + x: float = p1[0] + t * (p2[0] - p1[0]) + y: float = p1[1] + t * (p2[1] - p1[1]) points.append((x, y)) return points - def _point_in_polygon(self, point, polygon, merge_tolerance): - """Check if a point is inside a polygon using OpenCV.""" + def _point_in_polygon( + self, + point: Point2D, + polygon: PolygonArray, + merge_tolerance: float + ) -> bool: + """Check if a point is inside a polygon using OpenCV. + + Args: + point (Point2D): Point to test + polygon (PolygonArray): Polygon as array of [x, y] coordinates + merge_tolerance (float): Tolerance for the point-in-polygon test + + Returns: + bool: True if point is inside the polygon within tolerance + """ if len(polygon) < 3: return False # Convert polygon to the format expected by cv2.pointPolygonTest - poly_points = polygon.astype(np.float32).reshape((-1, 1, 2)) - return cv2.pointPolygonTest(poly_points, point, True) >= -merge_tolerance + poly_points: npt.NDArray[np.float32] = polygon.astype(np.float32).reshape((-1, 1, 2)) + distance: float = cv2.pointPolygonTest(poly_points, point, True) + return distance >= -merge_tolerance def _initialize_model(self) -> None: """Initialize the model and tokenizer. @@ -235,32 +307,32 @@ def _initialize_model(self) -> None: 5. Loads the checkpoint weights """ # Load checkpoint first to inspect saved model configuration - latest_checkpoint = self._find_single_checkpoint() - checkpoint_path = os.path.join( + latest_checkpoint: str = self._find_single_checkpoint() + checkpoint_path: str = os.path.join( self.experiment_path, "logs", "checkpoints", latest_checkpoint ) - checkpoint = torch.load(checkpoint_path, map_location=torch.device("cpu")) + checkpoint: Dict[str, Any] = torch.load(checkpoint_path, map_location=torch.device("cpu")) # Create a copy of CFG for model creation - model_cfg = copy.deepcopy(CFG) + model_cfg: Any = copy.deepcopy(CFG) # Dynamically determine configuration from the saved positional embeddings - decoder_pos_embed_key = "decoder.decoder_pos_embed" - encoder_pos_embed_key = "decoder.encoder_pos_embed" + decoder_pos_embed_key: str = "decoder.decoder_pos_embed" + encoder_pos_embed_key: str = "decoder.encoder_pos_embed" if decoder_pos_embed_key in checkpoint["state_dict"]: - saved_decoder_pos_embed_shape = checkpoint["state_dict"][decoder_pos_embed_key].shape - checkpoint_max_len_minus_1 = saved_decoder_pos_embed_shape[1] # Shape is [1, MAX_LEN-1, embed_dim] - checkpoint_max_len = checkpoint_max_len_minus_1 + 1 - checkpoint_n_vertices = (checkpoint_max_len - 2) // 2 # Reverse: MAX_LEN = (N_VERTICES*2) + 2 + saved_decoder_pos_embed_shape: Tuple[int, ...] = checkpoint["state_dict"][decoder_pos_embed_key].shape + checkpoint_max_len_minus_1: int = saved_decoder_pos_embed_shape[1] # Shape is [1, MAX_LEN-1, embed_dim] + checkpoint_max_len: int = checkpoint_max_len_minus_1 + 1 + checkpoint_n_vertices: int = (checkpoint_max_len - 2) // 2 # Reverse: MAX_LEN = (N_VERTICES*2) + 2 if checkpoint_n_vertices != CFG.N_VERTICES: model_cfg.N_VERTICES = checkpoint_n_vertices model_cfg.MAX_LEN = checkpoint_max_len if encoder_pos_embed_key in checkpoint["state_dict"]: - saved_encoder_pos_embed_shape = checkpoint["state_dict"][encoder_pos_embed_key].shape - checkpoint_num_patches = saved_encoder_pos_embed_shape[1] # Shape is [1, num_patches, embed_dim] + saved_encoder_pos_embed_shape: Tuple[int, ...] = checkpoint["state_dict"][encoder_pos_embed_key].shape + checkpoint_num_patches: int = saved_encoder_pos_embed_shape[1] # Shape is [1, num_patches, embed_dim] if checkpoint_num_patches != CFG.NUM_PATCHES: model_cfg.NUM_PATCHES = checkpoint_num_patches @@ -277,8 +349,8 @@ def _initialize_model(self) -> None: CFG.PAD_IDX = self.tokenizer.PAD_code # Create model with the adapted configuration - encoder = Encoder(model_name=model_cfg.MODEL_NAME, pretrained=True, out_dim=256) - decoder = Decoder( + encoder: Encoder = Encoder(model_name=model_cfg.MODEL_NAME, pretrained=True, out_dim=256) + decoder: Decoder = Decoder( cfg=model_cfg, # Use adapted configuration vocab_size=self.tokenizer.vocab_size, encoder_len=model_cfg.NUM_PATCHES, @@ -306,11 +378,11 @@ def _find_single_checkpoint(self) -> str: FileNotFoundError: If no checkpoint directory or files are found RuntimeError: If more than one checkpoint file is found """ - checkpoint_dir = os.path.join(self.experiment_path, "logs", "checkpoints") + checkpoint_dir: str = os.path.join(self.experiment_path, "logs", "checkpoints") if not os.path.exists(checkpoint_dir): raise FileNotFoundError(f"Checkpoint directory not found: {checkpoint_dir}") - checkpoint_files = [ + checkpoint_files: List[str] = [ f for f in os.listdir(checkpoint_dir) if f.startswith("epoch_") and f.endswith(".pth") @@ -326,21 +398,21 @@ def _find_single_checkpoint(self) -> str: return checkpoint_files[0] def _process_tiles_batch( - self, tiles: List[np.ndarray], debug: bool = False - ) -> List[Dict[str, List[np.ndarray]]]: + self, tiles: List[npt.NDArray[np.uint8]], debug: bool = False + ) -> List[TileResult]: """Process a single batch of tiles. Args: - tiles (list[np.ndarray]): List of tile images to process + tiles (list[npt.NDArray[np.uint8]]): List of tile images to process Returns: - list[dict]: List of results for each tile, where each result contains: + list[TileResult]: List of results for each tile, where each result contains: - polygons: List of polygon coordinates """ # Generate cache key and try to load from cache (only when debug=True) if debug: - cache_key = self._generate_cache_key(tiles) - cached_results = self._load_from_cache(cache_key) + cache_key: str = self._generate_cache_key(tiles) + cached_results: Optional[List[TileResult]] = self._load_from_cache(cache_key) if cached_results is not None: log(f"Cache hit for batch of {len(tiles)} tiles") return cached_results @@ -348,9 +420,9 @@ def _process_tiles_batch( cache_key = None # Start timing for actual processing - batch_start_time = time.time() + batch_start_time: float = time.time() log(f"Processing batch of {len(tiles)} tiles...") - valid_transforms = A.Compose( + valid_transforms: A.Compose = A.Compose( [ A.Resize(height=CFG.INPUT_HEIGHT, width=CFG.INPUT_WIDTH), A.Normalize( @@ -363,15 +435,19 @@ def _process_tiles_batch( # Transform each tile individually and stack them transformed_tiles: List[torch.Tensor] = [] for tile in tiles: - transformed = valid_transforms(image=tile) + transformed: Dict[str, torch.Tensor] = valid_transforms(image=tile) transformed_tiles.append(transformed["image"]) # Stack the transformed tiles into a batch - batch_tensor = torch.stack(transformed_tiles).to(self.device) + batch_tensor: torch.Tensor = torch.stack(transformed_tiles).to(self.device) with torch.no_grad(): # Use adapted configuration for generation - adapted_generation_steps = (self.model_cfg.N_VERTICES * 2) + 1 + assert self.model_cfg is not None, "Model configuration not initialized" + adapted_generation_steps: int = (self.model_cfg.N_VERTICES * 2) + 1 + batch_preds: torch.Tensor + batch_confs: torch.Tensor + perm_preds: torch.Tensor batch_preds, batch_confs, perm_preds = test_generate( self.model, batch_tensor, @@ -381,31 +457,34 @@ def _process_tiles_batch( top_p=1, ) + vertex_coords: List[Optional[npt.NDArray[np.floating[Any]]]] + confs: List[Optional[npt.NDArray[np.floating[Any]]]] vertex_coords, confs = postprocess(batch_preds, batch_confs, self.tokenizer) - results: List[Dict[str, List[np.ndarray]]] = [] + results: List[TileResult] = [] for j in range(len(tiles)): + coord: torch.Tensor if vertex_coords[j] is not None: coord = torch.from_numpy(vertex_coords[j]) else: coord = torch.tensor([]) - padd = torch.ones((self.model_cfg.N_VERTICES - len(coord), 2)).fill_(CFG.PAD_IDX) + padd: torch.Tensor = torch.ones((self.model_cfg.N_VERTICES - len(coord), 2)).fill_(CFG.PAD_IDX) coord = torch.cat([coord, padd], dim=0) - batch_polygons = permutations_to_polygons( + batch_polygons: List[List[torch.Tensor]] = permutations_to_polygons( perm_preds[j : j + 1], [coord], out="torch" ) - valid_polygons: List[np.ndarray] = [] + valid_polygons: List[PolygonArray] = [] for poly in batch_polygons[0]: - poly = poly[poly[:, 0] != CFG.PAD_IDX] - if len(poly) > 0: + poly_filtered: torch.Tensor = poly[poly[:, 0] != CFG.PAD_IDX] + if len(poly_filtered) > 0: valid_polygons.append( - poly.cpu().numpy()[:, ::-1] + poly_filtered.cpu().numpy()[:, ::-1] ) # Convert to [x,y] format - result = {"polygons": valid_polygons} + result: TileResult = {"polygons": valid_polygons} results.append(result) @@ -414,40 +493,42 @@ def _process_tiles_batch( self._save_to_cache(cache_key, results) # Log processing time per tile - batch_time = time.time() - batch_start_time + batch_time: float = time.time() - batch_start_time log(f"Batch processing time: {batch_time/len(tiles):.3f}s per tile") return results def _create_tile_visualization( self, - tiles: List[np.ndarray], - tile_results: List[Dict[str, List[np.ndarray]]], - positions: List[Tuple[int, int, int, int]], + tiles: List[npt.NDArray[np.uint8]], + tile_results: List[TileResult], + positions: List[TilePosition], ) -> None: """Create a tile visualization showing each tile with its detected polygons and coordinate scales. Args: - tiles (List[np.ndarray]): List of tile images - tile_results (List[Dict[str, List[np.ndarray]]]): List of results for each tile - positions (List[Tuple[int, int, int, int]]): List of (x, y, x_end, y_end) tuples for each tile's position + tiles (List[npt.NDArray[np.uint8]]): List of tile images + tile_results (List[TileResult]): List of results for each tile + positions (List[TilePosition]): List of (x, y, x_end, y_end) tuples for each tile's position """ if not tiles: return # Calculate grid dimensions based on actual spatial arrangement # Extract unique x and y starting positions - x_positions = sorted(set(pos[0] for pos in positions)) - y_positions = sorted(set(pos[1] for pos in positions)) + x_positions: List[int] = sorted(set(pos[0] for pos in positions)) + y_positions: List[int] = sorted(set(pos[1] for pos in positions)) - cols = len(x_positions) - rows = len(y_positions) + cols: int = len(x_positions) + rows: int = len(y_positions) # Create mapping from (x, y) position to (row, col) index - x_to_col = {x: i for i, x in enumerate(x_positions)} - y_to_row = {y: i for i, y in enumerate(y_positions)} + x_to_col: Dict[int, int] = {x: i for i, x in enumerate(x_positions)} + y_to_row: Dict[int, int] = {y: i for i, y in enumerate(y_positions)} # Create figure + fig: plt.Figure + axes: Union[plt.Axes, List[plt.Axes], List[List[plt.Axes]]] fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4)) # Handle different subplot layouts @@ -468,10 +549,10 @@ def _create_tile_visualization( x, y, x_end, y_end = pos # Get the grid position for this tile - row = y_to_row[y] - col = x_to_col[x] + row: int = y_to_row[y] + col: int = x_to_col[x] - ax = axes[row][col] + ax: plt.Axes = axes[row][col] # Tiles are already in RGB format, no conversion needed for matplotlib ax.imshow(tile) @@ -481,17 +562,20 @@ def _create_tile_visualization( ax.axis('on') # Get tile dimensions + tile_height: int + tile_width: int tile_height, tile_width = tile.shape[:2] # Set up x-axis ticks and labels (global coordinates) - x_range = x_end - x + x_range: int = x_end - x # Generate tick positions ensuring min and max are included - num_x_ticks = 8 + num_x_ticks: int = 8 + x_tick_positions: List[int] if tile_width > 1: x_tick_positions = [0] # Always include minimum if num_x_ticks > 2: # Add intermediate positions - step = tile_width / (num_x_ticks - 1) + step: float = tile_width / (num_x_ticks - 1) for i in range(1, num_x_ticks - 1): x_tick_positions.append(int(i * step)) x_tick_positions.append(tile_width - 1) # Always include maximum @@ -499,7 +583,7 @@ def _create_tile_visualization( x_tick_positions = [0] # Calculate corresponding global coordinates - x_global_coords = [x + pos * x_range // tile_width for pos in x_tick_positions] + x_global_coords: List[int] = [x + pos * x_range // tile_width for pos in x_tick_positions] # Ensure the last coordinate is exactly x_end if len(x_global_coords) > 1: x_global_coords[-1] = x_end @@ -508,9 +592,10 @@ def _create_tile_visualization( ax.set_xticklabels([str(coord) for coord in x_global_coords], fontsize=8) # Set up y-axis ticks and labels (global coordinates) - y_range = y_end - y + y_range: int = y_end - y # Generate tick positions ensuring min and max are included - num_y_ticks = 8 + num_y_ticks: int = 8 + y_tick_positions: List[int] if tile_height > 1: y_tick_positions = [0] # Always include minimum if num_y_ticks > 2: @@ -523,7 +608,7 @@ def _create_tile_visualization( y_tick_positions = [0] # Calculate corresponding global coordinates - y_global_coords = [y + pos * y_range // tile_height for pos in y_tick_positions] + y_global_coords: List[int] = [y + pos * y_range // tile_height for pos in y_tick_positions] # Ensure the last coordinate is exactly y_end if len(y_global_coords) > 1: y_global_coords[-1] = y_end @@ -540,29 +625,29 @@ def _create_tile_visualization( ax.tick_params(axis='both', which='major', labelsize=8, length=3) # Draw polygons on this tile - polygons = tile_result["polygons"] - polygon_valid = tile_result["polygon_valid"] + polygons: List[PolygonArray] = tile_result["polygons"] + polygon_valid: List[bool] = tile_result["polygon_valid"] for poly_idx, (poly, is_valid) in enumerate(zip(polygons, polygon_valid)): if len(poly) > 2: # Use green for valid polygons, red for invalid ones - color = 'g' if is_valid else 'r' - vertex_color = 'red' if is_valid else 'darkred' + color: str = 'g' if is_valid else 'r' + vertex_color: str = 'red' if is_valid else 'darkred' # Close the polygon for visualization - poly_closed = np.vstack([poly, poly[0]]) + poly_closed: PolygonArray = np.vstack([poly, poly[0]]) ax.plot(poly_closed[:, 0], poly_closed[:, 1], f'{color}-', linewidth=2) # Draw vertices ax.scatter(poly[:, 0], poly[:, 1], c=vertex_color, s=20, zorder=5) # Calculate centroid and render polygon index - centroid_x = np.mean(poly[:, 0]) - centroid_y = np.mean(poly[:, 1]) + centroid_x: float = np.mean(poly[:, 0]) + centroid_y: float = np.mean(poly[:, 1]) # Use white text with black outline for visibility - text_color = 'white' - outline_color = 'black' + text_color: str = 'white' + outline_color: str = 'black' # Add text with outline for better visibility ax.text(centroid_x, centroid_y, str(poly_idx), @@ -583,26 +668,26 @@ def _create_tile_visualization( def _validate_all_polygons( self, - tile_results: List[Dict[str, List[np.ndarray]]], - positions: List[Tuple[int, int, int, int]], + tile_results: List[TileResult], + positions: List[TilePosition], image_height: int, image_width: int, merge_tolerance: float - ) -> List[Dict[str, List[np.ndarray]]]: + ) -> List[TileResult]: """Validate all polygons in the tile results and add validation attributes. This method implements a heuristic to validate polygons by checking if their boundary edges have points that are contained in polygons from other tiles. Args: - tile_results (List[Dict[str, List[np.ndarray]]]): List of tile results containing polygons - positions (List[Tuple[int, int, int, int]]): List of (x, y, x_end, y_end) tuples for each tile's position + tile_results (List[TileResult]): List of tile results containing polygons + positions (List[TilePosition]): List of (x, y, x_end, y_end) tuples for each tile's position image_height (int): Height of the original image image_width (int): Width of the original image merge_tolerance (float): Tolerance for point-in-polygon tests during validation (in pixels) Returns: - List[Dict[str, List[np.ndarray]]]: Updated tile results with validation attributes + List[TileResult]: Updated tile results with validation attributes """ # Initialize polygon_valid list for each tile for tile_result in tile_results: @@ -756,12 +841,12 @@ def _validate_all_polygons( def _merge_polygons( self, - tile_results: List[Dict[str, List[np.ndarray]]], - positions: List[Tuple[int, int, int, int]], + tile_results: List[TileResult], + positions: List[TilePosition], image_height: int, image_width: int, debug: bool = False, - ) -> List[np.ndarray]: + ) -> List[PolygonArray]: """Merge polygon predictions from multiple tiles using a bitmap approach. This method creates a bitmap where pixels inside any polygon are set to True, @@ -769,27 +854,28 @@ def _merge_polygons( from traditional polygon union operations. Args: - tile_results (list[dict]): List of dictionaries containing 'polygons' for each tile - positions (list[tuple[int, int, int, int]]): List of (x, y, x_end, y_end) tuples for each tile's position + tile_results (list[TileResult]): List of dictionaries containing 'polygons' for each tile + positions (list[TilePosition]): List of (x, y, x_end, y_end) tuples for each tile's position image_height (int): Height of the original image image_width (int): Width of the original image + debug (bool): Whether to save debug images Returns: - list[np.ndarray]: List of merged polygons in original image coordinates + list[PolygonArray]: List of merged polygons in original image coordinates """ # Scale factor for subpixel precision - scale_factor = 16 + scale_factor: int = 16 # Create bitmap at 8x resolution for subpixel precision - bitmap = np.zeros((image_height * scale_factor, image_width * scale_factor), dtype=np.uint8) + bitmap: npt.NDArray[np.uint8] = np.zeros((image_height * scale_factor, image_width * scale_factor), dtype=np.uint8) # Process all valid polygons and fill them immediately for tile_result, (x, y, x_end, y_end) in zip(tile_results, positions): - tile_polygons = tile_result["polygons"] - polygon_valid = tile_result["polygon_valid"] + tile_polygons: List[PolygonArray] = tile_result["polygons"] + polygon_valid: List[bool] = tile_result["polygon_valid"] # Pre-allocate translation vector for this tile - translation_vector = np.array([x, y]) + translation_vector: npt.NDArray[np.floating[Any]] = np.array([x, y]) for poly, is_valid in zip(tile_polygons, polygon_valid): # Skip invalid polygons @@ -797,23 +883,23 @@ def _merge_polygons( continue # Transform polygon from tile coordinates to image coordinates - transformed_poly = poly + translation_vector + transformed_poly: PolygonArray = poly + translation_vector # Scale up coordinates for high-resolution bitmap - scaled_poly = transformed_poly * scale_factor + scaled_poly: PolygonArray = transformed_poly * scale_factor # Ensure coordinates are within scaled bitmap bounds scaled_poly[:, 0] = np.clip(scaled_poly[:, 0], 0, image_width * scale_factor - 1) scaled_poly[:, 1] = np.clip(scaled_poly[:, 1], 0, image_height * scale_factor - 1) # Convert to integer coordinates for rasterization - poly_coords = scaled_poly.astype(np.int32) + poly_coords: npt.NDArray[np.int32] = scaled_poly.astype(np.int32) # Fill polygon immediately to avoid winding order issues cv2.fillPoly(bitmap, [poly_coords], 255) - kernel_size = 32 - kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) + kernel_size: int = 32 + kernel: npt.NDArray[np.uint8] = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size)) bitmap = cv2.morphologyEx(bitmap, cv2.MORPH_CLOSE, kernel) # Save bitmap for debugging (optional) @@ -822,20 +908,22 @@ def _merge_polygons( log("Saved bitmap visualization to bitmap-visualization.png") # Find contours in the bitmap + contours: List[npt.NDArray[np.int32]] + _: Any # hierarchy not used contours, _ = cv2.findContours(bitmap, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) # Collect all valid contours into shapely polygons - shapely_polygons = [] + shapely_polygons: List[Polygon] = [] for contour in contours: # Skip very small contours (area is scaled by scale_factor^2) - area = cv2.contourArea(contour) + area: float = cv2.contourArea(contour) if area < CFG.MIN_POLYGON_AREA * (scale_factor ** 2): continue # Convert contour to Shapely Polygon - contour_points = contour.reshape(-1, 2).astype(np.float64) - shapely_polygon = Polygon(contour_points) + contour_points: npt.NDArray[np.float64] = contour.reshape(-1, 2).astype(np.float64) + shapely_polygon: Polygon = Polygon(contour_points) shapely_polygon = make_valid(shapely_polygon) @@ -844,7 +932,7 @@ def _merge_polygons( if shapely_polygon.geom_type == 'MultiPolygon': # Extract individual polygons from MultiPolygon for individual_poly in shapely_polygon.geoms: - simple_poly = Polygon(individual_poly.exterior.coords) + simple_poly: Polygon = Polygon(individual_poly.exterior.coords) if simple_poly.is_valid and simple_poly.area > 0: shapely_polygons.append(simple_poly) elif shapely_polygon.geom_type == 'Polygon': @@ -854,23 +942,23 @@ def _merge_polygons( else: log(f"Skipping invalid polygon") - merged_polygons: List[np.ndarray] = [] + merged_polygons: List[PolygonArray] = [] # Create single GeoDataFrame with all polygons and regularize them all at once if shapely_polygons: log(f"Regularizing {len(shapely_polygons)} polygons") - gdf = gpd.GeoDataFrame({'geometry': shapely_polygons}) - regularized_gdf = regularize_geodataframe(gdf, simplify_tolerance=20, parallel_threshold=100) + gdf: gpd.GeoDataFrame = gpd.GeoDataFrame({'geometry': shapely_polygons}) + regularized_gdf: gpd.GeoDataFrame = regularize_geodataframe(gdf, simplify_tolerance=20, parallel_threshold=100) # Process the regularized polygons for regularized_polygon in regularized_gdf.geometry: # Convert back to numpy array for OpenCV format - coords = np.array(regularized_polygon.exterior.coords[:-1]) # Remove duplicate last point + coords: npt.NDArray[np.floating[Any]] = np.array(regularized_polygon.exterior.coords[:-1]) # Remove duplicate last point # Convert from OpenCV format to our polygon format if len(coords) >= 3: # Valid polygon needs at least 3 points # Scale down coordinates back to original image coordinate system - polygon_coords = coords.astype(np.float32) / scale_factor + polygon_coords: PolygonArray = coords.astype(np.float32) / scale_factor merged_polygons.append(polygon_coords) @@ -900,8 +988,8 @@ def infer(self, image_data: bytes, debug: bool = False, merge_tolerance: Optiona seed_everything(42) # Decode image - nparr = np.frombuffer(image_data, np.uint8) - image = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + nparr: npt.NDArray[np.uint8] = np.frombuffer(image_data, np.uint8) + image: Optional[npt.NDArray[np.uint8]] = cv2.imdecode(nparr, cv2.IMREAD_COLOR) if image is None: raise ValueError("Failed to decode image data") @@ -912,17 +1000,19 @@ def infer(self, image_data: bytes, debug: bool = False, merge_tolerance: Optiona image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Split image into tiles + height: int + width: int height, width = image.shape[:2] if height == 0 or width == 0: raise ValueError("Invalid image dimensions") # Use provided parameters or fall back to config defaults - effective_merge_tolerance = merge_tolerance if merge_tolerance is not None else CFG.MERGE_TOLERANCE - effective_tile_overlap_ratio = tile_overlap_ratio if tile_overlap_ratio is not None else CFG.TILE_OVERLAP_RATIO + effective_merge_tolerance: float = merge_tolerance if merge_tolerance is not None else CFG.MERGE_TOLERANCE + effective_tile_overlap_ratio: float = tile_overlap_ratio if tile_overlap_ratio is not None else CFG.TILE_OVERLAP_RATIO - overlap_ratio = effective_tile_overlap_ratio + overlap_ratio: float = effective_tile_overlap_ratio - bboxes = calculate_slice_bboxes( + bboxes: List[TilePosition] = calculate_slice_bboxes( image_height=height, image_width=width, slice_height=CFG.TILE_SIZE, @@ -931,11 +1021,11 @@ def infer(self, image_data: bytes, debug: bool = False, merge_tolerance: Optiona overlap_width_ratio=overlap_ratio, ) - tiles: List[np.ndarray] = [] + tiles: List[npt.NDArray[np.uint8]] = [] for bbox in bboxes: x1, y1, x2, y2 = bbox - tile = image[y1:y2, x1:x2] + tile: npt.NDArray[np.uint8] = image[y1:y2, x1:x2] if tile.size == 0: continue tiles.append(tile) @@ -943,15 +1033,15 @@ def infer(self, image_data: bytes, debug: bool = False, merge_tolerance: Optiona log(f"Total number of tiles to process: {len(tiles)}") # Process tiles in batches - all_results: List[Dict[str, List[np.ndarray]]] = [] + all_results: List[TileResult] = [] for i in range(0, len(tiles), CFG.PREDICTION_BATCH_SIZE): - batch_tiles = tiles[i : i + CFG.PREDICTION_BATCH_SIZE] - batch_results = self._process_tiles_batch(batch_tiles, debug) + batch_tiles: List[npt.NDArray[np.uint8]] = tiles[i : i + CFG.PREDICTION_BATCH_SIZE] + batch_results: List[TileResult] = self._process_tiles_batch(batch_tiles, debug) all_results.extend(batch_results) - tiles_processed_so_far = i + len(batch_tiles) - total_tiles = len(tiles) + tiles_processed_so_far: int = i + len(batch_tiles) + total_tiles: int = len(tiles) log(f"Processed batch of {len(batch_tiles)} tiles ({tiles_processed_so_far}/{total_tiles})") # Validate all polygons and add validation attributes @@ -963,10 +1053,10 @@ def infer(self, image_data: bytes, debug: bool = False, merge_tolerance: Optiona if debug: self._create_tile_visualization(tiles, all_results, bboxes) - merged_polygons = self._merge_polygons(all_results, bboxes, height, width, debug) + merged_polygons: List[PolygonArray] = self._merge_polygons(all_results, bboxes, height, width, debug) # Convert to list format - polygons_list = [poly.tolist() for poly in merged_polygons] + polygons_list: List[List[List[float]]] = [poly.tolist() for poly in merged_polygons] # Round coordinates to two decimal places polygons_list = [ [[round(x, 2), round(y, 2)] for x, y in polygon] From 27153d047a7acbecbb33483557a3ea9d5ac79fe5 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Wed, 16 Jul 2025 10:18:41 -0500 Subject: [PATCH 41/45] Log the model being used. --- api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api.py b/api.py index ff8c6da..8ebb711 100644 --- a/api.py +++ b/api.py @@ -201,6 +201,8 @@ def load_model(model_name: str): HTTPException: If model name is invalid or model files don't exist """ global predictor, current_model_name, model_dir + + log(f"Using model: {model_name}") if not validate_model_name(model_name): raise HTTPException(status_code=400, detail="Invalid model name. Only alphanumeric characters, underscores, and hyphens are allowed.") From 546044a1cc26cbdd4f1b73a6bf4e6e6e6f6af913 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Wed, 16 Jul 2025 13:06:29 -0500 Subject: [PATCH 42/45] Handle multipolygons coming out of regularization. --- polygon_inference.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/polygon_inference.py b/polygon_inference.py index 169b922..d0daa8d 100644 --- a/polygon_inference.py +++ b/polygon_inference.py @@ -952,15 +952,24 @@ def _merge_polygons( # Process the regularized polygons for regularized_polygon in regularized_gdf.geometry: - # Convert back to numpy array for OpenCV format - coords: npt.NDArray[np.floating[Any]] = np.array(regularized_polygon.exterior.coords[:-1]) # Remove duplicate last point + # Extract individual polygons (either from MultiPolygon or single Polygon) + individual_polygons = [] + if regularized_polygon.geom_type == 'MultiPolygon': + individual_polygons = list(regularized_polygon.geoms) + elif regularized_polygon.geom_type == 'Polygon': + individual_polygons = [regularized_polygon] - # Convert from OpenCV format to our polygon format - if len(coords) >= 3: # Valid polygon needs at least 3 points - # Scale down coordinates back to original image coordinate system - polygon_coords: PolygonArray = coords.astype(np.float32) / scale_factor - - merged_polygons.append(polygon_coords) + # Process each individual polygon with single code path + for individual_polygon in individual_polygons: + if individual_polygon.is_valid and individual_polygon.area > 0: + # Convert back to numpy array for OpenCV format + coords: npt.NDArray[np.floating[Any]] = np.array(individual_polygon.exterior.coords[:-1]) # Remove duplicate last point + + # Convert from OpenCV format to our polygon format + if len(coords) >= 3: # Valid polygon needs at least 3 points + # Scale down coordinates back to original image coordinate system + polygon_coords: PolygonArray = coords.astype(np.float32) / scale_factor + merged_polygons.append(polygon_coords) log(f"Polygons extracted: {len(merged_polygons)}") return merged_polygons From dd8968870f48dd4ee4411862042c24a5b913e8bc Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Fri, 18 Jul 2025 15:00:38 -0500 Subject: [PATCH 43/45] Listen on all interfaces. --- start_api.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/start_api.sh b/start_api.sh index 7b43bac..f548f99 100755 --- a/start_api.sh +++ b/start_api.sh @@ -9,4 +9,4 @@ conda activate pix2poly # Start the API server echo "Starting API server" -uvicorn api:app --port 8080 --workers 1 --backlog 10 +uvicorn api:app --host 0.0.0.0 --port 8080 --workers 1 --backlog 10 From bd1596bf099c0c7fa24de1e16f315fea9dab52d1 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Mon, 10 Nov 2025 10:28:43 -0600 Subject: [PATCH 44/45] Update build command for cross-platform compatibility. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 8f71e20..051f4a7 100644 --- a/README.md +++ b/README.md @@ -174,7 +174,7 @@ Pix2Poly provides a Docker setup for easy deployment and inference. The Docker c ### Building the Docker Image ```bash -docker build -t pix2poly . +docker buildx build --platform linux/amd64 -t pix2poly . ``` ### Running the API Server From 7c393e775267189c92dfc014ffe50db8f7c386a4 Mon Sep 17 00:00:00 2001 From: Scott Martin Date: Mon, 10 Nov 2025 15:19:33 -0600 Subject: [PATCH 45/45] Specify a compatible version of huggingface-hub --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index 2499e4d..2aec210 100644 --- a/environment.yml +++ b/environment.yml @@ -30,4 +30,5 @@ dependencies: - python-multipart>=0.0.5 - tqdm>=4.62.0 - diskcache>=5.6.0 + - huggingface-hub>=0.15.1,<1.0 \ No newline at end of file