diff --git a/src/deepforest/main.py b/src/deepforest/main.py index 0c6fa9b1a..9e62a09eb 100644 --- a/src/deepforest/main.py +++ b/src/deepforest/main.py @@ -515,6 +515,7 @@ def predict_tile( iou_threshold=0.15, dataloader_strategy="single", crop_model=None, + project=False, ): """For images too large to input into the model, predict_tile cuts the image into overlapping windows, predicts trees on each window and @@ -531,9 +532,10 @@ def predict_tile( - "batch" loads the entire image into GPU memory and creates views of an image as batch, requires in the entire tile to fit into GPU memory. CPU parallelization is possible for loading images. - "window" loads only the desired window of the image from the raster dataset. Most memory efficient option, but cannot parallelize across windows. crop_model: a deepforest.model.CropModel object to predict on crops + project (bool): If True, return a geopandas.GeoDataFrame with geometry column projected to the image CRS. Defaults to False. Returns: - pd.DataFrame or tuple: Predictions dataframe or (predictions, crops) tuple + pd.DataFrame or tuple: Predictions dataframe or (predictions, crops) tuple. If project=True, returns a geopandas.GeoDataFrame. """ self.model.eval() self.model.nms_thresh = self.config.nms_thresh @@ -660,6 +662,14 @@ def predict_tile( formatted_results = utilities.read_file(cropmodel_results, root_dir=root_dir) + if project: + if root_dir is None and isinstance(paths[0], str): + root_dir = os.path.dirname(paths[0]) + + formatted_results = utilities.image_to_geo_coordinates( + formatted_results, root_dir=root_dir + ) + return formatted_results def training_step(self, batch, batch_idx): diff --git a/tests/test_main.py b/tests/test_main.py index 4aea9480a..535016def 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -473,6 +473,24 @@ def test_predict_tile(m, path, dataloader_strategy): plot_results(prediction, show=False) +def test_predict_tile_projected(m): + """Test that project=True returns a GeoDataFrame with projected coordinates""" + m.create_model() + m.create_trainer() + m.load_model("weecology/deepforest-tree") + + raster_path = get_data("OSBS_029.tif") + + results = m.predict_tile(path=raster_path, patch_size=300, patch_overlap=0.1, project=True) + + import geopandas as gpd + assert isinstance(results, gpd.GeoDataFrame) + assert results.crs is not None + assert "geometry" in results.columns + + # check that coordinates are large (UTM), not pixels + # pixel 0 is usually ~0. UTM 0 is ~400,000 meters. + assert results.iloc[0]["xmin"] > 10000 # Add predict_tile for serial single dataloader strategy def test_predict_tile_serial_single(m):