Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions src/pyfia/core/fia.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,9 @@ def _apply_spatial_filter(self, df: pl.LazyFrame, table_name: str) -> pl.LazyFra
return df
if table_name == "PLOT":
return df.filter(pl.col("CN").is_in(self._spatial_plot_cns))
elif table_name in ["TREE", "COND"]:
# Filter any table with a PLT_CN column by the spatial plot CNs
schema = self._reader.get_table_schema(table_name)
if "PLT_CN" in schema:
return df.filter(pl.col("PLT_CN").is_in(self._spatial_plot_cns))
return df

Expand Down Expand Up @@ -346,9 +348,15 @@ def load_table(
pl.LazyFrame
Polars LazyFrame of the requested table.
"""
# Build base WHERE clause for state filter
# Inspect table schema to determine which filters apply
table_schema = self._reader.get_table_schema(table_name)
table_columns = set(table_schema.keys())
has_plt_cn = "PLT_CN" in table_columns
has_statecd = "STATECD" in table_columns

# Build base WHERE clause for state filter (any table with STATECD)
base_where_clause = None
if self.state_filter and table_name in ["PLOT", "COND", "TREE"]:
if self.state_filter and has_statecd:
state_list = ", ".join(str(s) for s in self.state_filter)
base_where_clause = f"STATECD IN ({state_list})"

Expand All @@ -359,9 +367,9 @@ def load_table(
else:
base_where_clause = where

# EVALID filter via PLT_CN for TREE, COND tables
# This is a critical optimization - it reduces data load by 90%+ for GRM estimates
if self.evalid and table_name in ["TREE", "COND"]:
# EVALID filter via PLT_CN for any table that has a PLT_CN column
# This is a critical optimization - it reduces data load by 90%+
if self.evalid and has_plt_cn:
valid_plot_cns = self._get_valid_plot_cns()
if valid_plot_cns:
from .utils import batch_query_by_values
Expand Down Expand Up @@ -404,7 +412,7 @@ def query_batch(batch: list) -> pl.LazyFrame:
self.tables[table_name] = result
return self.tables[table_name]

# Default path - no EVALID filtering or not a filterable table
# Default path - no EVALID filtering or table has no PLT_CN column
df = self._reader.read_table(
table_name,
columns=columns,
Expand Down
160 changes: 160 additions & 0 deletions tests/unit/test_load_table_filtering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""Unit tests for load_table() EVALID and state filtering.

Verifies that load_table() applies PLT_CN-based EVALID filtering and
STATECD-based state filtering to any table that has those columns,
not just a hardcoded allowlist of table names.
"""

from unittest.mock import MagicMock, patch

import polars as pl
import pytest


@pytest.fixture
def mock_fia():
"""Create a mock FIA instance with the real load_table method."""
from pyfia.core.fia import FIA

with patch.object(FIA, "__init__", lambda self: None):
db = FIA()
db.tables = {}
db.evalid = None
db.state_filter = None
db._polygon_attributes = None
db._spatial_plot_cns = None
db._valid_plot_cns = None
db._reader = MagicMock()
return db


class TestEVALIDFilteringByColumn:
"""EVALID filtering should apply to any table with PLT_CN, not just TREE/COND."""

def test_tree_grm_component_gets_evalid_filtered(self, mock_fia):
"""TREE_GRM_COMPONENT has PLT_CN and should be EVALID-filtered."""
mock_fia.evalid = [132303]
mock_fia._valid_plot_cns = ["100", "200", "300"]
mock_fia._reader.get_table_schema.return_value = {
"CN": "VARCHAR",
"TRE_CN": "VARCHAR",
"PLT_CN": "VARCHAR",
"COMPONENT": "VARCHAR",
"TPA_UNADJ": "DOUBLE",
}
mock_fia._reader.read_table.return_value = pl.DataFrame(
{"TRE_CN": ["1"], "PLT_CN": ["100"], "TPA_UNADJ": [1.0]}
).lazy()

mock_fia.load_table("TREE_GRM_COMPONENT")

# Verify read_table was called with a PLT_CN IN (...) WHERE clause
call_args = mock_fia._reader.read_table.call_args
where_clause = call_args.kwargs.get("where", "") or ""
assert "PLT_CN IN" in where_clause

def test_table_without_plt_cn_skips_evalid_filter(self, mock_fia):
"""Tables without PLT_CN (e.g. POP_EVAL) should not get EVALID filtering."""
mock_fia.evalid = [132303]
mock_fia._valid_plot_cns = ["100", "200"]
mock_fia._reader.get_table_schema.return_value = {
"CN": "VARCHAR",
"EVALID": "INTEGER",
"EVAL_DESCR": "VARCHAR",
}
mock_fia._reader.read_table.return_value = pl.DataFrame(
{"CN": ["1"], "EVALID": [132303], "EVAL_DESCR": ["test"]}
).lazy()

mock_fia.load_table("POP_EVAL")

# Should use default path without PLT_CN filtering
call_args = mock_fia._reader.read_table.call_args
where_clause = call_args.kwargs.get("where", "") or ""
assert "PLT_CN IN" not in where_clause

def test_no_evalid_set_skips_filter(self, mock_fia):
"""When no EVALID is set, PLT_CN filtering should be skipped."""
mock_fia.evalid = None
mock_fia._reader.get_table_schema.return_value = {
"CN": "VARCHAR",
"PLT_CN": "VARCHAR",
"TPA_UNADJ": "DOUBLE",
}
mock_fia._reader.read_table.return_value = pl.DataFrame(
{"CN": ["1"], "PLT_CN": ["100"], "TPA_UNADJ": [1.0]}
).lazy()

mock_fia.load_table("TREE_GRM_COMPONENT")

# Should use default path
call_args = mock_fia._reader.read_table.call_args
where_clause = call_args.kwargs.get("where", "") or ""
assert "PLT_CN IN" not in where_clause


class TestStateFilteringByColumn:
"""State filtering should apply to any table with STATECD."""

def test_table_with_statecd_gets_filtered(self, mock_fia):
"""Any table with STATECD should get state filtering."""
mock_fia.state_filter = [13] # Georgia
mock_fia._reader.get_table_schema.return_value = {
"CN": "VARCHAR",
"PLT_CN": "VARCHAR",
"STATECD": "INTEGER",
}
mock_fia._reader.read_table.return_value = pl.DataFrame(
{"CN": ["1"], "PLT_CN": ["100"], "STATECD": [13]}
).lazy()

mock_fia.load_table("SEEDLING")

call_args = mock_fia._reader.read_table.call_args
where_clause = call_args.kwargs.get("where", "") or ""
assert "STATECD IN (13)" in where_clause

def test_table_without_statecd_skips_filter(self, mock_fia):
"""Tables without STATECD should not get state filtering."""
mock_fia.state_filter = [13]
mock_fia._reader.get_table_schema.return_value = {
"CN": "VARCHAR",
"EVALID": "INTEGER",
}
mock_fia._reader.read_table.return_value = pl.DataFrame(
{"CN": ["1"], "EVALID": [132303]}
).lazy()

mock_fia.load_table("POP_EVAL")

call_args = mock_fia._reader.read_table.call_args
where_clause = call_args.kwargs.get("where", "") or ""
assert "STATECD" not in where_clause


class TestSpatialFilteringByColumn:
"""Spatial filtering should apply to any table with PLT_CN."""

def test_spatial_filter_applies_to_grm_table(self, mock_fia):
"""Tables with PLT_CN should get spatial filtering when active."""
mock_fia._spatial_plot_cns = ["100", "200"]
mock_fia._reader.get_table_schema.return_value = {
"CN": "VARCHAR",
"PLT_CN": "VARCHAR",
"TPA_UNADJ": "DOUBLE",
}
data = pl.DataFrame(
{
"CN": ["1", "2", "3"],
"PLT_CN": ["100", "200", "999"],
"TPA_UNADJ": [1.0, 2.0, 3.0],
}
).lazy()
mock_fia._reader.read_table.return_value = data

result = mock_fia.load_table("TREE_GRM_COMPONENT")

# Should filter to only spatial plot CNs
collected = result.collect()
assert collected.shape[0] == 2
assert set(collected["PLT_CN"].to_list()) == {"100", "200"}
Loading