From c042fe4ac00a00963f9096de91d06eb08afbd00a Mon Sep 17 00:00:00 2001 From: Dhiraj BM Date: Sun, 2 Mar 2025 18:30:55 +0530 Subject: [PATCH 1/2] Added kangas features with tests --- dev_requirements.txt | 1 + environment.yml | 1 + src/deepforest/main.py | 223 +++++++++++++++++++++++++++++++++++++++-- tests/test_main.py | 101 ++++++++++++++++++- 4 files changed, 317 insertions(+), 9 deletions(-) diff --git a/dev_requirements.txt b/dev_requirements.txt index a19c1a94c..5d890eacb 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -8,6 +8,7 @@ pydata-sphinx-theme geopandas huggingface_hub>=0.25.0 h5py +kangas matplotlib nbmake nbsphinx diff --git a/environment.yml b/environment.yml index 37d14329f..3ba260ded 100644 --- a/environment.yml +++ b/environment.yml @@ -8,6 +8,7 @@ dependencies: - geopandas - huggingface_hub>=0.25.0 - h5py + - kangas - matplotlib - nbmake - nbsphinx diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 1db109d76..70eaa203b 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -3,7 +3,7 @@ import os import typing import warnings - +import kangas as kg import numpy as np import pandas as pd import pytorch_lightning as pl @@ -35,6 +35,10 @@ class deepforest(pl.LightningModule, PyTorchModelHubMixin): existing_train_dataloader: a Pytorch dataloader that yields a tuple path, images, targets existing_val_dataloader: a Pytorch dataloader that yields a tuple path, images, targets + Notes: + Kangas visualization is supported in predict_* and evaluate methods with visualize_with="kangas". + Install Kangas with `pip install kangas` to enable this optional feature. + Returns: self: a deepforest pytorch lightning module """ @@ -126,6 +130,134 @@ def __init__(self, self.save_hyperparameters() + def visualize_evaluation_kangas( + self, + predictions: pd.DataFrame, + ground_df: pd.DataFrame, + root_dir: str, + evaluation_results: typing.Optional[dict] = None) -> None: + """Use Kangas to visualize predictions and ground truth data from + evaluation. + + Args: + predictions: Predictions DataFrame with columns "image_path", "xmin", "ymin", "xmax", "ymax", "label", "score". + ground_df: Ground truth DataFrame with columns "image_path", "xmin", "ymin", "xmax", "ymax", "label". + root_dir: Directory where images are stored. + evaluation_results: Optional dictionary of metrics (e.g., precision, recall) to display alongside. + + Returns: + None + """ + if kg is None: + print( + "Kangas is not installed. Run 'pip install kangas' to enable visualization." + ) + return + + kangas_data = [] + for _, row in ground_df.iterrows(): + image_path = os.path.join(root_dir, row["image_path"]) + kangas_data.append({ + "image_path": + image_path, + "bounding_boxes": [{ + "xmin": row["xmin"], + "ymin": row["ymin"], + "xmax": row["xmax"], + "ymax": row["ymax"], + "label": f"GT_{row['label']}", + "score": 1.0 + }] + }) + + for _, row in predictions.iterrows(): + image_path = os.path.join(root_dir, row["image_path"]) + kangas_data.append({ + "image_path": + image_path, + "bounding_boxes": [{ + "xmin": row["xmin"], + "ymin": row["ymin"], + "xmax": row["xmax"], + "ymax": row["ymax"], + "label": f"Pred_{row['label']}", + "score": row["score"] + }] + }) + + try: + grid = kg.DataGrid(kangas_data, name="DeepForest Evaluation") + grid.show() + except Exception as e: + print(f"Kangas evaluation failed :{e}") + + if evaluation_results: + metrics_data = [{ + "Metric": key, + "Value": value + } + for key, value in evaluation_results.items() + if isinstance(value, (int, float))] + if metrics_data: + metrics_grid = kg.DataGrid(metrics_data, name="Evaluation Metrics") + metrics_grid.show() + + def visualize_kangas(self, + predictions: typing.Union[pd.DataFrame, + typing.List[pd.DataFrame]], + image_paths: typing.Optional[typing.List[str]] = None) -> None: + """Visualize predictions using Kangas. + + Args: + predictions: DataFrame or list of DataFrames with "image_path", "xmin", "ymin", "xmax", "ymax", "label", "score" + image_paths: Optional list of image paths if not included in predictions + """ + if kg is None: + print( + "Kangas is not installed. Run 'pip install kangas' to enable visualization." + ) + return + + # Handle different prediction formats + if isinstance(predictions, pd.DataFrame): + df = predictions + elif isinstance(predictions, list) and all( + isinstance(p, pd.DataFrame) for p in predictions): + df = pd.concat(predictions, ignore_index=True) + else: + raise ValueError("Predictions must be a DataFrame or list of DataFrames") + + # Ensure image paths are available + if "image_path" not in df.columns: + if not image_paths: + raise ValueError("image_paths must be provided if not in predictions") + if len(image_paths) != (len(predictions) + if isinstance(predictions, list) else 1): + raise ValueError("Length of image_paths must match predictions") + df["image_path"] = image_paths if isinstance( + predictions, list) else [image_paths[0]] * len(df) + + # Use root_dir if available + root_dir = getattr(df, "root_dir", None) if hasattr(df, "root_dir") else None + if root_dir and not os.path.isabs(df["image_path"].iloc[0]): + df["image_path"] = df["image_path"].apply(lambda x: os.path.join(root_dir, x)) + + # Group predictions by image_path + grouped = df.groupby("image_path") + kangas_data = [{ + "image_path": + image_path, + "bounding_boxes": + group[["xmin", "ymin", "xmax", "ymax", "label", + "score"]].to_dict(orient="records") + } for image_path, group in grouped] + + try: + grid = kg.DataGrid(kangas_data, name="DeepForest Predictions") + grid.show() + except Exception as e: + print(f"Kangas visualization failed: {e}") + def load_model(self, model_name="weecology/deepforest-tree", revision='main'): """Loads a model that has already been pretrained for a specific task, like tree crown detection. @@ -336,7 +468,10 @@ def val_dataloader(self): batch_size=self.config["batch_size"]) return loader - def predict_dataloader(self, ds): + def predict_dataloader( + self, + ds: torch.utils.data.Dataset, + visualize_with: typing.Optional[str] = None) -> torch.utils.data.DataLoader: """Create a PyTorch dataloader for prediction. Args: @@ -350,6 +485,20 @@ def predict_dataloader(self, ds): shuffle=False, num_workers=self.config["workers"]) + if visualize_with == "kangas": + # Generate predictions and visualize + predictions = self.trainer.predict(self, loader) + if predictions: + flattened_predictions = [item for batch in predictions for item in batch] + # Extract image paths from dataset if available, else use placeholder + image_paths = [ + getattr(ds, "paths", + [f"image_{i}" + for i in range(len(flattened_predictions))])[i] + for i in range(len(flattened_predictions)) + ] + self.visualize_kangas(flattened_predictions, image_paths) + return loader def predict_image(self, @@ -357,7 +506,8 @@ def predict_image(self, path: typing.Optional[str] = None, return_plot: bool = False, color: typing.Optional[tuple] = (0, 165, 255), - thickness: int = 1): + thickness: int = 1, + visualize_with: typing.Optional[str] = None): """Predict a single image with a deepforest model. Deprecation warning: The 'return_plot', and related 'color' and 'thickness' arguments @@ -428,10 +578,19 @@ def predict_image(self, else: root_dir = os.path.dirname(path) result = utilities.read_file(result, root_dir=root_dir) + # Visualize if requested + if visualize_with == "kangas": + self.visualize_kangas(result) return result - def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1): + def predict_file(self, + csv_file, + root_dir, + savedir=None, + color=None, + thickness=1, + visualize_with=None): """Create a dataset and predict entire annotation file Csv file format is .csv file with the columns "image_path", "xmin","ymin","xmax","ymax" for the image name and bounding box position. Image_path is the @@ -469,6 +628,9 @@ def predict_file(self, csv_file, root_dir, savedir=None, color=None, thickness=1 thickness=thickness) results.root_dir = root_dir + # Visualize if requested + if visualize_with == "kangas": + self.visualize_kangas(results) return results @@ -487,7 +649,8 @@ def predict_tile(self, thickness=1, crop_model=None, crop_transform=None, - crop_augment=False): + crop_augment=False, + visualize_with=None): """For images too large to input into the model, predict_tile cuts the image into overlapping windows, predicts trees on each window and reassambles into a single array. @@ -625,6 +788,11 @@ def predict_tile(self, root_dir = os.path.dirname(raster_path) results = utilities.read_file(results, root_dir=root_dir) + if visualize_with == "kangas" and mosaic: + self.visualize_kangas(results) + elif visualize_with == "kangas": + print("Kangas visualization only supported with mosaic=True.") + return results def training_step(self, batch, batch_idx): @@ -819,6 +987,14 @@ def on_validation_epoch_end(self): if empty_accuracy is not None: results["empty_frame_accuracy"] = empty_accuracy + # Check config for Kangas visualization + if self.config.get("visualize_with") == "kangas": + self.visualize_evaluation_kangas( + predictions=self.predictions_df, + ground_df=ground_df, + root_dir=self.config["validation"]["root_dir"], + evaluation_results=results) + # Log each key value pair of the results dict if not results["class_recall"] is None: for key, value in results.items(): @@ -844,16 +1020,32 @@ def on_validation_epoch_end(self): except MisconfigurationException: pass - def predict_step(self, batch, batch_idx): + def predict_step( + self, + batch: typing.Any, + batch_idx: int, + visualize_with: typing.Optional[str] = None, + image_paths: typing.Optional[typing.List[str]] = None + ) -> typing.List[pd.DataFrame]: + batch_results = self.model(batch) results = [] for result in batch_results: boxes = visualize.format_boxes(result) results.append(boxes) + + if visualize_with == "kangas": + self.visualize_kangas(results, image_paths) return results - def predict_batch(self, images, preprocess_fn=None): + def predict_batch( + self, + images: typing.Union[torch.Tensor, np.ndarray], + preprocess_fn: typing.Optional[typing.Callable] = None, + visualize_with: typing.Optional[str] = None, + image_paths: typing.Optional[typing.List[str]] = None + ) -> typing.List[pd.DataFrame]: """Predict a batch of images with the deepforest model. Args: @@ -887,6 +1079,9 @@ def predict_batch(self, images, preprocess_fn=None): #convert predictions to dataframes results = [utilities.read_file(pred) for pred in predictions if pred is not None] + if visualize_with == "kangas": + self.visualize_kangas(results, image_paths=image_paths) + return results def configure_optimizers(self): @@ -945,7 +1140,12 @@ def configure_optimizers(self): else: return optimizer - def evaluate(self, csv_file, root_dir, iou_threshold=None, savedir=None): + def evaluate(self, + csv_file, + root_dir, + iou_threshold=None, + savedir=None, + visualize_with=None): """Compute intersection-over-union and precision/recall for a given iou_threshold. @@ -971,5 +1171,12 @@ def evaluate(self, csv_file, root_dir, iou_threshold=None, savedir=None): root_dir=root_dir, iou_threshold=iou_threshold, numeric_to_label_dict=self.numeric_to_label_dict) + # If user wants Kangas visualization + if visualize_with == "kangas": + self.visualize_evaluation_kangas(predictions=predictions, + ground_df=ground_df, + root_dir=root_dir, + evaluation_results=results) return results + diff --git a/tests/test_main.py b/tests/test_main.py index 96fa67c1a..a3aee3369 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -864,4 +864,103 @@ def test_evaluate_on_epoch_interval(m): m.create_trainer() m.trainer.fit(m) assert m.trainer.logged_metrics["box_precision"] - assert m.trainer.logged_metrics["box_recall"] \ No newline at end of file + assert m.trainer.logged_metrics["box_recall"] + +# Test predict_dataloader with Kangas +def test_predict_dataloader_kangas(m, tmpdir): + """Test predict_dataloader triggers Kangas visualization.""" + csv_file = get_data("example.csv") + ds = dataset.TreeDataset(csv_file=csv_file, root_dir=os.path.dirname(csv_file), transforms=None, train=False) + loader = m.predict_dataloader(ds, visualize_with="kangas") + assert isinstance(loader, torch.utils.data.DataLoader), "Should return a DataLoader" + # Kangas UI should open; we verify no crash and correct type + +# Test predict_image with Kangas +def test_predict_image_kangas(m, tmpdir): + """Test predict_image triggers Kangas visualization and returns correct output.""" + image_path = get_data("2019_YELL_2_528000_4978000_image_crop2.png") + result = m.predict_image(path=image_path, visualize_with="kangas") + assert isinstance(result, pd.DataFrame) or result is None, "Should return DataFrame or None" + if result is not None: + assert "image_path" in result.columns, "Result should include image_path" + assert not result.empty, "Should predict trees with pre-trained model" + +# Test predict_file with Kangas +def test_predict_file_kangas(m, tmpdir): + """Test predict_file triggers Kangas visualization and returns correct output.""" + csv_file = get_data("OSBS_029.csv") + root_dir = os.path.dirname(csv_file) + result = m.predict_file(csv_file=csv_file, root_dir=root_dir, visualize_with="kangas") + assert isinstance(result, pd.DataFrame), "Should return a DataFrame" + assert "image_path" in result.columns, "Result should include image_path" + assert not result.empty, "Should predict trees with pre-trained model" + +# Test predict_tile with Kangas +def test_predict_tile_kangas(m, raster_path): + """Test predict_tile triggers Kangas visualization with mosaic=True.""" + result = m.predict_tile(raster_path=raster_path, patch_size=300, patch_overlap=0.1, visualize_with="kangas") + assert isinstance(result, pd.DataFrame) or result is None, "Should return DataFrame or None" + if result is not None: + assert "image_path" in result.columns, "Result should include image_path" + assert not result.empty, "Should predict trees with pre-trained model" + +# Test predict_tile no visualization with mosaic=False +def test_predict_tile_kangas_no_mosaic(m, raster_path): + """Test predict_tile doesn’t visualize with mosaic=False.""" + result = m.predict_tile(raster_path=raster_path, patch_size=300, patch_overlap=0.1, mosaic=False, visualize_with="kangas") + assert isinstance(result, list), "Should return a list of (prediction, crop) tuples" + assert all(isinstance(r[0], pd.DataFrame) and isinstance(r[1], np.ndarray) for r in result), "Each item should be (DataFrame, array)" + # Kangas won’t trigger; we verify output type only + +# Test predict_step with Kangas +def test_predict_step_kangas(m, tmpdir): + """Test predict_step triggers Kangas visualization.""" + image_path = get_data("2019_YELL_2_528000_4978000_image_crop2.png") + image = np.array(Image.open(image_path).convert("RGB")).astype("float32") + batch = torch.tensor(image).permute(2, 0, 1).unsqueeze(0) / 255 + result = m.predict_step(batch, 0, visualize_with="kangas", image_paths=[image_path]) + assert isinstance(result, list), "Should return a list of predictions" + assert all(isinstance(r, pd.DataFrame) for r in result), "Each item should be a DataFrame" + # Kangas UI should open; we verify no crash + +# Test predict_batch with Kangas +def test_predict_batch_kangas(m, tmpdir): + """Test predict_batch triggers Kangas visualization.""" + image_path = get_data("2019_YELL_2_528000_4978000_image_crop2.png") + image = np.array(Image.open(image_path).convert("RGB")).astype("float32") + images = np.stack([image] * 2) # Batch of 2 images + image_paths = [image_path, image_path] + result = m.predict_batch(images, visualize_with="kangas", image_paths=image_paths) + assert isinstance(result, list), "Should return a list of DataFrames" + assert all(isinstance(r, pd.DataFrame) for r in result), "Each item should be a DataFrame" + assert all("image_path" in r.columns for r in result if not r.empty), "Results should include image_path" + # Kangas UI should open; we verify no crash + +# Test evaluate with Kangas +def test_evaluate_kangas(m, tmpdir): + """Test evaluate triggers Kangas visualization with ground truth and metrics.""" + csv_file = get_data("OSBS_029.csv") + root_dir = os.path.dirname(csv_file) + result = m.evaluate(csv_file=csv_file, root_dir=root_dir, visualize_with="kangas") + assert isinstance(result, dict), "Should return a dict of evaluation metrics" + assert "box_precision" in result, "Should include precision" + assert "box_recall" in result, "Should include recall" + assert isinstance(result["results"], pd.DataFrame), "Results should be a DataFrame" + # Kangas UI should open with predictions, ground truth, and metrics; we verify no crash + +# Test without Kangas installed (mock Kangas unavailable) +def test_predict_image_no_kangas(m, tmpdir, monkeypatch): + """Test predict_image handles missing Kangas gracefully.""" + monkeypatch.setattr("deepforest.main.kg", None) # Simulate Kangas not installed + image_path = get_data("2019_YELL_2_528000_4978000_image_crop2.png") + result = m.predict_image(path=image_path, visualize_with="kangas") + assert isinstance(result, pd.DataFrame) or result is None, "Should return DataFrame or None despite no Kangas" + +# Test empty predictions with Kangas +def test_predict_image_empty_kangas(m, tmpdir): + """Test predict_image with empty predictions still triggers Kangas.""" + image = np.zeros((400, 400, 3), dtype=np.float32) # Black image, likely no predictions + result = m.predict_image(image=image, visualize_with="kangas") + assert result is None or (isinstance(result, pd.DataFrame) and result.empty), "Should return None or empty DataFrame" + + \ No newline at end of file From e89a87e0eb89fd02d0a6c7802dab3699ccbbd084 Mon Sep 17 00:00:00 2001 From: Dhiraj BM Date: Mon, 3 Mar 2025 17:36:20 +0530 Subject: [PATCH 2/2] Corrected and verified the test cases --- deepforest-evaluation.datagrid | Bin 0 -> 49152 bytes deepforest-predictions.datagrid | Bin 0 -> 32768 bytes evaluation-metrics.datagrid | Bin 0 -> 24576 bytes src/deepforest/main.py | 68 ++++++++++++++++++++++---------- tests/test_main.py | 23 ++++++----- 5 files changed, 58 insertions(+), 33 deletions(-) create mode 100644 deepforest-evaluation.datagrid create mode 100644 deepforest-predictions.datagrid create mode 100644 evaluation-metrics.datagrid diff --git a/deepforest-evaluation.datagrid b/deepforest-evaluation.datagrid new file mode 100644 index 0000000000000000000000000000000000000000..d645d2e867bf85b069ea0f9925854b4da6bc7990 GIT binary patch literal 49152 zcmeI5S!^BGdB?emB1Q7Xab!ofC|)aZV#V@s=FD!#iDg@I1%e`dE07`&1&X2z>I4P)QXoJOG%rDsKD0pEqClGhO&?km{e9=0 zx%XVz%v{$eL17FHle!CezkT_C-<(-E^X%%*O8e^O*80*;J0AGZz~JD(W9|09z`z9m zPx62DwU2*{R1f$$==^)He@+Ypi-)JCPYsMty*<$SMyoab{M6gi2lsk=^L!(LMgoll z8VNKKXe7``ppif$fkpz21R4pvy9CaU?7MqrX7H;!OPAMHwpVs`RySVVu6~X_arWei zb0^#9PCWkX$#(T@`+<$6^_BLylP{iYzi{^Sb0^MTY(I1I;_>zyOKUgH^T*~!_uc)( z%;3Q4#`4Nr+pn+DAucZ6+}YHR7puRt7*+q0&Tzj_xciymF(K${^O~adm7S&Kr5!%c z{2b|_sCw3+=y?0;>dM;kVu!@z?VT^(FjO6Hudi;buHRg5pL%xw#5o==z2!b!dShj4 z>D86$)i;*5R+l!ebRO?)?kufUFK+KFFR#2&J=)y4wzAcJdSU+gW4q%+V|O1qG`PF2 z8CViwZkzAJJ;a;m+7IZTi>u38pvT+hp+mEILWIzM{Mq@(t9NvGmrt#p+_`ph{qn}r z>Y9G*u`7dGmb-^EDe@;?-CA9)eh&4JQ$5>$;L7IO&Gn7NuzmXZb0?oVdG>hQJ`-6I z?L%c49z|=Tt)GbUCR_j2`dRCzt$%6#Wa57&E{*^B_@Vu8j{R^f+xOd}KN~$i@<$^d z9=f_msj#czhWK~@?dT0^2!=79ewKDVtAA%wy*F3JP`%qtxLPJA2k$y)GLLm zRGJP$ZkF?)qA->RUozw-VaS6z3hy=)-tCvd*horKs1gyVrBJ0tx>|-x(_Tm6U53KD zyiyoNqS#8yt_-6zdXF-kIjAX&4|=6ADIEfHN6C5m9_gF?h>_q&ywVqij-N!O5M^~7 z_2@Y1$4~Aw6yE8V!qlvl|SP4c^ACczBhYW!q@=9Q4oI%qUxq&8%oKP5Peb-3v za7z;ywY(B&f{6))kt;wa$|!B9b>vMO@}_@LhQb4WDa;%jj8ZpHmnO%U zC{xPY>L{Er6i)c1Fmk;(6rwt*#Fb$xbx#e2N5(aU>9|)4ZN4yJM!Q2rpmO{;N`#LZ z0%!Ld5$^X&V4OL*PKENoQrHuzRnkW)IXtE*jK;iDsC4?M9HmQN9GPm?RMir+YXms6 zPm>q#^GaUqW)iK`cavYl?rK>#aA388s1Qwbqg&T^#j-VHgSP(zz0|YQ?0KL3{U@Pp!MkVk6M5B0cNye zt&u<@fkpz21R4o65@;mQNT88GBY{Q&jRZcR5?C6Vyl=){X})`4b)9Pi7H=%=T&vdM z<2ALZN2N4Iyjj$XRO+QgO3n;Sc_K(4op^rOX_8>_G1 z)LZ4ZK0kQi(9EIn!NIYiv4yS8Hy>JEuGa6*&Rgk0ihYuX9-t^2~Z<>F8 z4;I%}w|C(2mdyW0THhRK{afpAT7S$xny*FzjRYDAG!kee&`6+>KqG-h0*wS32{aOD zB+y9Umo0&VBg1p8D_bi|oWRgI9$-5>JUo0bmqQY|7uB~<$^i~bf@It?H2mNU7u5r6 ztgkGZqy6r`ICua0-2KbbgTwc=uC3hMqW@pLvK`7HeKqG-h0*wS32{aODBv31XgTr%!?|=tJ_75K%6m($t&fy0K zZvz5`riSku?*afa|3AcGeyyeH?@k|?`s(E0O@8*k?=|QD^?gENbN=7^KjpDG|L^^t zli8gA=fMiy+`ta=KMboVX!&>@BL=u z=KR0+d+fKHNt*Nj-utoXJ@RYL|9j5@{Ggrx$2AP^ek#7{|K|r~t17y^B$b-fLAM{F~D?f8WcqKsR*lX2~2hV=WSm39;(x>bp z>srBa?7E0P^50sZxiLzO1ZR&K3Xl1vFgC7gC=8u2t#ZeSW9{B`6h2@me84M(k-m3_ zje>FL+yX#F=FD#GCb7PGK^=vkG!%Z)FNO9Pcq2o7z_2w#Wh11aP~ZNhj>1nE3P0hO zLaR$dVd~_%UKy5sFKA!!Uq3~l~DU98GXJWUdFfN^fA?oK0`y)p+h3Qeh6zULYWvFieRT0>8 zuLpgfk2*)W&xr6o&jdOpO{AP)Qwi7PWPMABIttr{!nR)ub+ytlL&lQWzD^QSl;9OJ7eg?CTizBZj;qe#vvsi5DSR@}k128Tu(1W@a>b@r+;c zLUTI@6RMqj9K{ZO`nEzf_IK}LP2b#Muk@vc3PYYFJxjKBkF<{vc<;T2x_iA+=XkHi zQ#!5zrPa-rbdggdyL;}@)ID_1FN(UznHAU)Op7k7??t5l|H#1i2U?5Me=$8Xb!+m+ zlaC(w-HHF6cxn7k#_!sHY^>OKdi14{uMGdj&^HJFkT?E6|Jr@!lnDE4=_#+mX38JL zNuI_@l$Tl2Q;BCqki}^fM@5{L=1xe~QYg}cJPb>f=TwyLJ(38of!Kb>a$eY4SzbK1 zwX)*gw=fK{qDYENMMaXTB#kp}V|e|fW^n$bR|ZoRBykuQ6v#;#s#xDB&L)Oj1)(Bh zL@Fy&?hxa|xgrhVuZ+U9;MOPh-bN;?g@`50!kibmXPW+{T5iiChk`=>7G0rmyY{5! z_RN!Bxh?Y`OOq@_m0^^p?#*z>bwQZtocK{jWvQZ`;!&ZH5(1Mf6El#EUe%IW9H=l) zXms)-z3nD(dnlI=$+9FbQ|u{C${2RnpU~_sJmHmHl@?epHzMk)!m2hZi$DQU(KL$E zd#qQLQIPUfreG|zdbjMVI1I|7L{HqRM}D!V*58xgOeJNiU^Wj`F6zDbxaRl4$G!3! z}7(am8wx1?WHQjbPeUQ3YWNr6fg4 zAnU!s;a>a_mO_*ntD+2JCG0+C*nP|^yRk|G!m3g#iBuGp&W*K_B!B{p8oI;82IVU9 z;*JR>i(!_T6@*o9p#qf@vEue{+^Npo@Tx8_C~}A0jn8U!&wth{yDG|p zD2MK{hj63aFbh)RBDAO6K+{lOX*Vf&4W007-0IHgu=3SJ;m0Yt3ERmdle}sWRMkeA_kdgUf!i*w9JGFm|d*^iWq=-1H z7^<9?pp{{o7(`s~qJ-}{UJGIcQCM3tIZTXs_EBMWH-FSCvjkF|F|0bBh7gT*bBXHf zF0G$76n@$(g~X07g*X-uDv#2Qeq~Y`EbGX1kNYh@VkmsXD}^d{DRfcLp1T=Gechh$ zVNGHBuvZH0I6x~o)+M+VVblXLc5lwL=WuR!rql$+rB?ziJYuX+-wM+b=uQd@z+9u^ zBZU!Q;g!C;XVT(0K6%ygl741EhjUF}lzS!6X$%Rej*qxbtc6)4znRSFH}gxL1idPZ z>gL|+*Ny;n(l-E49et^xFZD{FJL#hw>B=v43+lYSdEK2&G<|vEl|Ir?oF z`V~_)>EKf%!6UJzF#QGI{!h0dR>KsaT>{CSJoxzpO8wd^)Xo< zL`lNh1{G4(npI(!ZZA#fU6Zi;S|^cF`cf!{F%3!8|8yj_N9{`=p5}QjRTh;-oKi1b z|GegS{_|crrYj35^KNUvZ1s=_xhg1$lY}nS6m^ziN+yJn7b=$?pa*xLRiiF0RalhO zEA`yc&%$S((F<|QTDV<2uer_7d*#-2W>%rpvpJ1o9%gxw+q%DE4_HK%aT(@OmZqXM z$KSzm1TTb3A?W@Xm+1k@X9Vg2`y(y zWsSv23Gaf@0xD%0fP{c1z4oh6V)w~TVNZvH+`^85}HN*b92qYs7CwX;UC zXT5StUO=%F@oABjkwFZsVv9HsAH~E37Jzb$wK8Pj!AKyeyvVATEVoRHK%ofWU_%<2 z5gV^%7uX`nRFSBXX1Pd&-7|*WGk)16ihv2xr(k>1dYim$3(5@lRPesKwS5{DhnsAab(fe8UQz$}SN;o?#TY8N$dRtQ*I4KmsL}-aLy5W*vB{o>_3S8~_TiAr-TI<$29& z`n*?GdlC%QLLm8r;Hf#6*klb=@Dh+~@B_*7l0&-*2Cjk;7G)C4(8k0sW578EOc@>9zt}VR-iu-4p!^CoR~+R^N~HB? zjDnx>$}oeZfFV?z!gFWXWEl#bn-$+!6Bmy;GnRe;O!k<$Ph=P zpgpY8BM69_6RMh{0F)%^Qj*N|)0)-!(_UGPaI`!I_)TJ@3!-h5q@6cT zvVb{NmX#TOhndY;Q%-X{fKU{rAr(%Apc+dqXai)fBsO7-l6rQjmQvCHH5sEF@!ZcD z&;6WNZgCSP70kcDrEga(Ll@{ysRrzlj90U~ zcXX%KiJni*1^q!OYI5jG|5VHc5Z9bPqt@v2v5@mg{bO%0!@ z=TTk5pmRt2gg9H9Nx^L0J}zk&`5g>^EmouaDx&jJNl%f)NvOkpA9j%-HATWEfRy47 zdK7(Cv%7HBE4wtmOkB|)1(DPhJy>8EP{u<9)0h}|Mej&jhFL!dcv=GEZEnw`kMR9$EIv9}$gWnJNo#P1+l$SW{`V!ov9ZgR2)Ty;*hJkorYw7euLs})#VjQ+(3!dGy3#F&FpJ6!Alo%%;5GwBu3_((@8@+}%>kZxRL+lO@$L3bE*o z<(FxMhC38xQr~NRjMECjJh`3d7Yo|4aZ&So{-RfYp&LixV+g-aAFuhP&<9W_W>(U0 zYK)SdKZ<-}n~dy;rzo*XLuBewh5>Bq5%LXcTW`;>70oukZi%*xUQQKl%?N|G`VmS0jN&0v}`v?5^K5N$sZB zm`80z^-U|mB@ucYEP>o6aT%N{+0rFf>p>we*_=j&38u{ej+Jj-R!)qU%W$|vgz#=k zpnfn2M!?f0#JewE7NQn2%ps@wpfVPpo(nL*JW=RwlVqR zavUVi1WW3(>sy-F`7OV^LK;gNe3DyN8qLZ|m0<9c#xhuSsueD)0o4;+uX|xdmA<7Q zNEl>PLv(~!%PtiW9VCz$YUPpyG33{cA;0dGTZ*%g4Gw`YWFEy&2iC)DMHF;97T4-} zuFRXTd;CoA+LV5gi4+Q2zD>s(Co%Hyt1pWC%_)< zGKIQt{T4Wv=$Y1tu;tQ;jLU$LGh?Bi|UI-uFwE5Vtz#Tjd=Ox(yS!jRFOCP~#oCGi ztz6w!Q-n^u#YaJVs`T4meql?N={i>#74(~hOowk=*Zf_)?v+1j^aN7rlnRMX%lT%*_5RL02FEY4EECxzC&xxG*O0dSTUF3%4)~jYntUVuX$xz z!*enT1EIUER-%O$z=Nrdv#{;NV2TE=?#d)HxJ!hRn5e(lA?&P!)cPyUIIE%TXR@9} zY|}X9s$qB4E4vJFfi)0oMum6VF03G|EZ$YuDw-Po?O4wFx?kQ8XS WLX7J81vdoXp|XxhsKkL=m;M)$qHh)e literal 0 HcmV?d00001 diff --git a/deepforest-predictions.datagrid b/deepforest-predictions.datagrid new file mode 100644 index 0000000000000000000000000000000000000000..8bf952cd4e9c0497fd437f8c75778f5a54787fce GIT binary patch literal 32768 zcmeI5-;Nx`6~<@1#>OG`M2alrD3%6GmOwbY^{=bCa)AQaM1Zl0v6WZ|jmEnJo@oD= z-B~a~C^GMmt3-K=+~fjYBDq6(h6wQh`Bitj{?(OeP<|94$*=WbrHyybJpcOTwZUw6Yx?-$<83;`@#Jv6S3e%dBaLPJ z=?tNTPI2($m0M>+KwsXjBifnHCtH&_=WD*7P9loWb|Shq`fxVg-Wo4Lyf&JD_Na;K z+GuCCJKH(j8NG4y_N#Y!IC9cw-+xRY;x}A(HGZ7|LY@u5c+>E z`sc5VKt>=VkP*lTWCSt-8G(#IMj#`Q5y%K+1b)N_oFD#quzumY!SLRN?}q>Q5ic|^ zEhCT-$OvQvG6ETaj6g;pBajiu2xJ8Q?+_TD`soXo+uuQda<;i>|7bFQ82_Ps_}z%l zuWat^JldYlXM4K`Tz>!j3V*$SFvHZ-M-bcknP#QoRxeLuK%AN{%tV)&+wnazw<->$_QixG6ETa zj6g;pBajiu2xJ5@0vUmfKt|vPM&RP<^^M`?{&d0(7#4R2I1rcD*Dv~TkA&r~DvaLU zpKi@I>v6>|POZOuncr1lc(F4bH@Eh?dVk~U&c@ZP3v25y3?EJp_v!y%s8r;lGF94*&H7FEF1jBajiu2xJ5@0vUmfKt>=VkP*lTWCSt-8G(}#xVXNtc4R(q z`q}l1YhgUF{`2*h*H#7rr_QgxaBevO2&b^ddE=F=s`@4c^{oh$)~=9NRYp2*D_8PT zyfJ~`=XhHJQOK(FURRB3s6N>f#;;(I4D zuT7z3sf71d_#`yxtSY22&XwA;p5|i>WUQ#Dv{t3`rEZSc2bs}T>73=A9LOYrRX`?$ zQ$k4t5cLAvJ4Z%}WzHjvIw_CBijpI{QW&GkZcs{C z4sVCxRolnG-msZG>q%dzG@R5K z&9=7*rQk&2rFYE4FJJ}*OJ(-dInJ#{P6xeY65iAE6ETru-~QE_Z`sT5`4 zRuUC*QCOwH&=Xm#WDQV;Q(JE(e}&|B3!oIlb(H1^aFDmk#huxoKJ=JPQaN_LA+m(& zm9{_yeu&YEfGVSP!LvstA?P?o8O=ya`S()ux*&QK1JwvRm25j?w~a;9)pm?lps#ryQxBv6M|F!jbuxf>`EQ@!C_)l zRYB=NYpDVd7p#_1A*E>J%37MZELx!}Y9%R&mZFDFSTRaG(yJ6AWo?2MP;#K;Fbm;C zj}(RTDRi>bKI%GKpF*;z-)-&;Y;a|H(%!?ec&V(QidE1$p*hcM4pQ;jN{yj*!!=b3 zD}@G5yF^O@F@DK*`QPq63`-H%!tTI{u>JbeXjHHQoYz{4yHAKNJvx=Lq!CYFih`q} zCChEUD1B@)CzfU=suUiF&9~gDJ=uc;)&UYmgH3D`(+T}iIy^$Xrmk_pzhP;RTvH*o zRIv`L0{v^R6v2_UV7DwRVt8v`FjNQ9jQ}-vsY@kAQVOpHT#ot{Dee^TV`QTuMrc+7 z3Dv!*ED1QtTEgT|z63S^<(5g3@MNx_eQ0Z3*IUvWSUg+p%DVE%&cF&dMTJUzKphb~ zzg64%Gd<9t`Gil+r?!E7LT82871ArK?83>n*cgHhy9KqXYjQ`Ea4uD$C;_1wFN}M= z$#4k;rBavH^B2aWKV_)Cv$}MalhYZcyB}vsZN<{J8@XxN8cvrLoJNpzMVV+}Srp!q znv9z`TIEJG@Mx#j>o`PnB#%;M%`tm4U?X2n6Saq3%rCs1>3HnR%i5#y0i*QjaL>8S{>JmL? zp=(Q(CXvRjK=tuy%T@LSCFm?e0m2v}a#cz$pyI$NCY`4XhvL*RDtaXnb+o^ugT{d( z!aFmMjpvz*gvLnYUC{0U?ffW~K+|ZfP767%Roo z;pH7xm|&WNYfKatu_s4tZ9Gz~N}U}-TfIy&qRJarc!&3K^lxn7Xma>KPYhTOMheE! z&GGKUwyP4SWE3Rs^eJuOK_PPo)D zYUr?1{qU}!EMbRfDQekYGEWQ09Ttujh8ow=H7sNS^Y9$);Gw~^CLjF|0?kl>B3!F+ zKVTB0%b*EQN@zX-j1w4*|3~jl-8~1!v}_7<0!Wx>ON&?)x-wK;P%=%kgOQ4MWs2PJx4_Shn~D+O(0O3{Z-3p(c5;>O`(% z`$I}-^_r1kSW!>%5O>TH$*YjG#8u!-44nflj%MORZR<>laMxf)9Rr$>q>esOCxz3{ zrerR%m@PF_ufm)L=0eiCZ|V;V$f$Q@l6|z2Na$=~x)R(lcExsSHaSjBQ}>p0j0hOp z^T^bL&ZGze?Xb&>+S-E@B?We1NG*Hc%Wm<8sEIkL$=oI}Ok9}zv0^%&X1LhNqLhWE z!(=t>B&bbca65bs;zVE9j}_Tr>r>H4Hbbm#VC0l+V43J(86tBI4$}q}MhU^OByKD6 zUb%WWf|AkBwstFq=7^0=M3Z zsuuYrUdk8-Oi&1M1~-PA^~ahjPSGSHyohRTEwAL&KtyYqLx&>03xx~CN-hpa1Y1w-4fv_cZ2*T4(bz@@R6>JwKI|v8$w^?*8=*rUa z18ZmK(z404svYIYli-d)>x+RgED{z*HM}{QcyQuSGhst=d7}vfe*~)CMmmxjXT!t} z;xw0Jnvq1KOkpo*gtj|qe|lVD_KbN)z#IYCR0kLm7&y&_w)JGarbHbUt`r$C;d%;H zhQT^7MP;Jg$zUizHNMnBh4W?qa?ygCUvGtgcTZdteiN=X1wySDW~wvAk0B|8!loTq z1_m99V73Jlv(q(aHz3AG??EKnGZEEGptQE|*Q z0wcm`36o3_=$1W-LMcUa1K#pKB9ArdPzN{ zETB&9NWYUhAjh2pHy0A?)}>>DDP0ha-09BztA|t!s`E^KXb z-MQr2LPlGHW-{}nn_QH%HZ*U-5DBdmfhDne*BrynQvDDi8qDT8XFtanc2GuODP|kH z*G(q9GU~iIcXF*XB!v_MN~bkxBV$G3LML5CJH}bEANhluLC74nwH;XYV8$S{q$1&8 znd&6uSG4O~J@MUMHj<3%>Rdi#(BAQpaa$5I3};ryR3?nJmvK==3g~5-@`MWwCOD41 Zut+(0bLkx^NiULcNu3MZjW21?`Y)%#9r6GG literal 0 HcmV?d00001 diff --git a/evaluation-metrics.datagrid b/evaluation-metrics.datagrid new file mode 100644 index 0000000000000000000000000000000000000000..6d532261d25d07f3b534a0f7c7f9be5033335325 GIT binary patch literal 24576 zcmeI)&u`mg7zc1WcI$N`@g|6(0G8eaRkT`I+HpYQK-O-iDs2~*g{{>zS#JEUwd5b2 z{kjy5357p_BXAhIfjCd%Fzv*R0|z(_AtZ!g69>d;N8Z=YTUxrp1!)L#gU^Fy;zSt;MGJagFYh(iDZ5P$##AOHafKmY;|fWTt~-qOcwmZiSO-1UH_lrcZr zNM${Cwr9`vZPK5cU$%*ilBviIDe2p*eez~+ab>QzMwaZg7TI=#tcXvy$Hr@CEmiR& zkM5+`1Ac?(xEYJ{z9Y}%%*vLyF+3s?9KNO*BBA%Yc}5{+uIDm-y`t9lWhA3}GHQ`) zJ`Fr)H^mlV?`#!WwMghkewc;i_2uqdpLgAzypT*1#KRDuZk@3)dsp5At27JvFx+K9{E=PO*{ro~%pz(veXzHHd>OjGP&=c)o8h z*u54RMP^4Sj=HCd(NCB8&sH+OQp}(D7YYO*009U<00Izz00bZa0SG_<0{?k|T1i(M zMrl3XaS}RkgP`;CM+>Ws4}Pqx`oxQdh;Q)?e(J|j=fR!(;hp=Dxc}!5{ZSwQ0SG_< z0uX=z1Rwwb2tWV=5cnSpl+8~Rt@5s7E>s@!3IzfXfB*y_009U<00Izz00bcLuL}g) zv6@vg^0IVE{Qp3{oSwhpFP~n^N7bEQK0UGi>nl^83tzr)`_$!+JC8>FTi@OK{Iheb zH=2Vui&*nja=Jx|wv$Ev^^A(p%`2s+>Q=QF#L!rwEb#~O-FWeL2yuA06;Q^Xj;H*H z8%_RorJ1s%dG)3|;bA9VRgcxJ`h;p2sChE*RO{kXiu^MBpEt(d==-9Zy6 zq{Wv6;8PhPb&93^CfC}NVSCc6s9M9^q*>yp%pasPVv#@Hyr-D=% 1: + images, _, paths = batch + else: + images = batch if isinstance(batch, (tuple, list)) else batch + paths = [ds.paths[i] for i in range(batch_idx * self.config["batch_size"], + min((batch_idx + 1) * self.config["batch_size"], len(ds.paths)))] + batch_predictions = self.predict_step(images, batch_idx) + for pred, path in zip(batch_predictions, paths): + pred['image_path'] = path + predictions.append(pred) if predictions: - flattened_predictions = [item for batch in predictions for item in batch] - # Extract image paths from dataset if available, else use placeholder - image_paths = [ - getattr(ds, "paths", - [f"image_{i}" - for i in range(len(flattened_predictions))])[i] - for i in range(len(flattened_predictions)) - ] - self.visualize_kangas(flattened_predictions, image_paths) - + self.visualize_kangas(predictions) return loader def predict_image(self, @@ -1028,15 +1030,28 @@ def predict_step( image_paths: typing.Optional[typing.List[str]] = None ) -> typing.List[pd.DataFrame]: - batch_results = self.model(batch) + self.model.eval() # Ensure evaluation mode + with torch.no_grad(): + batch_results = self.model(batch) # Returns list of dicts: [{"boxes": tensor, "labels": tensor, "scores": tensor}, ...] results = [] for result in batch_results: - boxes = visualize.format_boxes(result) - results.append(boxes) + # Ensure result is a dict from RetinaNet + if isinstance(result, dict): + boxes = visualize.format_boxes(result) + results.append(boxes) + else: + # Handle unexpected format (e.g., empty or malformed output) + empty_df = pd.DataFrame(columns=["xmin", "ymin", "xmax", "ymax", "label", "score"]) + results.append(empty_df) + + if visualize_with == "kangas" and image_paths: + if len(image_paths) != len(results): + raise ValueError(f"Length of image_paths ({len(image_paths)}) must match predictions ({len(results)})") + for pred, path in zip(results, image_paths): + pred['image_path'] = path + self.visualize_kangas(results) - if visualize_with == "kangas": - self.visualize_kangas(results, image_paths) return results def predict_batch( @@ -1076,12 +1091,23 @@ def predict_batch( with torch.no_grad(): predictions = self.predict_step(images, 0) - #convert predictions to dataframes - results = [utilities.read_file(pred) for pred in predictions if pred is not None] + # Handle predictions, including empty ones + results = [] + for i, pred in enumerate(predictions): + if pred is None or pred.empty: + # Create an empty DataFrame with expected columns + empty_df = pd.DataFrame(columns=["xmin", "ymin", "xmax", "ymax", "label", "score", "image_path"]) + if image_paths and i < len(image_paths): + empty_df["image_path"] = [image_paths[i]] + results.append(empty_df) + else: + pred_df = utilities.read_file(pred) + if image_paths and i < len(image_paths): + pred_df["image_path"] = image_paths[i] + results.append(pred_df) if visualize_with == "kangas": - self.visualize_kangas(results, image_paths=image_paths) - + self.visualize_kangas(results) return results def configure_optimizers(self): diff --git a/tests/test_main.py b/tests/test_main.py index a3aee3369..7e3cc07db 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -348,7 +348,7 @@ def test_predict_tile_from_array(m, raster_path): m.create_trainer() prediction = m.predict_tile(image=image, patch_size=300) - + assert not prediction.empty @@ -719,13 +719,13 @@ def test_batch_prediction(m, raster_path): tile = np.array(Image.open(raster_path)) ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=300) dl = DataLoader(ds, batch_size=3) - + # Perform prediction predictions = [] for batch in dl: prediction = m.predict_batch(batch) predictions.append(prediction) - + # Check results assert len(predictions) == len(dl) for batch_pred in predictions: @@ -739,21 +739,21 @@ def test_batch_inference_consistency(m, raster_path): tile = np.array(Image.open(raster_path)) ds = dataset.TileDataset(tile=tile, patch_overlap=0.1, patch_size=300) dl = DataLoader(ds, batch_size=4) - + batch_predictions = [] for batch in dl: prediction = m.predict_batch(batch) batch_predictions.extend(prediction) - + single_predictions = [] for image in ds: image = image.permute(1,2,0).numpy() * 255 prediction = m.predict_image(image=image) single_predictions.append(prediction) - + batch_df = pd.concat(batch_predictions, ignore_index=True) single_df = pd.concat(single_predictions, ignore_index=True) - + # Make all xmin, ymin, xmax, ymax integers for col in ["xmin", "ymin", "xmax", "ymax"]: batch_df[col] = batch_df[col].astype(int) @@ -865,12 +865,13 @@ def test_evaluate_on_epoch_interval(m): m.trainer.fit(m) assert m.trainer.logged_metrics["box_precision"] assert m.trainer.logged_metrics["box_recall"] - + # Test predict_dataloader with Kangas def test_predict_dataloader_kangas(m, tmpdir): """Test predict_dataloader triggers Kangas visualization.""" csv_file = get_data("example.csv") ds = dataset.TreeDataset(csv_file=csv_file, root_dir=os.path.dirname(csv_file), transforms=None, train=False) + ds.paths = [os.path.join(os.path.dirname(csv_file), img) for img in ds.annotations.image_path.unique()] loader = m.predict_dataloader(ds, visualize_with="kangas") assert isinstance(loader, torch.utils.data.DataLoader), "Should return a DataLoader" # Kangas UI should open; we verify no crash and correct type @@ -921,7 +922,7 @@ def test_predict_step_kangas(m, tmpdir): result = m.predict_step(batch, 0, visualize_with="kangas", image_paths=[image_path]) assert isinstance(result, list), "Should return a list of predictions" assert all(isinstance(r, pd.DataFrame) for r in result), "Each item should be a DataFrame" - # Kangas UI should open; we verify no crash + # Test predict_batch with Kangas def test_predict_batch_kangas(m, tmpdir): @@ -934,7 +935,7 @@ def test_predict_batch_kangas(m, tmpdir): assert isinstance(result, list), "Should return a list of DataFrames" assert all(isinstance(r, pd.DataFrame) for r in result), "Each item should be a DataFrame" assert all("image_path" in r.columns for r in result if not r.empty), "Results should include image_path" - # Kangas UI should open; we verify no crash + assert any(not r.empty for r in result), "At least one result should have predictions" # Test evaluate with Kangas def test_evaluate_kangas(m, tmpdir): @@ -962,5 +963,3 @@ def test_predict_image_empty_kangas(m, tmpdir): image = np.zeros((400, 400, 3), dtype=np.float32) # Black image, likely no predictions result = m.predict_image(image=image, visualize_with="kangas") assert result is None or (isinstance(result, pd.DataFrame) and result.empty), "Should return None or empty DataFrame" - - \ No newline at end of file