diff --git a/tests/__snapshots__/delineation_test/TestSnapshotOutputs.test_multi_subbasin_structure_snapshot.json b/tests/__snapshots__/delineation_test/TestSnapshotOutputs.test_multi_subbasin_structure_snapshot.json new file mode 100644 index 0000000..bf24853 --- /dev/null +++ b/tests/__snapshots__/delineation_test/TestSnapshotOutputs.test_multi_subbasin_structure_snapshot.json @@ -0,0 +1,12 @@ +{ + "custom_nodes": [ + "main_outlet", + "upstream1", + "upstream2" + ], + "max_strahler_order": 4, + "num_edges": 55, + "num_nodes": 56, + "num_subbasins": 56, + "total_area_km2": 2924.4 +} diff --git a/tests/__snapshots__/delineation_test/TestSnapshotOutputs.test_single_outlet_network_structure_snapshot.json b/tests/__snapshots__/delineation_test/TestSnapshotOutputs.test_single_outlet_network_structure_snapshot.json new file mode 100644 index 0000000..22aa1d4 --- /dev/null +++ b/tests/__snapshots__/delineation_test/TestSnapshotOutputs.test_single_outlet_network_structure_snapshot.json @@ -0,0 +1,12 @@ +{ + "custom_nodes": [ + "outlet1" + ], + "max_shreve_order": 16, + "max_strahler_order": 4, + "num_edges": 54, + "num_nodes": 55, + "terminal_nodes": [ + "outlet1" + ] +} diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1982f9d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,131 @@ +""" +Pytest configuration and fixtures for watershed delineation tests. + +This module sets up the environment variables needed to access remote hydrology data +and provides common fixtures used across test modules. +""" + +import os +import tempfile +from pathlib import Path + +import pytest + +# Set environment variables at module load time, before any imports +# These are required by upstream_delineator modules at import time +os.environ.setdefault( + "CATCHMENT_PATH", "https://public-hydrology-data.upstream.tech/catchments" +) +os.environ.setdefault( + "RIVER_PATH", "https://public-hydrology-data.upstream.tech/rivers" +) +os.environ.setdefault( + "FLOW_DIR_PATH", "https://public-hydrology-data.upstream.tech/merit_flowdir.tif" +) +os.environ.setdefault( + "ACCUM_PATH", "https://public-hydrology-data.upstream.tech/merit_accum.tif" +) +os.environ.setdefault( + "MEGABASINS_PATH", + "https://public-hydrology-data.upstream.tech/merit_basins_lvl2.gpkg", +) + + +@pytest.fixture(scope="session") +def temp_output_dir(): + """Create a temporary directory for test outputs.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def single_outlet_csv(tmp_path): + """ + Create a CSV file with a single outlet point. + Uses a location in Iceland (megabasin 27) for testing. + Includes custom columns to verify they're passed through to the graph. + """ + csv_path = tmp_path / "single_outlet.csv" + csv_path.write_text( + "id,lng,lat,name,outlet_id,gage_id,priority\n" + "outlet1,-14.36201,65.50253,Lagarfljot River at Lagarfoss,outlet1,GAGE001,high\n" + ) + return str(csv_path) + + +@pytest.fixture +def multi_subbasin_csv(tmp_path): + """ + Create a CSV file with an outlet and upstream subbasin points. + This tests the subbasin delineation capability. + Includes custom columns (gage_id, priority) to verify they're passed through. + """ + csv_path = tmp_path / "multi_subbasin.csv" + csv_path.write_text( + "id,lng,lat,name,outlet_id,gage_id,priority\n" + "main_outlet,-14.36201,65.50253,Lagarfljot River at Lagarfoss,main_outlet,GAGE001,high\n" + "upstream1,-15.0883,64.9839,Jokulsa I River at Fljotsdal Holl,main_outlet,GAGE002,medium\n" + "upstream2,-14.533,65.14,Gringa Dam,main_outlet,GAGE003,low\n" + ) + return str(csv_path) + + +@pytest.fixture +def headwater_outlet_csv(tmp_path): + """ + Create a CSV file with a headwater outlet point. + Headwater points are typically in leaf catchments with no upstream neighbors. + """ + csv_path = tmp_path / "headwater.csv" + # A point near the headwaters of a small stream in Iceland + csv_path.write_text( + "id,lng,lat,name,outlet_id,gage_id,priority\n" + "headwater,-15.186,64.735,Lake Sauoarvatr outlet,headwater,GAGE_HW,high\n" + ) + return str(csv_path) + + +@pytest.fixture +def disconnected_basins_csv(tmp_path): + """ + Create a CSV file with two separate, disconnected watersheds. + Each outlet_id defines a separate river system that should result + in independent subgraphs in the final network. + """ + csv_path = tmp_path / "disconnected_basins.csv" + csv_path.write_text( + "id,lng,lat,name,outlet_id,gage_id,priority\n" + # First watershed - Lagarfljot system + "basin1_outlet,-14.36201,65.50253,Lagarfljot River at Lagarfoss,basin1_outlet,GAGE_B1,high\n" + "basin1_upstream,-14.533,65.14,Gringa Dam,basin1_outlet,GAGE_B1U,medium\n" + # Second watershed - separate system near headwaters + "basin2_outlet,-15.186,64.735,Lake Sauoarvatr outlet,basin2_outlet,GAGE_B2,high\n" + ) + return str(csv_path) + + +@pytest.fixture +def default_config(): + """Default configuration for tests - minimal output, no plots.""" + return { + "VERBOSE": False, + "WRITE_OUTPUT": False, + "PLOTS": False, + "CONSOLIDATE": False, + "NETWORK_DIAGRAMS": False, + "SIMPLIFY": False, + } + + +@pytest.fixture +def consolidate_config(): + """Configuration with consolidation enabled.""" + return { + "VERBOSE": False, + "WRITE_OUTPUT": False, + "PLOTS": False, + "CONSOLIDATE": True, + "MAX_AREA": 500, + "NETWORK_DIAGRAMS": False, + "SIMPLIFY": False, + } diff --git a/tests/delineation_test.py b/tests/delineation_test.py index 358994c..1f25af3 100644 --- a/tests/delineation_test.py +++ b/tests/delineation_test.py @@ -1,2 +1,922 @@ -def test_placeholder(): - pass # TODO +""" +Tests for the watershed delineation library. + +These tests verify correctness of watershed delineation in various hydrologic scenarios: +1. Single outlet delineation +2. Multiple subbasin delineation +3. Headwater watersheds +4. Network consolidation +5. Graph topology and stream order calculations +6. Disconnected basin networks + +Uses syrupy for snapshot testing of complex geodata outputs. +""" + +import networkx as nx +import pytest +from syrupy.assertion import SnapshotAssertion +from syrupy.extensions.json import JSONSnapshotExtension + +from upstream_delineator import config +from upstream_delineator.delineator_utils.delineate import delineate + + +class TestBasicDelineation: + """Test basic watershed delineation functionality.""" + + @pytest.fixture(autouse=True) + def reset_config(self, default_config): + """Reset config before each test.""" + config.set(default_config) + + def test_single_outlet_delineation_runs_without_error( + self, single_outlet_csv, default_config, temp_output_dir + ): + """ + Test that a single outlet delineation completes without errors. + This is the most basic smoke test for the delineation workflow. + """ + config.set( + { + **default_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + + G, subbasins_gdf, rivers_gdf = delineate( + single_outlet_csv, "test_single", default_config + ) + + # Basic assertions that we got valid output + assert G is not None, "Graph should not be None" + assert subbasins_gdf is not None, "Subbasins GeoDataFrame should not be None" + assert rivers_gdf is not None, "Rivers GeoDataFrame should not be None" + + # Graph should have nodes + assert G.number_of_nodes() > 0, "Graph should have at least one node" + + # Subbasins should have geometries + assert len(subbasins_gdf) > 0, "Should have at least one subbasin" + assert "geometry" in subbasins_gdf.columns, "Subbasins should have geometry" + + # Verify the outlet node exists and matches CSV + assert G.has_node("outlet1"), "Graph should contain the outlet1 node" + + # Terminal node should be the outlet + terminal_nodes = [n for n in G.nodes() if G.out_degree(n) == 0] + assert "outlet1" in terminal_nodes, "outlet1 should be a terminal node" + + def test_multi_subbasin_delineation( + self, multi_subbasin_csv, default_config, temp_output_dir + ): + """ + Test delineation with multiple subbasin outlets. + Verifies that upstream points are correctly assigned to subbasins. + """ + config.set( + { + **default_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + + G, subbasins_gdf, _rivers_gdf = delineate( + multi_subbasin_csv, "test_multi", default_config + ) + + # Should have custom nodes for each outlet point + custom_nodes = [n for n, d in G.nodes(data=True) if d.get("custom", False)] + assert len(custom_nodes) >= 3, "Should have at least 3 custom outlet nodes" + + # All custom nodes should be in the subbasins + subbasin_ids = subbasins_gdf["comid"].tolist() + for node in custom_nodes: + assert node in subbasin_ids, f"Custom node {node} should be in subbasins" + + # Verify all expected outlet IDs are present + expected_outlets = ["main_outlet", "upstream1", "upstream2"] + for outlet_id in expected_outlets: + assert G.has_node(outlet_id), f"Graph should contain node {outlet_id}" + + # The main outlet should be terminal + terminal_nodes = [n for n in G.nodes() if G.out_degree(n) == 0] + assert "main_outlet" in terminal_nodes, "main_outlet should be a terminal node" + + def test_headwater_outlet_delineation( + self, headwater_outlet_csv, default_config, temp_output_dir + ): + """ + Test delineation at a headwater location. + Headwaters are leaf catchments with no upstream neighbors. + """ + config.set( + { + **default_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + + G, subbasins_gdf, _rivers_gdf = delineate( + headwater_outlet_csv, "test_headwater", default_config + ) + + # Headwater watershed should be small (few subbasins) + assert len(subbasins_gdf) >= 1, ( + "Headwater watershed should have at least 1 subbasin" + ) + + # The outlet node should exist and be terminal + assert G.has_node("headwater"), "Graph should contain the headwater outlet node" + terminal_nodes = [n for n in G.nodes() if G.out_degree(n) == 0] + assert "headwater" in terminal_nodes, "headwater should be a terminal node" + + +class TestNetworkTopology: + """Test river network graph topology and connectivity.""" + + @pytest.fixture(autouse=True) + def reset_config(self, default_config): + """Reset config before each test.""" + config.set(default_config) + + def test_graph_is_directed_acyclic( + self, multi_subbasin_csv, default_config, temp_output_dir + ): + """ + Test that the river network graph is a directed acyclic graph (DAG). + River networks should flow downstream without cycles. + """ + config.set( + { + **default_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + + G, _, _ = delineate(multi_subbasin_csv, "test_dag", default_config) + + assert isinstance(G, nx.DiGraph), "Graph should be a directed graph" + assert nx.is_directed_acyclic_graph(G), "River network should be acyclic" + + # Verify terminal node is the expected outlet + terminal_nodes = [n for n in G.nodes() if G.out_degree(n) == 0] + assert "main_outlet" in terminal_nodes, "main_outlet should be a terminal node" + + def test_single_terminal_node( + self, multi_subbasin_csv, default_config, temp_output_dir + ): + """ + Test that the network has exactly one terminal (outlet) node. + The terminal node is the one with no outgoing edges. + """ + config.set( + { + **default_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + + G, _, _ = delineate(multi_subbasin_csv, "test_terminal", default_config) + + # Find terminal nodes (no successors) + terminal_nodes = [n for n in G.nodes() if G.out_degree(n) == 0] + assert len(terminal_nodes) == 1, "Should have exactly one terminal node" + assert terminal_nodes[0] == "main_outlet", "Terminal node should be main_outlet" + + def test_outlet_node_attributes_from_csv( + self, multi_subbasin_csv, default_config, temp_output_dir + ): + """ + Test that custom attributes from CSV are present in the graph nodes. + Verifies that gage_id, priority, and other custom columns are accessible. + """ + config.set( + { + **default_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + + G, subbasins_gdf, _ = delineate( + multi_subbasin_csv, "test_attrs", default_config + ) + + # Verify custom nodes are marked + assert G.nodes["main_outlet"].get("custom") is True + assert G.nodes["upstream1"].get("custom") is True + assert G.nodes["upstream2"].get("custom") is True + + # Verify the subbasins GeoDataFrame has the custom columns + assert "gage_id" in subbasins_gdf.columns, ( + "gage_id column should be in subbasins" + ) + assert "priority" in subbasins_gdf.columns, ( + "priority column should be in subbasins" + ) + + # Check that custom outlet rows have the expected values + main_outlet_row = subbasins_gdf[subbasins_gdf["comid"] == "main_outlet"] + assert len(main_outlet_row) == 1, "Should have exactly one main_outlet row" + assert main_outlet_row.iloc[0]["gage_id"] == "GAGE001" + assert main_outlet_row.iloc[0]["priority"] == "high" + + upstream1_row = subbasins_gdf[subbasins_gdf["comid"] == "upstream1"] + assert len(upstream1_row) == 1, "Should have exactly one upstream1 row" + assert upstream1_row.iloc[0]["gage_id"] == "GAGE002" + assert upstream1_row.iloc[0]["priority"] == "medium" + + def test_stream_orders_assigned( + self, multi_subbasin_csv, default_config, temp_output_dir + ): + """ + Test that Strahler and Shreve stream orders are calculated. + These are fundamental hydrologic properties of river networks. + """ + config.set( + { + **default_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + + G, subbasins_gdf, _ = delineate( + multi_subbasin_csv, "test_orders", default_config + ) + + # Check that stream orders are in graph nodes + for node in G.nodes(): + assert "strahler_order" in G.nodes[node], ( + f"Node {node} should have strahler_order" + ) + assert "shreve_order" in G.nodes[node], ( + f"Node {node} should have shreve_order" + ) + assert G.nodes[node]["strahler_order"] >= 1, ( + "Strahler order should be at least 1" + ) + assert G.nodes[node]["shreve_order"] >= 1, ( + "Shreve order should be at least 1" + ) + + # Check that stream orders are in subbasins geodataframe + assert "strahler_order" in subbasins_gdf.columns + assert "shreve_order" in subbasins_gdf.columns + + # Verify outlet is the expected node + terminal_nodes = [n for n in G.nodes() if G.out_degree(n) == 0] + assert "main_outlet" in terminal_nodes + + def test_strahler_order_properties( + self, multi_subbasin_csv, default_config, temp_output_dir + ): + """ + Test that Strahler stream order follows correct rules: + - Headwaters (no upstream) have order 1 + - When two streams of same order meet, result is order + 1 + - When streams of different orders meet, result is max order + """ + config.set( + { + **default_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + + G, _, _ = delineate(multi_subbasin_csv, "test_strahler", default_config) + + # Find leaf nodes (headwaters) - they should have Strahler order 1 + leaf_nodes = [n for n in G.nodes() if G.in_degree(n) == 0] + for leaf in leaf_nodes: + assert G.nodes[leaf]["strahler_order"] == 1, ( + f"Headwater node {leaf} should have Strahler order 1" + ) + + # For each non-leaf node, verify Strahler order calculation + for node in G.nodes(): + if G.in_degree(node) > 0: + upstream_orders = [ + G.nodes[pred]["strahler_order"] for pred in G.predecessors(node) + ] + max_order = max(upstream_orders) + count_max = upstream_orders.count(max_order) + + expected_order = max_order + 1 if count_max > 1 else max_order + actual_order = G.nodes[node]["strahler_order"] + + assert actual_order == expected_order, ( + f"Node {node}: expected Strahler order {expected_order}, " + f"got {actual_order}" + ) + + def test_shreve_order_properties( + self, multi_subbasin_csv, default_config, temp_output_dir + ): + """ + Test that Shreve stream order follows correct rules: + - Headwaters have order 1 + - Order increases downstream (sum of upstream orders) + """ + config.set( + { + **default_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + + G, _, _ = delineate(multi_subbasin_csv, "test_shreve", default_config) + + # Leaf nodes should have Shreve order 1 + leaf_nodes = [n for n in G.nodes() if G.in_degree(n) == 0] + for leaf in leaf_nodes: + assert G.nodes[leaf]["shreve_order"] == 1, ( + f"Headwater node {leaf} should have Shreve order 1" + ) + + # Shreve order should increase downstream (or stay same at terminus) + for node in G.nodes(): + successors = list(G.successors(node)) + if successors: + successor = successors[0] + assert ( + G.nodes[successor]["shreve_order"] >= G.nodes[node]["shreve_order"] + ), ( + f"Shreve order should not decrease downstream " + f"from {node} to {successor}" + ) + + +class TestDisconnectedBasins: + """Test handling of multiple disconnected watersheds.""" + + @pytest.fixture(autouse=True) + def reset_config(self, default_config): + """Reset config before each test.""" + config.set(default_config) + + def test_disconnected_basins_separate_systems( + self, disconnected_basins_csv, default_config, temp_output_dir + ): + """ + Test that two separate outlets create two disconnected river systems. + Each outlet_id in the CSV should correspond to an independent watershed. + """ + config.set( + { + **default_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + + G, subbasins_gdf, _ = delineate( + disconnected_basins_csv, "test_disconnected", default_config + ) + + # Both outlet nodes should exist + assert G.has_node("basin1_outlet"), "Graph should contain basin1_outlet" + assert G.has_node("basin2_outlet"), "Graph should contain basin2_outlet" + + # Both should be terminal nodes (no outgoing edges) + terminal_nodes = [n for n in G.nodes() if G.out_degree(n) == 0] + assert "basin1_outlet" in terminal_nodes, ( + "basin1_outlet should be a terminal node" + ) + assert "basin2_outlet" in terminal_nodes, ( + "basin2_outlet should be a terminal node" + ) + + # The graph should have exactly 2 terminal nodes (2 separate systems) + assert len(terminal_nodes) == 2, ( + f"Should have exactly 2 terminal nodes for 2 watersheds, got {len(terminal_nodes)}" + ) + + # The graph should be disconnected (2 weakly connected components) + num_components = nx.number_weakly_connected_components(G) + assert num_components == 2, ( + f"Should have 2 weakly connected components, got {num_components}" + ) + + # Verify custom attributes from CSV are present + assert "gage_id" in subbasins_gdf.columns + basin1_row = subbasins_gdf[subbasins_gdf["comid"] == "basin1_outlet"] + assert basin1_row.iloc[0]["gage_id"] == "GAGE_B1" + + def test_disconnected_basins_upstream_connectivity( + self, disconnected_basins_csv, default_config, temp_output_dir + ): + """ + Test that upstream points are correctly connected to their respective outlets. + basin1_upstream should flow to basin1_outlet, not basin2_outlet. + """ + config.set( + { + **default_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + + G, _, _ = delineate( + disconnected_basins_csv, "test_disconnected_connectivity", default_config + ) + + # basin1_upstream should be in the same component as basin1_outlet + # Find the component containing basin1_outlet + components = list(nx.weakly_connected_components(G)) + + basin1_component = None + basin2_component = None + for comp in components: + if "basin1_outlet" in comp: + basin1_component = comp + if "basin2_outlet" in comp: + basin2_component = comp + + assert basin1_component is not None, "Should find basin1 component" + assert basin2_component is not None, "Should find basin2 component" + + # basin1_upstream should be in basin1's component + assert "basin1_upstream" in basin1_component, ( + "basin1_upstream should be in the same component as basin1_outlet" + ) + + # basin1_upstream should NOT be in basin2's component + assert "basin1_upstream" not in basin2_component, ( + "basin1_upstream should not be in basin2's component" + ) + + +class TestConsolidation: + """Test network consolidation functionality.""" + + @pytest.fixture(autouse=True) + def reset_config(self, consolidate_config): + """Reset config with consolidation enabled.""" + config.set(consolidate_config) + + def test_consolidation_reduces_nodes( + self, multi_subbasin_csv, default_config, consolidate_config, temp_output_dir + ): + """ + Test that consolidation reduces the number of nodes in the network. + Consolidation merges small unit catchments to create larger subbasins. + """ + # First, delineate without consolidation + config.set( + { + **default_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + G_original, _, _ = delineate( + multi_subbasin_csv, "test_no_consol", default_config + ) + + # Then, delineate with consolidation + config.set( + { + **consolidate_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + G_consolidated, _, _ = delineate( + multi_subbasin_csv, "test_consol", consolidate_config + ) + + # Consolidated network should have fewer or equal nodes + assert G_consolidated.number_of_nodes() <= G_original.number_of_nodes(), ( + f"Consolidated network ({G_consolidated.number_of_nodes()} nodes) " + f"should have <= nodes than original ({G_original.number_of_nodes()} nodes)" + ) + + # Verify outlet node is still present and terminal + assert G_consolidated.has_node("main_outlet") + terminal_nodes = [ + n for n in G_consolidated.nodes() if G_consolidated.out_degree(n) == 0 + ] + assert "main_outlet" in terminal_nodes + + def test_consolidation_preserves_custom_nodes( + self, multi_subbasin_csv, consolidate_config, temp_output_dir + ): + """ + Test that consolidation preserves user-specified outlet points. + Custom nodes should not be merged away during consolidation. + """ + config.set( + { + **consolidate_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + + G, subbasins_gdf, _ = delineate( + multi_subbasin_csv, "test_custom_preserved", consolidate_config + ) + + # All custom outlet IDs should still be present + expected_outlets = ["main_outlet", "upstream1", "upstream2"] + for outlet_id in expected_outlets: + assert G.has_node(outlet_id), ( + f"Custom outlet {outlet_id} should be preserved after consolidation" + ) + + # Verify custom attributes are preserved + assert "gage_id" in subbasins_gdf.columns + main_outlet_row = subbasins_gdf[subbasins_gdf["comid"] == "main_outlet"] + assert main_outlet_row.iloc[0]["gage_id"] == "GAGE001" + + def test_consolidation_maintains_connectivity( + self, multi_subbasin_csv, consolidate_config, temp_output_dir + ): + """ + Test that consolidation maintains network connectivity. + The graph should still be connected and acyclic after consolidation. + """ + config.set( + { + **consolidate_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + + G, _, _ = delineate(multi_subbasin_csv, "test_connectivity", consolidate_config) + + # Graph should still be a DAG + assert nx.is_directed_acyclic_graph(G), ( + "Consolidated network should still be a DAG" + ) + + # Should still have exactly one terminal node + terminal_nodes = [n for n in G.nodes() if G.out_degree(n) == 0] + assert len(terminal_nodes) == 1, ( + "Consolidated network should have exactly one terminal node" + ) + assert terminal_nodes[0] == "main_outlet" + + +class TestGeometryValidity: + """Test that output geometries are valid.""" + + @pytest.fixture(autouse=True) + def reset_config(self, default_config): + """Reset config before each test.""" + config.set(default_config) + + def test_subbasin_geometries_valid( + self, multi_subbasin_csv, default_config, temp_output_dir + ): + """ + Test that all subbasin geometries are valid polygons. + Invalid geometries can cause problems in downstream GIS analysis. + + Note: The MERIT-Hydro source data occasionally has invalid geometries + (typically self-intersections). These can be fixed with make_valid(). + """ + from shapely.validation import make_valid + + config.set( + { + **default_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + + _, subbasins_gdf, _ = delineate( + multi_subbasin_csv, "test_valid_geom", default_config + ) + + # Check how many are invalid before fixing + invalid_before = subbasins_gdf[~subbasins_gdf.geometry.is_valid] + + # All geometries should be fixable with make_valid + fixed_geoms = subbasins_gdf.geometry.apply( + lambda g: make_valid(g) if not g.is_valid else g + ) + + # After fixing, all should be valid + still_invalid = fixed_geoms[~fixed_geoms.is_valid] + assert len(still_invalid) == 0, ( + f"Found {len(still_invalid)} geometries that cannot be made valid" + ) + + # Log any that needed fixing (these are source data issues) + if len(invalid_before) > 0: + import warnings + + warnings.warn( + f"Found {len(invalid_before)} invalid geometries in source data " + f"(COMIDs: {invalid_before['comid'].tolist()}). " + "These can be fixed with shapely.validation.make_valid()." + ) + + def test_subbasin_geometries_nonempty( + self, multi_subbasin_csv, default_config, temp_output_dir + ): + """ + Test that no subbasin geometries are empty. + Every subbasin should have a polygon representing its contributing area. + """ + config.set( + { + **default_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + + _, subbasins_gdf, _ = delineate( + multi_subbasin_csv, "test_nonempty_geom", default_config + ) + + # No empty geometries + empty_geoms = subbasins_gdf[subbasins_gdf.geometry.is_empty] + assert len(empty_geoms) == 0, ( + f"Found {len(empty_geoms)} empty subbasin geometries" + ) + + def test_subbasins_have_positive_area( + self, multi_subbasin_csv, default_config, temp_output_dir + ): + """ + Test that all subbasins have positive area. + Area is a key attribute for hydrologic modeling. + """ + config.set( + { + **default_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + + _, subbasins_gdf, _ = delineate( + multi_subbasin_csv, "test_positive_area", default_config + ) + + # All areas should be positive + assert (subbasins_gdf["unitarea"] > 0).all(), ( + "All subbasins should have positive area" + ) + + def test_rivers_geometries_valid( + self, multi_subbasin_csv, default_config, temp_output_dir + ): + """ + Test that river reach geometries are valid LineStrings. + """ + config.set( + { + **default_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + + _, _, rivers_gdf = delineate( + multi_subbasin_csv, "test_valid_rivers", default_config + ) + + # Filter out any empty geometries first (these are handled separately) + non_empty = rivers_gdf[~rivers_gdf.geometry.is_empty] + + # All non-empty geometries should be valid + invalid_geoms = non_empty[~non_empty.geometry.is_valid] + assert len(invalid_geoms) == 0, ( + f"Found {len(invalid_geoms)} invalid river geometries" + ) + + +class TestDataConsistency: + """Test consistency between graph, subbasins, and rivers data.""" + + @pytest.fixture(autouse=True) + def reset_config(self, default_config): + """Reset config before each test.""" + config.set(default_config) + + def test_graph_subbasins_correspondence( + self, multi_subbasin_csv, default_config, temp_output_dir + ): + """ + Test that graph nodes correspond to subbasins in the GeoDataFrame. + Every node in the graph should have a corresponding subbasin. + """ + config.set( + { + **default_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + + G, subbasins_gdf, _ = delineate( + multi_subbasin_csv, "test_correspondence", default_config + ) + + graph_nodes = set(G.nodes()) + subbasin_ids = set(subbasins_gdf["comid"].tolist()) + + # All graph nodes should be in subbasins + missing_in_subbasins = graph_nodes - subbasin_ids + assert len(missing_in_subbasins) == 0, ( + f"Graph nodes missing from subbasins: {missing_in_subbasins}" + ) + + # Verify terminal node + terminal_nodes = [n for n in G.nodes() if G.out_degree(n) == 0] + assert "main_outlet" in terminal_nodes + + def test_nextdown_consistency( + self, multi_subbasin_csv, default_config, temp_output_dir + ): + """ + Test that graph structure is internally consistent. + Every non-terminal node should have exactly one outgoing edge. + Terminal nodes (outlets) should have no outgoing edges. + """ + config.set( + { + **default_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + + G, subbasins_gdf, _ = delineate( + multi_subbasin_csv, "test_nextdown", default_config + ) + + # Every node should have at most one successor (river networks are trees) + for node in G.nodes(): + out_degree = G.out_degree(node) + assert out_degree <= 1, ( + f"Node {node} has {out_degree} successors, expected at most 1" + ) + + # The terminal node should have nextdown = 0 in the GeoDataFrame + terminal_nodes = [n for n in G.nodes() if G.out_degree(n) == 0] + assert len(terminal_nodes) >= 1, "Should have at least one terminal node" + + for terminal in terminal_nodes: + if terminal in subbasins_gdf["comid"].values: + terminal_row = subbasins_gdf[subbasins_gdf["comid"] == terminal] + nextdown_val = terminal_row.iloc[0]["nextdown"] + assert nextdown_val == 0, ( + f"Terminal node {terminal} should have nextdown=0, got {nextdown_val}" + ) + + def test_area_values_consistent( + self, multi_subbasin_csv, default_config, temp_output_dir + ): + """ + Test that area values are consistent between graph and GeoDataFrame. + """ + config.set( + { + **default_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + + G, subbasins_gdf, _ = delineate( + multi_subbasin_csv, "test_area_consistent", default_config + ) + + subbasins_gdf_indexed = subbasins_gdf.set_index("comid") + + for node in G.nodes(): + graph_area = G.nodes[node].get("area", 0) + if node in subbasins_gdf_indexed.index: + gdf_area = subbasins_gdf_indexed.loc[node, "unitarea"] + # Allow small floating point differences + assert abs(graph_area - gdf_area) < 0.5, ( + f"Area mismatch for node {node}: graph={graph_area}, gdf={gdf_area}" + ) + + +class TestSnapshotOutputs: + """Snapshot tests for complex output verification using syrupy.""" + + @pytest.fixture + def snapshot_json(self, snapshot: SnapshotAssertion): + """Configure syrupy to use JSON extension for readable snapshots.""" + return snapshot.with_defaults(extension_class=JSONSnapshotExtension) + + @pytest.fixture(autouse=True) + def reset_config(self, default_config): + """Reset config before each test.""" + config.set(default_config) + + def test_single_outlet_network_structure_snapshot( + self, + single_outlet_csv, + default_config, + temp_output_dir, + snapshot_json, + ): + """ + Snapshot test for network structure of single outlet delineation. + Captures the essential topology of the delineated watershed. + """ + config.set( + { + **default_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + + G, subbasins_gdf, _ = delineate( + single_outlet_csv, "test_snapshot", default_config + ) + + # Verify outlet node + terminal_nodes = [n for n in G.nodes() if G.out_degree(n) == 0] + assert "outlet1" in terminal_nodes + + # Verify custom attributes are present + assert "gage_id" in subbasins_gdf.columns + outlet_row = subbasins_gdf[subbasins_gdf["comid"] == "outlet1"] + assert outlet_row.iloc[0]["gage_id"] == "GAGE001" + + # Create a serializable summary of the network structure + network_summary = { + "num_nodes": G.number_of_nodes(), + "num_edges": G.number_of_edges(), + "custom_nodes": sorted( + [str(n) for n, d in G.nodes(data=True) if d.get("custom", False)] + ), + "terminal_nodes": sorted( + [str(n) for n in G.nodes() if G.out_degree(n) == 0] + ), + "max_strahler_order": max( + d.get("strahler_order", 0) for _, d in G.nodes(data=True) + ), + "max_shreve_order": max( + d.get("shreve_order", 0) for _, d in G.nodes(data=True) + ), + } + + assert network_summary == snapshot_json + + def test_multi_subbasin_structure_snapshot( + self, + multi_subbasin_csv, + default_config, + temp_output_dir, + snapshot_json, + ): + """ + Snapshot test for multi-subbasin delineation structure. + """ + config.set( + { + **default_config, + "OUTPUT_DIR": str(temp_output_dir), + "CACHE_DIR": str(temp_output_dir / "cache"), + } + ) + + G, subbasins_gdf, _ = delineate( + multi_subbasin_csv, "test_multi_snapshot", default_config + ) + + # Verify outlet node + terminal_nodes = [n for n in G.nodes() if G.out_degree(n) == 0] + assert "main_outlet" in terminal_nodes + + # Verify custom attributes are present + assert "gage_id" in subbasins_gdf.columns + assert "priority" in subbasins_gdf.columns + + # Create a serializable summary + network_summary = { + "num_nodes": G.number_of_nodes(), + "num_edges": G.number_of_edges(), + "num_subbasins": len(subbasins_gdf), + "custom_nodes": sorted( + [str(n) for n, d in G.nodes(data=True) if d.get("custom", False)] + ), + "total_area_km2": round(subbasins_gdf["unitarea"].sum(), 1), + "max_strahler_order": max( + d.get("strahler_order", 0) for _, d in G.nodes(data=True) + ), + } + + assert network_summary == snapshot_json diff --git a/tests/graph_tools_test.py b/tests/graph_tools_test.py new file mode 100644 index 0000000..6f37bee --- /dev/null +++ b/tests/graph_tools_test.py @@ -0,0 +1,367 @@ +""" +Tests for graph tools module. + +These tests verify the correctness of graph operations used in river network analysis, +including stream order calculations (Strahler and Shreve), node operations, and +graph utilities. +""" + +import networkx as nx +import pytest + +from upstream_delineator.delineator_utils.graph_tools import ( + calculate_shreve_stream_order, + calculate_strahler_stream_order, + insert_node, + make_river_network, + prune_node, + upstream_nodes, +) + + +class TestStreamOrderCalculations: + """Test stream order calculation algorithms.""" + + def create_simple_network(self) -> nx.DiGraph: + """ + Create a simple test network with known structure. + + Structure (arrows show flow direction, toward right): + A ─┐ + ├─> D ─> E + B ─┘ + C ─────────┘ + + Expected Strahler orders: A=1, B=1, C=1, D=2, E=2 + Expected Shreve orders: A=1, B=1, C=1, D=2, E=3 + """ + G = nx.DiGraph() + G.add_edge("A", "D") + G.add_edge("B", "D") + G.add_edge("D", "E") + G.add_edge("C", "E") + return G + + def create_linear_network(self) -> nx.DiGraph: + """ + Create a linear (unbranched) network. + + Structure: A ─> B ─> C ─> D + + All nodes should have Strahler order 1. + Shreve orders: A=1, B=2, C=3, D=4 + """ + G = nx.DiGraph() + G.add_edge("A", "B") + G.add_edge("B", "C") + G.add_edge("C", "D") + return G + + def create_symmetric_network(self) -> nx.DiGraph: + """ + Create a symmetric branching network where streams of equal order meet. + + Structure: + A ─┐ ┌─ E + ├─> C ─>│ + B ─┘ └─ F + (plus D and G feeding into the junction at outlet H) + + When two streams of Strahler order n meet, the result is n+1. + """ + G = nx.DiGraph() + # Two headwaters feeding into C + G.add_edge("A", "C") + G.add_edge("B", "C") + # Two more headwaters feeding into D + G.add_edge("E", "D") + G.add_edge("F", "D") + # C and D meet at outlet + G.add_edge("C", "outlet") + G.add_edge("D", "outlet") + return G + + def test_strahler_order_simple_network(self): + """Test Strahler order calculation on simple branching network.""" + G = self.create_simple_network() + G = calculate_strahler_stream_order(G) + + # Headwaters should have order 1 + assert G.nodes["A"]["strahler_order"] == 1 + assert G.nodes["B"]["strahler_order"] == 1 + assert G.nodes["C"]["strahler_order"] == 1 + + # D receives two streams of order 1, so it becomes order 2 + assert G.nodes["D"]["strahler_order"] == 2 + + # E receives order 2 from D and order 1 from C, so max is 2 + assert G.nodes["E"]["strahler_order"] == 2 + + def test_strahler_order_linear_network(self): + """Test Strahler order on linear (unbranched) network.""" + G = self.create_linear_network() + G = calculate_strahler_stream_order(G) + + # All nodes in unbranched network should have order 1 + for node in G.nodes(): + assert G.nodes[node]["strahler_order"] == 1, ( + f"Node {node} should have Strahler order 1" + ) + + def test_strahler_order_symmetric_network(self): + """Test Strahler order on symmetric branching network.""" + G = self.create_symmetric_network() + G = calculate_strahler_stream_order(G) + + # Headwaters should have order 1 + for node in ["A", "B", "E", "F"]: + assert G.nodes[node]["strahler_order"] == 1, ( + f"Headwater {node} should have order 1" + ) + + # C and D both receive two order-1 streams, so they become order 2 + assert G.nodes["C"]["strahler_order"] == 2 + assert G.nodes["D"]["strahler_order"] == 2 + + # Outlet receives two order-2 streams, so it becomes order 3 + assert G.nodes["outlet"]["strahler_order"] == 3 + + def test_shreve_order_simple_network(self): + """Test Shreve order calculation on simple branching network.""" + G = self.create_simple_network() + G = calculate_shreve_stream_order(G) + + # Headwaters should have order 1 + assert G.nodes["A"]["shreve_order"] == 1 + assert G.nodes["B"]["shreve_order"] == 1 + assert G.nodes["C"]["shreve_order"] == 1 + + # D receives inputs from A and B + assert G.nodes["D"]["shreve_order"] == 2 + + # E receives inputs from D and C (Shreve order is max + 1) + assert G.nodes["E"]["shreve_order"] == 3 + + def test_shreve_order_linear_network(self): + """Test Shreve order on linear network - should increase downstream.""" + G = self.create_linear_network() + G = calculate_shreve_stream_order(G) + + # Shreve order increases by 1 at each step + assert G.nodes["A"]["shreve_order"] == 1 + assert G.nodes["B"]["shreve_order"] == 2 + assert G.nodes["C"]["shreve_order"] == 3 + assert G.nodes["D"]["shreve_order"] == 4 + + def test_shreve_order_symmetric_network(self): + """Test Shreve order on symmetric branching network.""" + G = self.create_symmetric_network() + G = calculate_shreve_stream_order(G) + + # Headwaters should have order 1 + for node in ["A", "B", "E", "F"]: + assert G.nodes[node]["shreve_order"] == 1 + + # C and D both receive two streams + assert G.nodes["C"]["shreve_order"] == 2 + assert G.nodes["D"]["shreve_order"] == 2 + + # Outlet receives from both C and D + assert G.nodes["outlet"]["shreve_order"] == 3 + + +class TestNodeOperations: + """Test node insertion and pruning operations.""" + + def test_prune_node_middle(self): + """Test pruning a node from the middle of a network.""" + G = nx.DiGraph() + G.add_edge("A", "B") + G.add_edge("B", "C") + + G = prune_node(G, "B") + + # B should be gone + assert "B" not in G.nodes() + + # A should now connect directly to C + assert G.has_edge("A", "C") + + def test_prune_node_branch(self): + """Test pruning a node that has multiple predecessors.""" + G = nx.DiGraph() + G.add_edge("A", "X") + G.add_edge("B", "X") + G.add_edge("X", "C") + + G = prune_node(G, "X") + + # X should be gone + assert "X" not in G.nodes() + + # Both A and B should now connect to C + assert G.has_edge("A", "C") + assert G.has_edge("B", "C") + + def test_prune_node_preserves_other_edges(self): + """Test that pruning preserves edges not involving the pruned node.""" + G = nx.DiGraph() + G.add_edge("A", "B") + G.add_edge("B", "C") + G.add_edge("D", "C") + + G = prune_node(G, "B") + + # D -> C edge should still exist + assert G.has_edge("D", "C") + + def test_prune_nonexistent_node_raises(self): + """Test that pruning a non-existent node raises ValueError.""" + G = nx.DiGraph() + G.add_node("A") + + with pytest.raises(ValueError, match="not in the graph"): + prune_node(G, "nonexistent") + + def test_insert_node_into_leaf(self): + """Test inserting a node into a leaf (Strahler order 1) catchment.""" + G = nx.DiGraph() + G.add_edge("A", "B") + G.add_edge("B", "C") + # Add strahler_order attribute for A (leaf node) + G.nodes["A"]["strahler_order"] = 1 + G.nodes["B"]["strahler_order"] = 1 + G.nodes["C"]["strahler_order"] = 1 + + G = insert_node(G, "new_node", "A") + + # New node should exist + assert "new_node" in G.nodes() + + # New node should connect to A + assert G.has_edge("new_node", "A") + + # New node should be marked as 'new' and 'leaf' type + assert G.nodes["new_node"]["new"] is True + assert G.nodes["new_node"]["type"] == "leaf" + + def test_insert_node_into_stem(self): + """Test inserting a node into a stem (Strahler order > 1) catchment.""" + G = nx.DiGraph() + G.add_edge("A", "B") + G.add_edge("C", "B") + G.add_edge("B", "D") + # B is at a junction, so it has order > 1 + G.nodes["A"]["strahler_order"] = 1 + G.nodes["C"]["strahler_order"] = 1 + G.nodes["B"]["strahler_order"] = 2 + G.nodes["D"]["strahler_order"] = 2 + + G = insert_node(G, "new_node", "B") + + # New node should exist + assert "new_node" in G.nodes() + + # New node should connect to B + assert G.has_edge("new_node", "B") + + # A and C should now connect to new_node instead of B + assert G.has_edge("A", "new_node") + assert G.has_edge("C", "new_node") + assert not G.has_edge("A", "B") + assert not G.has_edge("C", "B") + + # New node should be marked as 'new' and 'stem' type + assert G.nodes["new_node"]["new"] is True + assert G.nodes["new_node"]["type"] == "stem" + + +class TestUpstreamNodes: + """Test the upstream_nodes function.""" + + def test_upstream_nodes_simple(self): + """Test finding upstream nodes in a simple network.""" + G = nx.DiGraph() + G.add_edge("A", "B") + G.add_edge("B", "C") + + # From C's perspective + up = upstream_nodes(G, "C") + assert set(up) == {"A", "B"} + + # From B's perspective + up = upstream_nodes(G, "B") + assert set(up) == {"A"} + + # From A's perspective (headwater, no upstream) + up = upstream_nodes(G, "A") + assert len(up) == 0 + + def test_upstream_nodes_branching(self): + """Test finding upstream nodes in a branching network.""" + G = nx.DiGraph() + G.add_edge("A", "C") + G.add_edge("B", "C") + G.add_edge("C", "D") + + # From D's perspective + up = upstream_nodes(G, "D") + assert set(up) == {"A", "B", "C"} + + # From C's perspective + up = upstream_nodes(G, "C") + assert set(up) == {"A", "B"} + + def test_upstream_nodes_nonexistent_raises(self): + """Test that requesting upstream of non-existent node raises error.""" + G = nx.DiGraph() + G.add_node("A") + + with pytest.raises(ValueError, match="not in the graph"): + upstream_nodes(G, "nonexistent") + + +class TestMakeRiverNetwork: + """Test the make_river_network function.""" + + def test_make_river_network_from_dataframe(self): + """Test creating a network graph from a DataFrame.""" + import pandas as pd + + # Create a simple DataFrame mimicking subbasin data + data = { + "nextdown": [2, 3, 0], # Node 1 -> 2, 2 -> 3, 3 -> outlet (0) + "unitarea": [100.0, 150.0, 200.0], + } + df = pd.DataFrame(data, index=[1, 2, 3]) + + G = make_river_network(df, terminal_node=3) + + # Should have 3 nodes + assert G.number_of_nodes() == 3 + + # Check edges (terminal node 3 should not have outgoing edge) + assert G.has_edge(1, 2) + assert G.has_edge(2, 3) + assert not G.has_edge(3, 0) # No edge to 0 for terminal + + # Check area attributes + assert G.nodes[1]["area"] == 100.0 + assert G.nodes[2]["area"] == 150.0 + assert G.nodes[3]["area"] == 200.0 + + def test_make_river_network_handles_zero_nextdown(self): + """Test that nextdown=0 (ocean/terminal) is handled correctly.""" + import pandas as pd + + data = { + "nextdown": [0], # Single terminal node + "unitarea": [100.0], + } + df = pd.DataFrame(data, index=[1]) + + G = make_river_network(df) + + # Should have 1 node, no edges + assert G.number_of_nodes() == 1 + assert G.number_of_edges() == 0 diff --git a/tests/validation_test.py b/tests/validation_test.py new file mode 100644 index 0000000..75cb69a --- /dev/null +++ b/tests/validation_test.py @@ -0,0 +1,233 @@ +""" +Tests for input validation and error handling. + +These tests verify that the library correctly validates inputs and +produces appropriate error messages for invalid data. +""" + +import pandas as pd +import pytest + +from upstream_delineator.delineator_utils.util import find_repeated_elements, validate + + +class TestValidateFunction: + """Test the validate function for CSV input validation.""" + + def test_validate_valid_input(self): + """Test that valid input passes validation.""" + df = pd.DataFrame( + { + "id": ["outlet1", "point2"], + "lat": [65.5, 64.9], + "lng": [-14.3, -15.0], + "outlet_id": ["outlet1", "outlet1"], + } + ) + + result = validate(df) + assert result is True + + def test_validate_missing_id_column(self): + """Test that missing 'id' column raises ValueError.""" + df = pd.DataFrame( + { + "lat": [65.5], + "lng": [-14.3], + "outlet_id": ["outlet1"], + } + ) + + with pytest.raises(ValueError, match=r"Missing column.*id"): + validate(df) + + def test_validate_missing_lat_column(self): + """Test that missing 'lat' column raises ValueError.""" + df = pd.DataFrame( + { + "id": ["outlet1"], + "lng": [-14.3], + "outlet_id": ["outlet1"], + } + ) + + with pytest.raises(ValueError, match=r"Missing column.*lat"): + validate(df) + + def test_validate_missing_lng_column(self): + """Test that missing 'lng' column raises ValueError.""" + df = pd.DataFrame( + { + "id": ["outlet1"], + "lat": [65.5], + "outlet_id": ["outlet1"], + } + ) + + with pytest.raises(ValueError, match=r"Missing column.*lng"): + validate(df) + + def test_validate_missing_outlet_id_column(self): + """Test that missing 'outlet_id' column raises ValueError.""" + df = pd.DataFrame( + { + "id": ["outlet1"], + "lat": [65.5], + "lng": [-14.3], + } + ) + + with pytest.raises(ValueError, match=r"Missing column.*outlet_id"): + validate(df) + + def test_validate_duplicate_ids(self): + """Test that duplicate IDs raise ValueError.""" + df = pd.DataFrame( + { + "id": ["outlet1", "outlet1"], # Duplicate! + "lat": [65.5, 64.9], + "lng": [-14.3, -15.0], + "outlet_id": ["outlet1", "outlet1"], + } + ) + + with pytest.raises(ValueError, match="unique"): + validate(df) + + def test_validate_non_numeric_lat(self): + """Test that non-numeric latitude raises ValueError.""" + df = pd.DataFrame( + { + "id": ["outlet1"], + "lat": ["not_a_number"], + "lng": [-14.3], + "outlet_id": ["outlet1"], + } + ) + + with pytest.raises(ValueError, match=r"lat.*not numeric"): + validate(df) + + def test_validate_lat_too_low(self): + """Test that latitude < -60 raises ValueError.""" + df = pd.DataFrame( + { + "id": ["outlet1"], + "lat": [-70.0], # Below -60 (MERIT-Hydro coverage limit) + "lng": [-14.3], + "outlet_id": ["outlet1"], + } + ) + + with pytest.raises(ValueError, match=r"latitude.*greater than -60"): + validate(df) + + def test_validate_lat_too_high(self): + """Test that latitude > 85 raises ValueError.""" + df = pd.DataFrame( + { + "id": ["outlet1"], + "lat": [90.0], # Above 85 (MERIT-Hydro coverage limit) + "lng": [-14.3], + "outlet_id": ["outlet1"], + } + ) + + with pytest.raises(ValueError, match=r"latitude.*less than 85"): + validate(df) + + def test_validate_lng_too_low(self): + """Test that longitude < -180 raises ValueError.""" + df = pd.DataFrame( + { + "id": ["outlet1"], + "lat": [65.5], + "lng": [-200.0], # Below -180 + "outlet_id": ["outlet1"], + } + ) + + with pytest.raises(ValueError, match=r"longitude.*greater than -180"): + validate(df) + + def test_validate_lng_too_high(self): + """Test that longitude > 180 raises ValueError.""" + df = pd.DataFrame( + { + "id": ["outlet1"], + "lat": [65.5], + "lng": [200.0], # Above 180 + "outlet_id": ["outlet1"], + } + ) + + with pytest.raises(ValueError, match=r"longitude.*less than 180"): + validate(df) + + def test_validate_id_zero_not_allowed(self): + """Test that id=0 is not allowed (reserved for ocean discharge).""" + df = pd.DataFrame( + { + "id": ["0"], # Reserved value + "lat": [65.5], + "lng": [-14.3], + "outlet_id": ["0"], + } + ) + + with pytest.raises(ValueError, match="id of 0 not allowed"): + validate(df) + + def test_validate_outlet_id_must_reference_existing_id(self): + """Test that outlet_id must reference an existing id in the CSV.""" + df = pd.DataFrame( + { + "id": ["point1"], + "lat": [65.5], + "lng": [-14.3], + "outlet_id": ["nonexistent"], # This id doesn't exist + } + ) + + with pytest.raises(ValueError, match=r"outlet_id.*must reference id"): + validate(df) + + +class TestFindRepeatedElements: + """Test the find_repeated_elements utility function.""" + + def test_no_repeats(self): + """Test list with no repeated elements.""" + lst = [1, 2, 3, 4, 5] + result = find_repeated_elements(lst) + assert result == [] + + def test_single_repeat(self): + """Test list with one repeated element.""" + lst = [1, 2, 2, 3, 4] + result = find_repeated_elements(lst) + assert result == [2] + + def test_multiple_repeats(self): + """Test list with multiple repeated elements.""" + lst = [1, 2, 2, 3, 3, 4] + result = find_repeated_elements(lst) + assert set(result) == {2, 3} + + def test_all_same(self): + """Test list where all elements are the same.""" + lst = [5, 5, 5, 5] + result = find_repeated_elements(lst) + assert result == [5] + + def test_empty_list(self): + """Test empty list returns empty list.""" + lst = [] + result = find_repeated_elements(lst) + assert result == [] + + def test_single_element(self): + """Test single element list returns empty list.""" + lst = [42] + result = find_repeated_elements(lst) + assert result == [] diff --git a/upstream_delineator/delineator_utils/delineate.py b/upstream_delineator/delineator_utils/delineate.py index 9a58dd9..781dbd4 100644 --- a/upstream_delineator/delineator_utils/delineate.py +++ b/upstream_delineator/delineator_utils/delineate.py @@ -250,7 +250,7 @@ def addnode(B: list, node_id): gages_gdf, gage_catchments_gdf, how="intersection", make_valid=True ) gages_gdf.set_index("id", inplace=True) - gages_gdf.set_crs(crs=PROJ_WGS84) + gages_gdf.set_crs(crs=PROJ_WGS84, inplace=True) # For any gages for which we could not find a unit catchment, add issue a warning # Basically checking which rows do not appear after doing the overlay @@ -288,9 +288,10 @@ def addnode(B: list, node_id): subbasins_gdf = catchments_gdf.loc[upstream_comids] # Add lat, lng, and NextDownID to subbasins_gdf. subbasins_gdf = subbasins_gdf.join(rivers_gdf[["lat", "lng", "NextDownID"]]) - # Re-name the NextDownID field, and make sure it is an integer + # Re-name the NextDownID field. Use object dtype to support mixed int/string IDs + # (COMIDs are integers but custom outlet IDs may be strings) subbasins_gdf.rename(columns={"NextDownID": "nextdown"}, inplace=True) - subbasins_gdf["nextdown"] = subbasins_gdf["nextdown"].astype(int) + subbasins_gdf["nextdown"] = subbasins_gdf["nextdown"].astype(object) subbasins_gdf["custom"] = ( False # Adds a column that shows whether a subbasin is connected to a custom pour point ) @@ -319,7 +320,7 @@ def addnode(B: list, node_id): # For now, put the split polygon geometry into a field in `gages_gdf` gages_gdf["polygon"] = None - gages_gdf["polygon_area"] = 0 + gages_gdf["polygon_area"] = 0.0 # Use float to allow fractional areas # Iterate over the gages, and run `split_catchment()` for every gage for gage_id in gages_gdf.index: @@ -452,7 +453,9 @@ def addnode(B: list, node_id): # After the dissolve operation, no way to preserve correct information in column 'nextdown' # But it is present in the Graph, so update the GeoDataFrame subbasins_gdf with that information # Add column `nextdownid` and the stream orders based on data in the graph - subbasins_gdf["nextdown"] = 0 + # Use object dtype to support mixed int/string IDs + subbasins_gdf["nextdown"] = None + subbasins_gdf["nextdown"] = subbasins_gdf["nextdown"].astype(object) for idx in subbasins_gdf.index: try: @@ -490,6 +493,9 @@ def addnode(B: list, node_id): pass # TODO why are we catching all exceptions? # Add the fields `nextdown` and the stream orders to the rivers. + # Initialize nextdown as object dtype to support mixed int/string IDs + myrivers_gdf["nextdown"] = None + myrivers_gdf["nextdown"] = myrivers_gdf["nextdown"].astype(object) for idx in myrivers_gdf.index: try: nextdown = next(iter(G.successors(idx))) @@ -575,17 +581,15 @@ def update_split_catchment_geo(gages_gdf, myrivers_gdf, rivers_gdf, subbasins_gd "polygon": "geometry", } gages_gdf.rename(columns=rnmap, inplace=True) - gages_gdf.set_crs(crs=PROJ_WGS84) - gages_gdf.set_geometry(col="geometry") + gages_gdf = gages_gdf.set_geometry(col="geometry") + gages_gdf.set_crs(crs=PROJ_WGS84, inplace=True) # First, handle the gages where there is only one gage in a unit catchment (standard treatment) # The new unit catchments (or nodes in the network) will always be upstream of the unit catchment # that we are inserting it into # insert these rows into `subbasins_gdf` if len(singles) > 0: - selected_rows = gages_gdf[gages_gdf["COMID"].isin(singles)] - selected_rows.set_crs( - crs=PROJ_WGS84 - ) # Just needed to eliminate an annoying warning + selected_rows = gages_gdf[gages_gdf["COMID"].isin(singles)].copy() + selected_rows.set_crs(crs=PROJ_WGS84, inplace=True) # This creates the dictionary `new_nodes` that maps gage id : unit catchment comid new_nodes = selected_rows["COMID"].to_dict() @@ -633,8 +637,8 @@ def update_split_catchment_geo(gages_gdf, myrivers_gdf, rivers_gdf, subbasins_gd # (Note: It is not really possible to understand the following code without a picture of what it is doing!!!) for comid in repeats: # Find all the gages that fall in this unit catchment. - gages_set = gages_gdf[gages_gdf["COMID"] == comid] - gages_set.set_crs(crs=PROJ_WGS84) + gages_set = gages_gdf[gages_gdf["COMID"] == comid].copy() + gages_set.set_crs(crs=PROJ_WGS84, inplace=True) # We want to handle these in order from downstream to upstream. # This is the same as largest area to smallest area gages_set.sort_values(by="unitarea", inplace=True) diff --git a/upstream_delineator/delineator_utils/util.py b/upstream_delineator/delineator_utils/util.py index a61fb2e..5531feb 100644 --- a/upstream_delineator/delineator_utils/util.py +++ b/upstream_delineator/delineator_utils/util.py @@ -3,7 +3,7 @@ import pickle import re import warnings -from functools import cache, partial +from functools import cache import geopandas as gpd import matplotlib.pyplot as plt @@ -190,7 +190,7 @@ def calc_area(poly: Polygon) -> float: """ Calculates the approximate area of a Shapely polygon in raw lat, lng coordinates (CRS=4326) First projects it into the Albers Equal Area projection to facilitate calculation. - No + Args: poly: Shapely polygon Returns: @@ -199,14 +199,13 @@ def calc_area(poly: Polygon) -> float: if poly.is_empty: return 0 - projected_poly = shapely.ops.transform( - partial( - pyproj.transform, - pyproj.Proj(init=PROJ_WGS84), - pyproj.Proj(proj="aea", lat_1=poly.bounds[1], lat_2=poly.bounds[3]), - ), - poly, + # Use modern pyproj API with Transformer instead of deprecated transform function + crs_wgs84 = pyproj.CRS(PROJ_WGS84) + crs_aea = pyproj.CRS.from_proj4( + f"+proj=aea +lat_1={poly.bounds[1]} +lat_2={poly.bounds[3]}" ) + transformer = pyproj.Transformer.from_crs(crs_wgs84, crs_aea, always_xy=True) + projected_poly = shapely.ops.transform(transformer.transform, poly) # Get the area in m^2 return projected_poly.area / 1e6 @@ -225,16 +224,15 @@ def calc_length(line: LineString) -> float: if line.is_empty: return 0 - projected_line = shapely.ops.transform( - partial( - pyproj.transform, - pyproj.Proj(init=PROJ_WGS84), - pyproj.Proj(proj="aea", lat_1=line.bounds[1], lat_2=line.bounds[3]), - ), - line, + # Use modern pyproj API with Transformer instead of deprecated transform function + crs_wgs84 = pyproj.CRS(PROJ_WGS84) + crs_aea = pyproj.CRS.from_proj4( + f"+proj=aea +lat_1={line.bounds[1]} +lat_2={line.bounds[3]}" ) + transformer = pyproj.Transformer.from_crs(crs_wgs84, crs_aea, always_xy=True) + projected_line = shapely.ops.transform(transformer.transform, line) - # Get the area in m^2 + # Get the length in m return projected_line.length / 1e3