diff --git a/pysheds/_sgrid.py b/pysheds/_sgrid.py index aa5fe9e..4e5f8af 100644 --- a/pysheds/_sgrid.py +++ b/pysheds/_sgrid.py @@ -1939,68 +1939,119 @@ def _wrapper(dem, mask, *args): @pfwrapper @njit(cache=True) def _priority_flood(dem, dem_mask, tuple_type): - open_cells = typedlist.List.empty_list(tuple_type) # Priority queue - pits = typedlist.List.empty_list(tuple_type) # FIFO queue - closed_cells = dem_mask.copy() - isertn = count() - - # Push the edges onto priority queue y, x = dem.shape - edge = _left(dem_mask)[:-1] - for row, col in zip(count(), edge): - if col >= 0: - open_cells.append((dem[row, col], next(isertn), row, col)) - closed_cells[row, col] = True - edge = _bottom(dem_mask)[:-1] - for row, col in zip(edge, count()): - if row >= 0: - open_cells.append((dem[row, col], next(isertn), row, col)) - closed_cells[row, col] = True - edge = np.flip(_right(dem_mask))[:-1] - for row, col in zip(count(y - 1, step=-1), edge): - if col >= 0: - open_cells.append((dem[row, col], next(isertn), row, col)) - closed_cells[row, col] = True - edge = np.flip(_top(dem_mask))[:-1] - for row, col in zip(edge, count(x - 1, step=-1)): - if row >= 0: - open_cells.append((dem[row, col], next(isertn), row, col)) - closed_cells[row, col] = True - heapify(open_cells) - - row_offsets = np.array([-1, -1, 0, 1, 1, 1, 0, -1]) - col_offsets = np.array([0, 1, 1, 1, 0, -1, -1, -1]) - - pits_pos = 0 - while open_cells or pits_pos < len(pits): - if pits_pos < len(pits): - elv, _, i, j = pits[pits_pos] - pits_pos += 1 + # Priority queue + pq = [] + isertn = 0 + + closed = dem_mask.astype(np.uint8) + + # Vectorized edge extraction + left_idx = _left(dem_mask)[:-1] + bottom_idx = _bottom(dem_mask)[:-1] + right_idx = np.flip(_right(dem_mask))[:-1] + top_idx = np.flip(_top(dem_mask))[:-1] + + # Left edge + rows = np.arange(left_idx.size) + mask = left_idx >= 0 + rows_m = rows[mask] + cols_m = left_idx[mask] + for k in range(rows_m.size): + r = rows_m[k] + c = cols_m[k] + pq.append((dem[r, c], isertn, r, c)) + isertn += 1 + closed[r, c] = 1 + + # Bottom edge + cols = np.arange(bottom_idx.size) + mask = bottom_idx >= 0 + rows_m = bottom_idx[mask] + cols_m = cols[mask] + for k in range(rows_m.size): + r = rows_m[k] + c = cols_m[k] + pq.append((dem[r, c], isertn, r, c)) + isertn += 1 + closed[r, c] = 1 + + # Right edge + rows = np.arange(y-1, -1, -1)[:right_idx.size] + mask = right_idx >= 0 + rows_m = rows[mask] + cols_m = right_idx[mask] + for k in range(rows_m.size): + r = rows_m[k] + c = cols_m[k] + pq.append((dem[r, c], isertn, r, c)) + isertn += 1 + closed[r, c] = 1 + + # Top edge + cols = np.arange(x-1, -1, -1)[:top_idx.size] + mask = top_idx >= 0 + rows_m = top_idx[mask] + cols_m = cols[mask] + for k in range(rows_m.size): + r = rows_m[k] + c = cols_m[k] + pq.append((dem[r, c], isertn, r, c)) + isertn += 1 + closed[r, c] = 1 + + heapify(pq) + + pits_i = np.empty(dem.size, np.int32) + pits_j = np.empty(dem.size, np.int32) + pits_elev = np.empty(dem.size, dem.dtype) + pits_head = 0 + pits_tail = 0 + + ROW_OFS = np.array([-1, -1, 0, 1, 1, 1, 0, -1], dtype=np.int8) + COL_OFS = np.array([0, 1, 1, 1, 0, -1, -1, -1], dtype=np.int8) + + while pq or pits_head != pits_tail: + + if pits_head != pits_tail: + elev = pits_elev[pits_head] + i = pits_i[pits_head] + j = pits_j[pits_head] + pits_head += 1 else: - elv, _, i, j = heappop(open_cells) + elev, _, i, j = heappop(pq) + # Neighbor expansion for n in range(8): - row = i + row_offsets[n] - col = j + col_offsets[n] + r = i + ROW_OFS[n] + c = j + COL_OFS[n] - if row < 0 or row >= y or col < 0 or col >= x: + if r < 0 or r >= y or c < 0 or c >= x: continue - - if dem_mask[row, col] or closed_cells[row, col]: + if closed[r, c] or dem_mask[r, c]: continue - if dem[row, col] <= elv: - dem[row, col] = elv - pits.append((elv, next(isertn), row, col)) + dval = dem[r, c] + + if dval <= elev: + # Fill pit + dem[r, c] = elev + pits_elev[pits_tail] = elev + pits_i[pits_tail] = r + pits_j[pits_tail] = c + pits_tail += 1 else: - heappush(open_cells, (dem[row, col], next(isertn), row, col)) - closed_cells[row, col] = True + # Add to PQ + heappush(pq, (dval, isertn, r, c)) + isertn += 1 + + closed[r, c] = 1 # pits book-keeping - if pits_pos == len(pits) and len(pits) > 1024: - # Queue is empty, lets clear it out - pits.clear() - pits_pos = 0 + if pits_head == pits_tail and pits_tail > 1024: + # Queue is empty, clear it out + pits_head = 0 + pits_tail = 0 - return dem \ No newline at end of file + return dem