Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 103 additions & 52 deletions pysheds/_sgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -1939,68 +1939,119 @@ def _wrapper(dem, mask, *args):
@pfwrapper
@njit(cache=True)
def _priority_flood(dem, dem_mask, tuple_type):
open_cells = typedlist.List.empty_list(tuple_type) # Priority queue
pits = typedlist.List.empty_list(tuple_type) # FIFO queue
closed_cells = dem_mask.copy()
isertn = count()

# Push the edges onto priority queue
y, x = dem.shape

edge = _left(dem_mask)[:-1]
for row, col in zip(count(), edge):
if col >= 0:
open_cells.append((dem[row, col], next(isertn), row, col))
closed_cells[row, col] = True
edge = _bottom(dem_mask)[:-1]
for row, col in zip(edge, count()):
if row >= 0:
open_cells.append((dem[row, col], next(isertn), row, col))
closed_cells[row, col] = True
edge = np.flip(_right(dem_mask))[:-1]
for row, col in zip(count(y - 1, step=-1), edge):
if col >= 0:
open_cells.append((dem[row, col], next(isertn), row, col))
closed_cells[row, col] = True
edge = np.flip(_top(dem_mask))[:-1]
for row, col in zip(edge, count(x - 1, step=-1)):
if row >= 0:
open_cells.append((dem[row, col], next(isertn), row, col))
closed_cells[row, col] = True
heapify(open_cells)

row_offsets = np.array([-1, -1, 0, 1, 1, 1, 0, -1])
col_offsets = np.array([0, 1, 1, 1, 0, -1, -1, -1])

pits_pos = 0
while open_cells or pits_pos < len(pits):
if pits_pos < len(pits):
elv, _, i, j = pits[pits_pos]
pits_pos += 1
# Priority queue
pq = []
isertn = 0

closed = dem_mask.astype(np.uint8)

# Vectorized edge extraction
left_idx = _left(dem_mask)[:-1]
bottom_idx = _bottom(dem_mask)[:-1]
right_idx = np.flip(_right(dem_mask))[:-1]
top_idx = np.flip(_top(dem_mask))[:-1]

# Left edge
rows = np.arange(left_idx.size)
mask = left_idx >= 0
rows_m = rows[mask]
cols_m = left_idx[mask]
for k in range(rows_m.size):
r = rows_m[k]
c = cols_m[k]
pq.append((dem[r, c], isertn, r, c))
isertn += 1
closed[r, c] = 1

# Bottom edge
cols = np.arange(bottom_idx.size)
mask = bottom_idx >= 0
rows_m = bottom_idx[mask]
cols_m = cols[mask]
for k in range(rows_m.size):
r = rows_m[k]
c = cols_m[k]
pq.append((dem[r, c], isertn, r, c))
isertn += 1
closed[r, c] = 1

# Right edge
rows = np.arange(y-1, -1, -1)[:right_idx.size]
mask = right_idx >= 0
rows_m = rows[mask]
cols_m = right_idx[mask]
for k in range(rows_m.size):
r = rows_m[k]
c = cols_m[k]
pq.append((dem[r, c], isertn, r, c))
isertn += 1
closed[r, c] = 1

# Top edge
cols = np.arange(x-1, -1, -1)[:top_idx.size]
mask = top_idx >= 0
rows_m = top_idx[mask]
cols_m = cols[mask]
for k in range(rows_m.size):
r = rows_m[k]
c = cols_m[k]
pq.append((dem[r, c], isertn, r, c))
isertn += 1
closed[r, c] = 1

heapify(pq)

pits_i = np.empty(dem.size, np.int32)
pits_j = np.empty(dem.size, np.int32)
pits_elev = np.empty(dem.size, dem.dtype)
pits_head = 0
pits_tail = 0

ROW_OFS = np.array([-1, -1, 0, 1, 1, 1, 0, -1], dtype=np.int8)
COL_OFS = np.array([0, 1, 1, 1, 0, -1, -1, -1], dtype=np.int8)

while pq or pits_head != pits_tail:

if pits_head != pits_tail:
elev = pits_elev[pits_head]
i = pits_i[pits_head]
j = pits_j[pits_head]
pits_head += 1
else:
elv, _, i, j = heappop(open_cells)
elev, _, i, j = heappop(pq)

# Neighbor expansion
for n in range(8):
row = i + row_offsets[n]
col = j + col_offsets[n]
r = i + ROW_OFS[n]
c = j + COL_OFS[n]

if row < 0 or row >= y or col < 0 or col >= x:
if r < 0 or r >= y or c < 0 or c >= x:
continue

if dem_mask[row, col] or closed_cells[row, col]:
if closed[r, c] or dem_mask[r, c]:
continue

if dem[row, col] <= elv:
dem[row, col] = elv
pits.append((elv, next(isertn), row, col))
dval = dem[r, c]

if dval <= elev:
# Fill pit
dem[r, c] = elev
pits_elev[pits_tail] = elev
pits_i[pits_tail] = r
pits_j[pits_tail] = c
pits_tail += 1
else:
heappush(open_cells, (dem[row, col], next(isertn), row, col))
closed_cells[row, col] = True
# Add to PQ
heappush(pq, (dval, isertn, r, c))
isertn += 1

closed[r, c] = 1

# pits book-keeping
if pits_pos == len(pits) and len(pits) > 1024:
# Queue is empty, lets clear it out
pits.clear()
pits_pos = 0
if pits_head == pits_tail and pits_tail > 1024:
# Queue is empty, clear it out
pits_head = 0
pits_tail = 0

return dem
return dem