Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/deepforest/datasets/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,8 @@ def __init__(self, path, patch_size, patch_overlap):
def prepare_items(self):
# Get raster shape without keeping file open
with rio.open(self.path) as src:
width = src.shape[0]
height = src.shape[1]
height = src.shape[0]
width = src.shape[1]

# Check is tiled
if not src.is_tiled:
Expand All @@ -460,8 +460,8 @@ def prepare_items(self):

# Generate sliding windows
self.windows = slidingwindow.generateForSize(
height,
width,
height,
dimOrder=slidingwindow.DimOrder.ChannelHeightWidth,
maxWindowSize=self.patch_size,
overlapPercent=self.patch_overlap,
Expand Down
4 changes: 2 additions & 2 deletions src/deepforest/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,8 @@ def split_raster(
)
numpy_image = numpy_image.transpose(2, 0, 1)

# Check that it's 3 bands
bands = numpy_image.shape[2]
# Check that it's 3 bands (image is now channels-first: C x H x W)
bands = numpy_image.shape[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take some time to understand this part of the code, you may be able to make more meaningful improvements to the logic.

if not bands == 3:
warnings.warn(
f"Input image had non-3 band shape of {numpy_image.shape}, selecting first 3 bands",
Expand Down
71 changes: 55 additions & 16 deletions tests/test_datasets_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,60 +2,99 @@
import os

import numpy as np
from PIL import Image
import pytest
from PIL import Image

from deepforest import get_data
from deepforest.datasets.prediction import TiledRaster, SingleImage, MultiImage, FromCSVFile, PredictionDataset
from deepforest.datasets.prediction import (
FromCSVFile,
MultiImage,
SingleImage,
TiledRaster,
)


def test_TiledRaster():
tile_path = get_data("test_tiled.tif")
ds = TiledRaster(path=tile_path,
patch_size=300,
patch_overlap=0)
ds = TiledRaster(path=tile_path, patch_size=300, patch_overlap=0)
assert len(ds) == 16

# assert crop shape
assert ds[1].shape == (3, 300, 300)


def test_TiledRaster_non_square(tmp_path):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m having difficulty seeing how this test is necessary given the changes made.

import rasterio as rio
from rasterio.transform import from_origin

transform = from_origin(0, 0, 1, 1)
raster_path = str(tmp_path / "non_square.tif")
# create a 3 band, 500 height, 800 width raster
with rio.open(
raster_path,
"w",
driver="GTiff",
height=500,
width=800,
count=3,
dtype=np.uint8,
crs="+proj=latlong",
transform=transform,
tiled=True,
blockxsize=256,
blockysize=256,
) as dst:
dst.write(np.zeros((3, 500, 800), dtype=np.uint8))

ds = TiledRaster(path=raster_path, patch_size=400, patch_overlap=0)

# width 800 -> 2 patches. height 500 -> 2 patches. Total = 4 patches
assert len(ds) == 4


def test_SingleImage_path():
ds = SingleImage(
path=get_data("OSBS_029.png"),
patch_size=300,
patch_overlap=0)
ds = SingleImage(path=get_data("OSBS_029.png"), patch_size=300, patch_overlap=0)

assert len(ds) == 4
assert ds[0].shape == (3, 300, 300)

for i in range(len(ds)):
assert ds.get_crop(i).shape == (3, 300, 300)


def test_invalid_image_shape():
# Not 3 channels
test_data = (np.random.rand(300, 300, 4) * 255).astype(np.uint8)
with pytest.raises(ValueError):
SingleImage(image=Image.fromarray(test_data))


def test_valid_image():
# 8-bit, HWC
test_data = np.random.randint(0, 256, (300,300,3)).astype(np.uint8)
test_data = np.random.randint(0, 256, (300, 300, 3)).astype(np.uint8)
SingleImage(image=Image.fromarray(test_data))


def test_valid_array():
# 8-bit, HWC
test_data = np.random.randint(0, 256, (300,300,3)).astype(np.uint8)
test_data = np.random.randint(0, 256, (300, 300, 3)).astype(np.uint8)
SingleImage(image=test_data)


def test_MultiImage():
ds = MultiImage(paths=[get_data("OSBS_029.png"), get_data("OSBS_029.png")],
patch_size=300,
patch_overlap=0)
ds = MultiImage(
paths=[get_data("OSBS_029.png"), get_data("OSBS_029.png")],
patch_size=300,
patch_overlap=0,
)
# 2 windows each image 2 * 2 = 4
assert len(ds) == 2
assert ds[0][0].shape == (3, 300, 300)


def test_FromCSVFile():
ds = FromCSVFile(csv_file=get_data("example.csv"),
root_dir=os.path.dirname(get_data("example.csv")))
ds = FromCSVFile(
csv_file=get_data("example.csv"),
root_dir=os.path.dirname(get_data("example.csv")),
)
assert len(ds) == 1
Loading