diff --git a/src/deepforest/datasets/prediction.py b/src/deepforest/datasets/prediction.py index db754ba85..a7efcb0cc 100644 --- a/src/deepforest/datasets/prediction.py +++ b/src/deepforest/datasets/prediction.py @@ -444,8 +444,8 @@ def __init__(self, path, patch_size, patch_overlap): def prepare_items(self): # Get raster shape without keeping file open with rio.open(self.path) as src: - width = src.shape[0] - height = src.shape[1] + height = src.shape[0] + width = src.shape[1] # Check is tiled if not src.is_tiled: @@ -460,8 +460,8 @@ def prepare_items(self): # Generate sliding windows self.windows = slidingwindow.generateForSize( - height, width, + height, dimOrder=slidingwindow.DimOrder.ChannelHeightWidth, maxWindowSize=self.patch_size, overlapPercent=self.patch_overlap, diff --git a/src/deepforest/preprocess.py b/src/deepforest/preprocess.py index d8bb51692..b51249eed 100644 --- a/src/deepforest/preprocess.py +++ b/src/deepforest/preprocess.py @@ -206,8 +206,8 @@ def split_raster( ) numpy_image = numpy_image.transpose(2, 0, 1) - # Check that it's 3 bands - bands = numpy_image.shape[2] + # Check that it's 3 bands (image is now channels-first: C x H x W) + bands = numpy_image.shape[0] if not bands == 3: warnings.warn( f"Input image had non-3 band shape of {numpy_image.shape}, selecting first 3 bands", diff --git a/tests/test_datasets_prediction.py b/tests/test_datasets_prediction.py index f98ab3f72..5c42dd634 100644 --- a/tests/test_datasets_prediction.py +++ b/tests/test_datasets_prediction.py @@ -2,28 +2,58 @@ import os import numpy as np -from PIL import Image import pytest +from PIL import Image from deepforest import get_data -from deepforest.datasets.prediction import TiledRaster, SingleImage, MultiImage, FromCSVFile, PredictionDataset +from deepforest.datasets.prediction import ( + FromCSVFile, + MultiImage, + SingleImage, + TiledRaster, +) def test_TiledRaster(): tile_path = get_data("test_tiled.tif") - ds = TiledRaster(path=tile_path, - patch_size=300, - patch_overlap=0) + ds = TiledRaster(path=tile_path, patch_size=300, patch_overlap=0) assert len(ds) == 16 # assert crop shape assert ds[1].shape == (3, 300, 300) + +def test_TiledRaster_non_square(tmp_path): + import rasterio as rio + from rasterio.transform import from_origin + + transform = from_origin(0, 0, 1, 1) + raster_path = str(tmp_path / "non_square.tif") + # create a 3 band, 500 height, 800 width raster + with rio.open( + raster_path, + "w", + driver="GTiff", + height=500, + width=800, + count=3, + dtype=np.uint8, + crs="+proj=latlong", + transform=transform, + tiled=True, + blockxsize=256, + blockysize=256, + ) as dst: + dst.write(np.zeros((3, 500, 800), dtype=np.uint8)) + + ds = TiledRaster(path=raster_path, patch_size=400, patch_overlap=0) + + # width 800 -> 2 patches. height 500 -> 2 patches. Total = 4 patches + assert len(ds) == 4 + + def test_SingleImage_path(): - ds = SingleImage( - path=get_data("OSBS_029.png"), - patch_size=300, - patch_overlap=0) + ds = SingleImage(path=get_data("OSBS_029.png"), patch_size=300, patch_overlap=0) assert len(ds) == 4 assert ds[0].shape == (3, 300, 300) @@ -31,31 +61,40 @@ def test_SingleImage_path(): for i in range(len(ds)): assert ds.get_crop(i).shape == (3, 300, 300) + def test_invalid_image_shape(): # Not 3 channels test_data = (np.random.rand(300, 300, 4) * 255).astype(np.uint8) with pytest.raises(ValueError): SingleImage(image=Image.fromarray(test_data)) + def test_valid_image(): # 8-bit, HWC - test_data = np.random.randint(0, 256, (300,300,3)).astype(np.uint8) + test_data = np.random.randint(0, 256, (300, 300, 3)).astype(np.uint8) SingleImage(image=Image.fromarray(test_data)) + def test_valid_array(): # 8-bit, HWC - test_data = np.random.randint(0, 256, (300,300,3)).astype(np.uint8) + test_data = np.random.randint(0, 256, (300, 300, 3)).astype(np.uint8) SingleImage(image=test_data) + def test_MultiImage(): - ds = MultiImage(paths=[get_data("OSBS_029.png"), get_data("OSBS_029.png")], - patch_size=300, - patch_overlap=0) + ds = MultiImage( + paths=[get_data("OSBS_029.png"), get_data("OSBS_029.png")], + patch_size=300, + patch_overlap=0, + ) # 2 windows each image 2 * 2 = 4 assert len(ds) == 2 assert ds[0][0].shape == (3, 300, 300) + def test_FromCSVFile(): - ds = FromCSVFile(csv_file=get_data("example.csv"), - root_dir=os.path.dirname(get_data("example.csv"))) + ds = FromCSVFile( + csv_file=get_data("example.csv"), + root_dir=os.path.dirname(get_data("example.csv")), + ) assert len(ds) == 1