diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fcdaa28..0a8bfdf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,6 @@ repos: rev: 'v1.10.1' # Use the sha / tag you want to point at hooks: - id: mypy - additional_dependencies: [types-all] - repo: https://github.com/codespell-project/codespell rev: v2.3.0 hooks: diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c8cb16..5d0a8af 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,34 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.1.0] - 2025-12-08 + +### Added +- Directional wave features for spatial-temporal disease modeling in GBQR +- Spatial utilities module (`spatial_utils.py`) with location centroids for US states +- State centroids data file (`state_centroids.csv`) with geographic coordinates for all 50 US states, DC, and territories +- Data directory README documenting available data files +- Haversine distance and bearing calculations for spatial analysis +- `create_directional_wave_features()` function in preprocessing pipeline +- Configuration options for directional wave features (disabled by default for backwards compatibility) +- Comprehensive test suite: 20 unit tests for spatial utilities, 14 for feature generation, 7 integration tests +- Documentation for directional wave features implementation + +### Removed +- Optional "examples" dependencies (jupyter, matplotlib, plotly) + +## [1.0.0] - 2025-11-24 + +### Added +- NSSP data source support for SARIX and GBQR models +- HSA (Hospital Service Area) level forecasting support +- State-level forecasting for NSSP data + +### Changed +- **Breaking**: `run_config.locations` superseded by `run_config.states` and `run_config.hsas` +- NSSP predictions restricted to [0, 1] range for proportion data +- Updated test infrastructure with config creation helpers + ## [0.1.0] - 2025-11-03 ### Added @@ -43,6 +71,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Updated to latest iddata API -[Unreleased]: https://github.com/reichlab/idmodels/compare/v0.1.0...HEAD +[Unreleased]: https://github.com/reichlab/idmodels/compare/v1.1.0...HEAD +[1.1.0]: https://github.com/reichlab/idmodels/compare/v1.0.0...v1.1.0 +[1.0.0]: https://github.com/reichlab/idmodels/compare/v0.1.0...v1.0.0 [0.1.0]: https://github.com/reichlab/idmodels/compare/v0.0.1...v0.1.0 [0.0.1]: https://github.com/reichlab/idmodels/releases/tag/v0.0.1 diff --git a/docs/directional_wave_features.md b/docs/directional_wave_features.md new file mode 100644 index 0000000..c565bfd --- /dev/null +++ b/docs/directional_wave_features.md @@ -0,0 +1,354 @@ +# Directional Wave Features for GBQR Model + +## Overview + +Directional wave features capture spatial-temporal patterns in disease spread by computing distance-weighted averages of neighboring locations' incidence in specified directions (N, NE, E, SE, S, SW, W, NW). + +These features allow the GBQR model to learn how disease "waves" propagate geographically over time, improving forecast accuracy when spatial spread patterns are important. + +## Motivation + +Traditional forecasting models treat each location independently or use simple spatial averaging. Directional wave features enable the model to: + +1. **Capture directional spread patterns**: Disease may spread preferentially in certain directions (e.g., following travel corridors, climate patterns) +2. **Learn wave propagation speed**: By including temporal lags, the model can learn how long it takes for a wave to travel between locations +3. **Distinguish between spreading and receding waves**: Velocity features capture acceleration/deceleration of spread + +## Feature Types + +For each location and time point, the following features are generated: + +### 1. Base Directional Features +- `inc_trans_cs_wave_N`: Distance-weighted average of northern neighbors' incidence +- `inc_trans_cs_wave_NE`: Distance-weighted average of northeastern neighbors' incidence +- `inc_trans_cs_wave_E`: Distance-weighted average of eastern neighbors' incidence +- ... (one for each specified direction) + +### 2. Aggregate Feature +- `inc_trans_cs_wave_avg`: Overall distance-weighted average of all neighbors (regardless of direction) + +### 3. Temporal Lag Features +- `inc_trans_cs_wave_N_lag1`: Northern neighbors' incidence from 1 week ago +- `inc_trans_cs_wave_N_lag2`: Northern neighbors' incidence from 2 weeks ago +- ... (for each direction and lag) + +**Important**: `lag1` refers to time t-1, `lag2` refers to time t-2, etc. + +### 4. Velocity Features (optional) +- `inc_trans_cs_wave_N_velocity`: Rate of change = current - lag1 +- ... (one for each direction) + +## Configuration + +Directional wave features are **disabled by default** for backwards compatibility. To enable them, add the following parameters to your `model_config`: + +```python +from types import SimpleNamespace + +model_config = SimpleNamespace( + # ... existing parameters ... + + # Directional wave features (disabled by default) + use_directional_waves = True, # Set to True to enable + + # Which directions to compute (subset of: N, NE, E, SE, S, SW, W, NW) + wave_directions = ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW'], # Default: all 8 + + # Temporal lags to include (lag1 = t-1, lag2 = t-2) + wave_temporal_lags = [1, 2], # Default: [1, 2] + + # Maximum distance (km) to consider as neighbor + wave_max_distance_km = 1000, # Default: 1000 + + # Include velocity (rate of change) features + wave_include_velocity = False, # Default: False + + # Include aggregate weighted average feature + wave_include_aggregate = True # Default: True +) +``` + +### Configuration Parameters Explained + +#### `use_directional_waves` (bool, default: False) +- Master switch to enable/disable directional wave features +- Must be set to `True` to generate wave features + +#### `wave_directions` (list of str, default: ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW']) +- Which directions to compute features for +- Valid directions: N, NE, E, SE, S, SW, W, NW +- Each direction has a 45° cone (±22.5° around center) +- Examples: + - `['N', 'S', 'E', 'W']` - Just cardinal directions (4 features) + - `['NE', 'SW']` - Just diagonal directions (2 features) + +#### `wave_temporal_lags` (list of int, default: [1, 2]) +- Which temporal lags to include +- `lag1` means t-1 (last week), `lag2` means t-2 (two weeks ago) +- Example: `[1, 2, 3]` includes 1, 2, and 3 week lags + +#### `wave_max_distance_km` (float, default: 1000) +- Maximum distance (kilometers) to consider a location as a neighbor +- Only locations within this distance are included in directional averages +- Larger values include more distant neighbors (slower computation) +- Typical values: + - 500-1000 km for state-level analysis (immediate neighbors) + - 2000-3000 km for regional patterns + - 5000+ km for continent-wide patterns + +#### `wave_include_velocity` (bool, default: False) +- Whether to include velocity features (rate of change) +- Velocity = current - lag1 +- Captures acceleration/deceleration of wave spread +- Increases feature count by ~50% (one velocity per direction) + +#### `wave_include_aggregate` (bool, default: True) +- Whether to include overall weighted average (all neighbors, any direction) +- Provides general spatial context independent of direction +- Recommended to keep enabled + +## Example Configurations + +### Minimal Configuration (4 cardinal directions) +```python +model_config = SimpleNamespace( + # ... other params ... + use_directional_waves = True, + wave_directions = ['N', 'S', 'E', 'W'] +) +``` +Generates: 4 base + 4 aggregate + (4+1)×2 lags = **14 features** + +### Standard Configuration (8 directions) +```python +model_config = SimpleNamespace( + # ... other params ... + use_directional_waves = True, + wave_directions = ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW'], + wave_temporal_lags = [1, 2] +) +``` +Generates: 8 base + 1 aggregate + (8+1)×2 lags = **27 features** + +### Maximum Information (all options) +```python +model_config = SimpleNamespace( + # ... other params ... + use_directional_waves = True, + wave_directions = ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW'], + wave_temporal_lags = [1, 2], + wave_max_distance_km = 2000, + wave_include_velocity = True, + wave_include_aggregate = True +) +``` +Generates: 8 base + 1 aggregate + (8+1)×2 lags + (8+1) velocity = **36 features** + +### Hypothesis-Driven (specific directions) +```python +# If you suspect disease spreads along NE-SW axis +model_config = SimpleNamespace( + # ... other params ... + use_directional_waves = True, + wave_directions = ['NE', 'SW'], + wave_temporal_lags = [1, 2, 3], # Longer lags for slower spread + wave_max_distance_km = 1500 +) +``` + +## Technical Details + +### Directional Cone Definition +Each direction captures neighbors within a 45° cone: +- **N (North)**: 0° ±22.5° (337.5° to 22.5°) +- **NE (Northeast)**: 45° ±22.5° (22.5° to 67.5°) +- **E (East)**: 90° ±22.5° (67.5° to 112.5°) +- ... and so on + +### Distance Weighting +Inverse distance weighting is used: +``` +weight = 1 / distance +weighted_average = Σ(weight × neighbor_value) / Σ(weight) +``` + +Closer neighbors have more influence on the feature value. + +### Lag Semantics +- **Base feature** (no lag suffix): Uses current time t +- **lag1**: Uses time t-1 (one week ago) +- **lag2**: Uses time t-2 (two weeks ago) + +This allows the model to learn patterns like: "If northern neighbors had high incidence last week (lag1), expect it here this week." + +### Missing Values +- If a location has no neighbors in a direction (within max_distance_km), the feature value is NaN +- These are handled by LightGBM during training +- Edge locations (e.g., coastal states) may have missing values for certain directions + +## Location Support + +Currently supported: +- **State-level** (agg_level='state'): US states, DC, PR, and national level + +**Data Source:** State centroids are loaded from `src/idmodels/data/state_centroids.csv`, which contains geographic centroids computed from US Census Bureau TIGER/Line shapefiles. See `src/idmodels/data/README.md` for detailed source information. + +To add support for other aggregation levels (county, HSA, etc.): +1. Create a CSV file (e.g., `county_centroids.csv`) with columns: `fips`, `name`, `latitude`, `longitude` +2. Place it in `src/idmodels/data/` +3. Update `_load_state_centroids()` function in `src/idmodels/spatial_utils.py` to support the new level +4. Document the data source in `src/idmodels/data/README.md` + +## Performance Considerations + +### Computational Cost +- Scales O(n²) with number of locations (n) +- For 50 states: ~2,500 distance calculations (precomputed) +- Feature computation is done once per training run + +### Feature Count Impact +Feature count depends on configuration: +- Base: n_directions + (1 if aggregate else 0) +- With lags: base × (1 + len(temporal_lags)) +- With velocity: total × 1.5 (approximately) + +More features → longer training time, but potentially better predictions + +### Recommendations +- Start with default configuration (8 directions, 2 lags) +- Experiment with fewer directions if training is slow +- Use `wave_include_velocity=False` unless you have evidence of acceleration patterns + +## Interpretation + +### Feature Importance +After training, you can examine feature importance to understand: +- Which directions are most predictive (e.g., is NE spread more important than SW?) +- Whether lags matter (are lag1 features more important than current?) +- Whether velocity features add value + +### Example Interpretations +- **High importance for `wave_N_lag1`**: Disease tends to arrive from the north with 1-week delay +- **High importance for `wave_avg`**: General spatial clustering matters more than direction +- **High importance for `wave_NE_velocity`**: Acceleration of northeastern spread is predictive + +## Warnings and Validation + +The implementation includes validation that warns about: +- **Opposite directions included**: If both N and S (or E and W, etc.) are included, they may be correlated in datasets with uniform spatial patterns. However, in typical epidemic scenarios with directional spread, opposite directions provide independent information. Tree-based models like LightGBM are also robust to multicollinearity, so this warning is informational rather than critical. + +## Example: Complete GBQR Configuration + +```python +from types import SimpleNamespace +from idmodels.gbqr import GBQRModel + +# Model configuration with directional wave features +model_config = SimpleNamespace( + model_class = "gbqr", + model_name = "gbqr_with_waves", + + # Standard GBQR parameters + incl_level_feats = True, + num_bags = 10, + bag_frac_samples = 0.7, + reporting_adj = False, + sources = ["nhsn"], + fit_locations_separately = False, + power_transform = "4rt", + + # Directional wave features + use_directional_waves = True, + wave_directions = ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW'], + wave_temporal_lags = [1, 2], + wave_max_distance_km = 1500, + wave_include_velocity = False, + wave_include_aggregate = True +) + +# Run configuration +run_config = SimpleNamespace( + disease = "flu", + ref_date = datetime.date(2024, 1, 6), + output_root = "output/", + artifact_store_root = "artifacts/", + save_feat_importance = True, + locations = None, # All locations + max_horizon = 4, + q_levels = [0.025, 0.10, 0.25, 0.50, 0.75, 0.90, 0.975], + q_labels = ["0.025", "0.1", "0.25", "0.5", "0.75", "0.9", "0.975"] +) + +# Run model +model = GBQRModel(model_config) +model.run(run_config) +``` + +## Backwards Compatibility + +The implementation is fully backwards compatible: +- Disabled by default (`use_directional_waves = False`) +- Existing configurations without wave parameters work unchanged +- Uses `hasattr()` checks to gracefully handle missing attributes + +Old configurations will continue to work without modification. + +## Testing + +The implementation includes comprehensive tests: + +### Unit Tests +- `tests/unit/test_spatial_utils.py`: Tests for spatial calculations (distance, bearing, neighbors) +- `tests/unit/test_directional_wave_features.py`: Tests for feature generation logic + +### Integration Tests +- `tests/integration/test_gbqr_wave_features.py`: End-to-end tests with realistic data + +Run tests with: +```bash +uv run pytest tests/unit/test_spatial_utils.py -v +uv run pytest tests/unit/test_directional_wave_features.py -v +uv run pytest tests/integration/test_gbqr_wave_features.py -v +``` + +## References + +### Epidemiological Motivation +- Spatial spread of infectious diseases often follows directional patterns +- Travel corridors, population density gradients, and climate patterns create anisotropic spread +- Historical examples: 1918 flu pandemic, COVID-19 spread in US + +### Implementation Details +- Haversine distance formula for great circle distance +- Bearing calculation using spherical trigonometry +- Inverse distance weighting for spatial interpolation + +## Future Enhancements + +Potential extensions: +1. **Additional aggregation levels**: County, HSA, HRR support +2. **Custom distance weighting**: Gaussian kernel, exponential decay +3. **Population-weighted features**: Weight by neighbor population, not just distance +4. **Temporal smoothing**: Moving averages of wave features +5. **Asymmetric cones**: Different cone widths for different directions + +## Troubleshooting + +### "Missing coordinates for locations" error +- Ensure all locations in your data have entries in `STATE_CENTROIDS` (spatial_utils.py) +- Check that `agg_level` in your data matches supported levels ('state') + +### Wave features are all NaN +- Check `wave_max_distance_km` - may be too small +- Verify location codes match those in `STATE_CENTROIDS` +- Some edge locations (islands, coastal states) naturally have fewer neighbors + +### Training is slow +- Reduce number of directions (try just ['N', 'S', 'E', 'W']) +- Reduce `wave_max_distance_km` +- Disable velocity features +- Use `fit_locations_separately=True` in model_config + +## Contact + +For questions or issues, please file an issue on the GitHub repository. diff --git a/pyproject.toml b/pyproject.toml index 29e8526..2280971 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,8 +42,11 @@ testpaths = [ "tests", ] -[tools.setuptools] -packages = ["idmodels"] +[tool.setuptools] +packages = {find = {where = ["src"]}} + +[tool.setuptools.package-data] +idmodels = ["data/*.csv", "data/*.md"] [tool.ruff] line-length = 120 @@ -57,3 +60,6 @@ quote-style = "double" [tool.setuptools.dynamic] version = {attr = "idmodels.__version__"} + +[tool.codespell] +ignore-words-list = "hsa" diff --git a/src/idmodels/__init__.py b/src/idmodels/__init__.py index 5becc17..aebf1b8 100644 --- a/src/idmodels/__init__.py +++ b/src/idmodels/__init__.py @@ -1 +1,2 @@ -__version__ = "1.0.0" +__version__ = "1.1.0" + diff --git a/src/idmodels/data/README.md b/src/idmodels/data/README.md new file mode 100644 index 0000000..f86c354 --- /dev/null +++ b/src/idmodels/data/README.md @@ -0,0 +1,61 @@ +# Geographic Data for Spatial Features + +This directory contains geographic reference data used for computing directional wave features. + +## Files + +### `state_centroids.csv` + +Geographic centroids (latitude, longitude) for US states, territories, and national level. + +**Columns:** +- `fips`: FIPS code (2-digit for states, 'US' for national) +- `state_name`: Human-readable state/territory name +- `latitude`: Latitude in decimal degrees +- `longitude`: Longitude in decimal degrees + +**Source:** +These centroids are computed from US Census Bureau TIGER/Line shapefiles representing state boundaries. The centroids represent the geographic center of each state's land area and are suitable for distance and bearing calculations in epidemiological spatial analysis. + +**Reference:** +- US Census Bureau TIGER/Line Shapefiles: https://www.census.gov/geographies/mapping-files/time-series/geo/tiger-line-file.html +- Computed using geographic (not population-weighted) centroids + +**Coverage:** +- 50 US states +- District of Columbia +- Puerto Rico +- National aggregate ('US') + +**Coordinate System:** +- WGS84 (EPSG:4326) +- Decimal degrees + +**Usage:** +These coordinates are used by `spatial_utils.py` to: +1. Calculate distances between locations (Haversine formula) +2. Determine bearing/direction between locations +3. Find directional neighbors for wave feature computation + +**Accuracy:** +Centroids are accurate to ~4 decimal places (~10 meters), which is more than sufficient for state-level spatial analysis where typical distances are hundreds of kilometers. + +## Future Additions + +Additional geographic reference files can be added here: +- `county_centroids.csv` - County-level centroids (3,000+ locations) +- `hsa_centroids.csv` - Hospital Service Area centroids +- `hrr_centroids.csv` - Hospital Referral Region centroids + +## Data Update Policy + +These centroids are relatively stable (state boundaries change rarely). If updates are needed: +1. Download current TIGER/Line shapefiles from US Census Bureau +2. Compute centroids using GIS software (e.g., QGIS, GeoPandas) +3. Export to CSV with same format +4. Update this README with new source date +5. Run tests to verify no breaking changes + +## License + +Geographic boundary data from US Census Bureau is in the public domain (US Government work). diff --git a/src/idmodels/data/__init__.py b/src/idmodels/data/__init__.py new file mode 100644 index 0000000..de7e6c4 --- /dev/null +++ b/src/idmodels/data/__init__.py @@ -0,0 +1 @@ +"""Geographic reference data for spatial analysis.""" diff --git a/src/idmodels/data/state_centroids.csv b/src/idmodels/data/state_centroids.csv new file mode 100644 index 0000000..bf63c95 --- /dev/null +++ b/src/idmodels/data/state_centroids.csv @@ -0,0 +1,54 @@ +fips,state_name,latitude,longitude +01,Alabama,32.7794,-86.8287 +02,Alaska,64.0685,-152.2782 +04,Arizona,34.2744,-111.6602 +05,Arkansas,34.8938,-92.4426 +06,California,37.1841,-119.4696 +08,Colorado,38.9972,-105.5478 +09,Connecticut,41.6219,-72.7273 +10,Delaware,38.9896,-75.5050 +11,District of Columbia,38.9072,-77.0369 +12,Florida,28.6305,-82.4497 +13,Georgia,32.6415,-83.4426 +15,Hawaii,20.2927,-156.3737 +16,Idaho,44.3509,-114.6130 +17,Illinois,40.0417,-89.1965 +18,Indiana,39.8942,-86.2816 +19,Iowa,42.0751,-93.4960 +20,Kansas,38.4937,-98.3804 +21,Kentucky,37.5347,-85.3021 +22,Louisiana,31.0689,-91.9968 +23,Maine,45.3695,-69.2428 +24,Maryland,39.0550,-76.7909 +25,Massachusetts,42.2596,-71.8083 +26,Michigan,44.3467,-85.4102 +27,Minnesota,46.2807,-94.3053 +28,Mississippi,32.7364,-89.6678 +29,Missouri,38.3566,-92.4580 +30,Montana,47.0527,-109.6333 +31,Nebraska,41.5378,-99.7951 +32,Nevada,39.3289,-116.6312 +33,New Hampshire,43.6805,-71.5811 +34,New Jersey,40.1907,-74.6728 +35,New Mexico,34.4071,-106.1126 +36,New York,42.9538,-75.5268 +37,North Carolina,35.5557,-79.3877 +38,North Dakota,47.4501,-100.4659 +39,Ohio,40.2862,-82.7937 +40,Oklahoma,35.5889,-97.4943 +41,Oregon,43.9336,-120.5583 +42,Pennsylvania,40.8781,-77.7996 +44,Rhode Island,41.6762,-71.5562 +45,South Carolina,33.9169,-80.8964 +46,South Dakota,44.4443,-100.2263 +47,Tennessee,35.8580,-86.3505 +48,Texas,31.4757,-99.3312 +49,Utah,39.3055,-111.6703 +50,Vermont,44.0687,-72.6658 +51,Virginia,37.5215,-78.8537 +53,Washington,47.3826,-120.4472 +54,West Virginia,38.6409,-80.6227 +55,Wisconsin,44.6243,-89.9941 +56,Wyoming,42.9957,-107.5512 +72,Puerto Rico,18.2208,-66.5901 +US,United States,39.8283,-98.5795 diff --git a/src/idmodels/gbqr.py b/src/idmodels/gbqr.py index 1850a6d..5fbb9a3 100644 --- a/src/idmodels/gbqr.py +++ b/src/idmodels/gbqr.py @@ -6,7 +6,7 @@ from iddata.loader import DiseaseDataLoader from tqdm.autonotebook import tqdm -from idmodels.preprocess import create_features_and_targets +from idmodels.preprocess import create_directional_wave_features, create_features_and_targets from idmodels.utils import build_save_path @@ -68,7 +68,20 @@ def run(self, run_config): init_feats = ["inc_trans_cs", "season_week", "log_pop"] elif run_config.disease == "covid": init_feats = ["inc_trans_cs", "log_pop"] - + + # Create directional wave features if enabled + if hasattr(self.model_config, "use_directional_waves") and self.model_config.use_directional_waves: + wave_config = { + "enabled": True, + "directions": getattr(self.model_config, "wave_directions", ["N", "NE", "E", "SE", "S", "SW", "W", "NW"]), + "temporal_lags": getattr(self.model_config, "wave_temporal_lags", [1, 2]), + "max_distance_km": getattr(self.model_config, "wave_max_distance_km", 1000), + "include_velocity": getattr(self.model_config, "wave_include_velocity", False), + "include_aggregate": getattr(self.model_config, "wave_include_aggregate", True) + } + df, wave_feat_names = create_directional_wave_features(df, wave_config) + init_feats = init_feats + wave_feat_names + df, feat_names = create_features_and_targets( df = df, incl_level_feats=self.model_config.incl_level_feats, diff --git a/src/idmodels/preprocess.py b/src/idmodels/preprocess.py index e080411..1581762 100644 --- a/src/idmodels/preprocess.py +++ b/src/idmodels/preprocess.py @@ -1,9 +1,12 @@ import fnmatch +import numpy as np import pandas as pd from iddata.utils import get_holidays from timeseriesutils import featurize +from idmodels.spatial_utils import get_directional_neighbors, get_location_centroids, validate_wave_directions + def create_features_and_targets(df, incl_level_feats, max_horizon, curr_feat_names = []): ''' @@ -129,3 +132,238 @@ def _drop_level_feats(feat_names): feat_names = [f for f in feat_names if f not in level_feats] return feat_names + +def create_directional_wave_features(df, wave_config=None): + """ + Create spatial directional wave features. + + For each location and time point, computes distance-weighted averages + of neighboring locations' incidence in specified directions (e.g., N, S, E, W). + Also computes lagged versions and optionally velocity (rate of change) features. + + Parameters + ---------- + df : pandas.DataFrame + Data frame with columns: location, wk_end_date, inc_trans_cs, agg_level + wave_config : dict, optional + Configuration dictionary with keys: + - 'enabled': bool (default: False) - whether to generate features + - 'directions': list of str (default: ['N', 'NE', 'E', 'SE', 'S', 'SW', 'W', 'NW']) + Subset of: N, NE, E, SE, S, SW, W, NW + - 'temporal_lags': list of int (default: [1, 2]) - temporal lags to include + lag1 means t-1, lag2 means t-2, etc. + - 'max_distance_km': float (default: 1000) - max distance for neighbors + - 'include_velocity': bool (default: False) - include rate-of-change features + - 'include_aggregate': bool (default: True) - include overall weighted average + + Returns + ------- + df : pandas.DataFrame + Input dataframe augmented with wave features + wave_feat_names : list of str + List of new feature names added to df + + Notes + ----- + - Lag semantics: lag1 uses time t-1, lag2 uses time t-2, etc. + - Velocity features compute: wave(t) - wave(t-1) + - Distance weighting uses inverse distance: weight = 1 / distance + - Features are computed per location and time point + """ + # Return early if not enabled + if wave_config is None or not wave_config.get("enabled", False): + return df, [] + + # Extract configuration with defaults + directions = wave_config.get("directions", ["N", "NE", "E", "SE", "S", "SW", "W", "NW"]) + temporal_lags = wave_config.get("temporal_lags", [1, 2]) + max_distance_km = wave_config.get("max_distance_km", 1000) + include_velocity = wave_config.get("include_velocity", False) + include_aggregate = wave_config.get("include_aggregate", True) + + # Validate directions + validate_wave_directions(directions) + + # Get aggregation level(s) from dataframe + agg_levels = df["agg_level"].unique() + if len(agg_levels) > 1: + raise ValueError( + f"Multiple aggregation levels found: {agg_levels}. " + f"Directional wave features currently support only one agg_level at a time." + ) + agg_level = agg_levels[0] + + # Get location centroids for this aggregation level + try: + location_coords = get_location_centroids(agg_level=agg_level) + except ValueError as e: + raise ValueError( + f"Cannot create directional wave features: {str(e)}" + ) + + # Filter to locations present in both data and coordinate lookup + locations_in_df = set(df["location"].unique()) + locations_with_coords = set(location_coords.keys()) + locations_to_use = locations_in_df.intersection(locations_with_coords) + + if len(locations_to_use) < len(locations_in_df): + missing = locations_in_df - locations_with_coords + raise ValueError( + f"Missing coordinates for locations: {missing}. " + f"Cannot compute directional wave features." + ) + + # Precompute directional neighbors for each location + neighbor_cache = {} + for loc in locations_to_use: + neighbor_cache[loc] = {} + for direction in directions: + neighbors = get_directional_neighbors( + origin_loc=loc, + origin_coord=location_coords[loc], + all_coords=location_coords, + direction=direction, + max_distance_km=max_distance_km + ) + neighbor_cache[loc][direction] = neighbors + + # Also compute all neighbors (for aggregate feature) + if include_aggregate: + all_neighbor_cache = {} + for loc in locations_to_use: + # Get all neighbors regardless of direction + neighbors = [] + for other_loc, other_coord in location_coords.items(): + if other_loc == loc: + continue + from idmodels.spatial_utils import haversine_distance + distance = haversine_distance(location_coords[loc], other_coord) + if distance <= max_distance_km: + neighbors.append((other_loc, distance)) + neighbors.sort(key=lambda x: x[1]) + all_neighbor_cache[loc] = neighbors + + # Create features for each direction + wave_features = {} + + # Sort dataframe by location and date for efficient processing + df_sorted = df.sort_values(["location", "wk_end_date"]).reset_index(drop=True) + + # Compute base directional features (at time t) + for direction in directions: + feat_name = f"inc_trans_cs_wave_{direction}" + feat_values = [] + + for idx, row in df_sorted.iterrows(): + loc = row["location"] + date = row["wk_end_date"] + + # Get neighbors in this direction + neighbors = neighbor_cache[loc][direction] + + if len(neighbors) == 0: + # No neighbors in this direction + feat_values.append(np.nan) + continue + + # Compute distance-weighted average + weighted_sum = 0.0 + weight_sum = 0.0 + + for neighbor_loc, distance in neighbors: + # Get neighbor's inc_trans_cs at same time point + neighbor_value = df_sorted[ + (df_sorted["location"] == neighbor_loc) & + (df_sorted["wk_end_date"] == date) + ]["inc_trans_cs"] + + if len(neighbor_value) > 0 and not pd.isna(neighbor_value.iloc[0]): + # Inverse distance weighting + weight = 1.0 / distance if distance > 0 else 1.0 + weighted_sum += weight * neighbor_value.iloc[0] + weight_sum += weight + + if weight_sum > 0: + feat_values.append(weighted_sum / weight_sum) + else: + feat_values.append(np.nan) + + wave_features[feat_name] = feat_values + + # Compute aggregate feature (overall weighted average) + if include_aggregate: + feat_name = "inc_trans_cs_wave_avg" + feat_values = [] + + for idx, row in df_sorted.iterrows(): + loc = row["location"] + date = row["wk_end_date"] + + neighbors = all_neighbor_cache[loc] + + if len(neighbors) == 0: + feat_values.append(np.nan) + continue + + # Compute distance-weighted average + weighted_sum = 0.0 + weight_sum = 0.0 + + for neighbor_loc, distance in neighbors: + neighbor_value = df_sorted[ + (df_sorted["location"] == neighbor_loc) & + (df_sorted["wk_end_date"] == date) + ]["inc_trans_cs"] + + if len(neighbor_value) > 0 and not pd.isna(neighbor_value.iloc[0]): + weight = 1.0 / distance if distance > 0 else 1.0 + weighted_sum += weight * neighbor_value.iloc[0] + weight_sum += weight + + if weight_sum > 0: + feat_values.append(weighted_sum / weight_sum) + else: + feat_values.append(np.nan) + + wave_features[feat_name] = feat_values + + # Add base features to dataframe + for feat_name, feat_values in wave_features.items(): + df_sorted[feat_name] = feat_values + + # Create lagged features + lagged_features = {} + base_feat_names = list(wave_features.keys()) + + for feat_name in base_feat_names: + for lag in temporal_lags: + lagged_feat_name = f"{feat_name}_lag{lag}" + # Use groupby to create lags within each location + df_sorted[lagged_feat_name] = df_sorted.groupby("location")[feat_name].shift(lag) + lagged_features[lagged_feat_name] = None # Just track the name + + # Create velocity features (rate of change) + if include_velocity: + velocity_features = {} + for feat_name in base_feat_names: + # Velocity = current - lag1 + lag1_name = f"{feat_name}_lag1" + if lag1_name in df_sorted.columns or 1 in temporal_lags: + velocity_feat_name = f"{feat_name}_velocity" + if lag1_name not in df_sorted.columns: + # Need to create lag1 if it doesn't exist + df_sorted[lag1_name] = df_sorted.groupby("location")[feat_name].shift(1) + df_sorted[velocity_feat_name] = df_sorted[feat_name] - df_sorted[lag1_name] + velocity_features[velocity_feat_name] = None + + # Restore original index order + df_sorted = df_sorted.sort_index() + + # Collect all feature names + wave_feat_names = list(wave_features.keys()) + wave_feat_names += list(lagged_features.keys()) + if include_velocity: + wave_feat_names += list(velocity_features.keys()) + + return df_sorted, wave_feat_names + diff --git a/src/idmodels/spatial_utils.py b/src/idmodels/spatial_utils.py new file mode 100644 index 0000000..113cc5e --- /dev/null +++ b/src/idmodels/spatial_utils.py @@ -0,0 +1,293 @@ +""" +Spatial utilities for computing directional wave features. + +Provides location coordinates and functions for computing distances, +bearings, and directional neighbors. +""" + +import math +import warnings +from importlib import resources +from typing import Dict, List, Tuple + +import pandas as pd + +# Direction angles (degrees, 0° = North, clockwise) +DIRECTION_ANGLES = { + "N": 0, + "NE": 45, + "E": 90, + "SE": 135, + "S": 180, + "SW": 225, + "W": 270, + "NW": 315 +} + +# Cone width for each direction (degrees) +CONE_WIDTH = 45.0 # ±22.5° around center + +# Cache for loaded centroid data +_CENTROID_CACHE = {} + + +def _load_state_centroids() -> Dict[str, Tuple[float, float]]: + """ + Load US state centroids from bundled CSV file. + + Returns + ------- + dict + Mapping from FIPS code to (latitude, longitude) tuple + """ + if "state" in _CENTROID_CACHE: + return _CENTROID_CACHE["state"] + + # Load the CSV file from package data + try: + # Python 3.9+ + with resources.files("idmodels.data").joinpath("state_centroids.csv").open("r") as f: + df = pd.read_csv(f, dtype={"fips": str}) + except AttributeError: + # Python 3.7-3.8 fallback + import pkg_resources + csv_path = pkg_resources.resource_filename("idmodels", "data/state_centroids.csv") + df = pd.read_csv(csv_path, dtype={"fips": str}) + + # Convert to dictionary + centroids = {} + for _, row in df.iterrows(): + centroids[row["fips"]] = (row["latitude"], row["longitude"]) + + # Cache the result + _CENTROID_CACHE["state"] = centroids + + return centroids + + +def get_location_centroids(agg_level: str = "state") -> Dict[str, Tuple[float, float]]: + """ + Get location centroids for a given aggregation level. + + Loads location coordinates from bundled CSV files in the package data directory. + + Parameters + ---------- + agg_level : str + Aggregation level ('state', 'county', 'hsa', etc.) + Currently only 'state' is supported. + + Returns + ------- + dict + Mapping from location code to (latitude, longitude) tuple + + Raises + ------ + ValueError + If agg_level is not supported + + Notes + ----- + Data is loaded from `idmodels/data/{agg_level}_centroids.csv` and cached + for subsequent calls. See `idmodels/data/README.md` for data sources. + """ + if agg_level == "state": + return _load_state_centroids().copy() + else: + raise ValueError( + f"Aggregation level '{agg_level}' not supported. " + f"Currently only 'state' is implemented. " + f"To add support for other levels, add a {agg_level}_centroids.csv " + f"file to idmodels/data/ and extend _load_state_centroids()." + ) + + +def haversine_distance(coord1: Tuple[float, float], + coord2: Tuple[float, float]) -> float: + """ + Calculate the great circle distance between two points on Earth. + + Parameters + ---------- + coord1 : tuple + (latitude, longitude) of first point in degrees + coord2 : tuple + (latitude, longitude) of second point in degrees + + Returns + ------- + float + Distance in kilometers + """ + lat1, lon1 = coord1 + lat2, lon2 = coord2 + + # Convert to radians + lat1, lon1, lat2, lon2 = map(math.radians, [lat1, lon1, lat2, lon2]) + + # Haversine formula + dlat = lat2 - lat1 + dlon = lon2 - lon1 + a = math.sin(dlat/2)**2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon/2)**2 + c = 2 * math.asin(math.sqrt(a)) + + # Earth's radius in kilometers + r = 6371.0 + + return c * r + + +def compute_bearing(coord1: Tuple[float, float], + coord2: Tuple[float, float]) -> float: + """ + Calculate the bearing (direction) from coord1 to coord2. + + Parameters + ---------- + coord1 : tuple + (latitude, longitude) of origin point in degrees + coord2 : tuple + (latitude, longitude) of destination point in degrees + + Returns + ------- + float + Bearing in degrees (0° = North, clockwise, range [0, 360)) + """ + lat1, lon1 = coord1 + lat2, lon2 = coord2 + + # Convert to radians + lat1, lon1, lat2, lon2 = map(math.radians, [lat1, lon1, lat2, lon2]) + + # Calculate bearing + dlon = lon2 - lon1 + x = math.sin(dlon) * math.cos(lat2) + y = math.cos(lat1) * math.sin(lat2) - math.sin(lat1) * math.cos(lat2) * math.cos(dlon) + + bearing_rad = math.atan2(x, y) + bearing_deg = math.degrees(bearing_rad) + + # Normalize to [0, 360) + bearing_deg = (bearing_deg + 360) % 360 + + return bearing_deg + + +def get_directional_neighbors( + origin_loc: str, + origin_coord: Tuple[float, float], + all_coords: Dict[str, Tuple[float, float]], + direction: str, + max_distance_km: float +) -> List[Tuple[str, float]]: + """ + Find neighbors of origin location within a directional cone. + + Parameters + ---------- + origin_loc : str + Location code of origin + origin_coord : tuple + (latitude, longitude) of origin location + all_coords : dict + Mapping from location codes to (lat, lon) tuples + direction : str + Direction name (one of: N, NE, E, SE, S, SW, W, NW) + max_distance_km : float + Maximum distance in kilometers to consider as neighbor + + Returns + ------- + list of tuples + List of (location_code, distance) for neighbors in the cone, + sorted by distance (nearest first) + + Raises + ------ + ValueError + If direction is not recognized + """ + if direction not in DIRECTION_ANGLES: + raise ValueError( + f"Invalid direction '{direction}'. " + f"Must be one of: {', '.join(sorted(DIRECTION_ANGLES.keys()))}" + ) + + direction_angle = DIRECTION_ANGLES[direction] + half_cone = CONE_WIDTH / 2.0 + + neighbors = [] + + for loc_code, loc_coord in all_coords.items(): + # Skip origin location + if loc_code == origin_loc: + continue + + # Calculate distance + distance = haversine_distance(origin_coord, loc_coord) + + # Skip if too far + if distance > max_distance_km: + continue + + # Calculate bearing + bearing = compute_bearing(origin_coord, loc_coord) + + # Check if bearing is within directional cone + # Handle wraparound at 0°/360° + angle_diff = abs((bearing - direction_angle + 180) % 360 - 180) + + if angle_diff <= half_cone: + neighbors.append((loc_code, distance)) + + # Sort by distance (nearest first) + neighbors.sort(key=lambda x: x[1]) + + return neighbors + + +def validate_wave_directions(wave_directions: List[str]) -> None: + """ + Validate that all directions are recognized. + + Parameters + ---------- + wave_directions : list of str + List of direction names to validate + + Raises + ------ + ValueError + If any direction is not recognized + + Warnings + -------- + UserWarning + If opposite directions are both included. In datasets with uniform spatial patterns, + this may lead to correlation between opposite direction features. + """ + valid_directions = set(DIRECTION_ANGLES.keys()) + + # Check for invalid directions + for direction in wave_directions: + if direction not in valid_directions: + raise ValueError( + f"Invalid direction '{direction}'. " + f"Must be one of: {', '.join(sorted(valid_directions))}" + ) + + # Check for opposite direction pairs (potential multicollinearity warning) + opposite_pairs = [("N", "S"), ("E", "W"), ("NE", "SW"), ("NW", "SE")] + wave_set = set(wave_directions) + + for dir1, dir2 in opposite_pairs: + if dir1 in wave_set and dir2 in wave_set: + warnings.warn( + f"Both {dir1} and {dir2} directions are included. " + f"In datasets with uniform spatial patterns, opposite directions may be correlated. " + f"Consider checking correlation if multicollinearity is a concern " + f"(note: tree-based models like GBQR are robust to multicollinearity).", + UserWarning + ) diff --git a/tests/integration/test_gbqr_wave_features.py b/tests/integration/test_gbqr_wave_features.py new file mode 100644 index 0000000..c81419c --- /dev/null +++ b/tests/integration/test_gbqr_wave_features.py @@ -0,0 +1,292 @@ +"""Integration test for GBQR model with directional wave features.""" + + +from types import SimpleNamespace + +import numpy as np +import pandas as pd + +from idmodels.preprocess import create_directional_wave_features, create_features_and_targets + + +def create_realistic_test_data(): + """Create realistic test data mimicking the structure from DiseaseDataLoader.""" + # Create data for several states over multiple weeks + np.random.seed(42) + + states = ["01", "06", "13", "36", "42", "48"] # AL, CA, GA, NY, PA, TX + dates = pd.date_range("2023-10-01", periods=20, freq="W") + + data = [] + for state in states: + for i, date in enumerate(dates): + # Create somewhat realistic progression of incidence + base_inc = 0.5 + 0.3 * np.sin(i / 10) + np.random.randn() * 0.1 + + data.append({ + "agg_level": "state", + "location": state, + "season": "2023-24", + "season_week": (i % 52) + 1, + "wk_end_date": date, + "inc": max(0, base_inc), + "source": "nhsn", + "pop": 1000000 + int(state) * 100000, + "log_pop": np.log(1000000 + int(state) * 100000), + "inc_trans": base_inc, + "inc_trans_scale_factor": 0.5, + "inc_trans_cs": base_inc * 0.5, + "inc_trans_center_factor": 0.1 + }) + + return pd.DataFrame(data) + + +def test_gbqr_preprocessing_without_waves(): + """Test that GBQR preprocessing works without wave features (backwards compatibility).""" + df = create_realistic_test_data() + + init_feats = ["inc_trans_cs", "log_pop"] + + # This should work without wave features (backwards compatibility) + df_result, feat_names = create_features_and_targets( + df=df, + incl_level_feats=True, + max_horizon=3, + curr_feat_names=init_feats + ) + + # Check that basic features are present + assert "inc_trans_cs" in feat_names + assert "log_pop" in feat_names + + # Check that no wave features are present + wave_feats = [f for f in feat_names if "wave" in f] + assert len(wave_feats) == 0 + + +def test_gbqr_preprocessing_with_waves_enabled(): + """Test that GBQR preprocessing works with wave features enabled.""" + df = create_realistic_test_data() + + # Create directional wave features + wave_config = { + "enabled": True, + "directions": ["N", "S", "E", "W"], + "temporal_lags": [1, 2], + "max_distance_km": 2000, + "include_velocity": False, + "include_aggregate": True + } + + df_with_waves, wave_feat_names = create_directional_wave_features(df, wave_config) + + # Check that wave features were created + assert len(wave_feat_names) > 0 + assert "inc_trans_cs_wave_N" in wave_feat_names + assert "inc_trans_cs_wave_S" in wave_feat_names + assert "inc_trans_cs_wave_E" in wave_feat_names + assert "inc_trans_cs_wave_W" in wave_feat_names + assert "inc_trans_cs_wave_avg" in wave_feat_names + assert "inc_trans_cs_wave_N_lag1" in wave_feat_names + assert "inc_trans_cs_wave_N_lag2" in wave_feat_names + + # Now pass through the full feature creation pipeline + init_feats = ["inc_trans_cs", "log_pop"] + wave_feat_names + + df_result, feat_names = create_features_and_targets( + df=df_with_waves, + incl_level_feats=True, + max_horizon=3, + curr_feat_names=init_feats + ) + + # Check that all wave features are in the final feature list + for wave_feat in wave_feat_names: + assert wave_feat in feat_names + + # Check that basic features are still present + assert "inc_trans_cs" in feat_names + assert "log_pop" in feat_names + + # Check that targets were created + assert "delta_target" in df_result.columns + + +def test_gbqr_preprocessing_with_all_wave_options(): + """Test GBQR preprocessing with all wave feature options enabled.""" + df = create_realistic_test_data() + + # Create directional wave features with all options + wave_config = { + "enabled": True, + "directions": ["N", "NE", "E", "SE", "S", "SW", "W", "NW"], + "temporal_lags": [1, 2], + "max_distance_km": 2000, + "include_velocity": True, + "include_aggregate": True + } + + df_with_waves, wave_feat_names = create_directional_wave_features(df, wave_config) + + # Expected features: + # - 8 directions + # - 1 aggregate + # - Each has: base + lag1 + lag2 + velocity + # Total: 9 * 4 = 36 features + expected_feature_count = 9 * 4 + assert len(wave_feat_names) == expected_feature_count + + # Check that velocity features exist + assert "inc_trans_cs_wave_N_velocity" in wave_feat_names + assert "inc_trans_cs_wave_avg_velocity" in wave_feat_names + + # Pass through full pipeline + init_feats = ["inc_trans_cs", "log_pop"] + wave_feat_names + + df_result, feat_names = create_features_and_targets( + df=df_with_waves, + incl_level_feats=True, + max_horizon=3, + curr_feat_names=init_feats + ) + + # All wave features should be in final feature list + for wave_feat in wave_feat_names: + assert wave_feat in feat_names + + +def test_gbqr_wave_features_no_nan_for_valid_data(): + """Test that wave features produce valid values for locations with neighbors.""" + df = create_realistic_test_data() + + wave_config = { + "enabled": True, + "directions": ["N", "S"], + "temporal_lags": [], + "max_distance_km": 3000, + "include_velocity": False, + "include_aggregate": True + } + + df_with_waves, wave_feat_names = create_directional_wave_features(df, wave_config) + + # For aggregate feature, most locations should have some neighbors + # Check that we have at least some non-NaN values + avg_feature = df_with_waves["inc_trans_cs_wave_avg"] + non_nan_count = (~avg_feature.isna()).sum() + + # At least half of the values should be non-NaN (locations have neighbors) + assert non_nan_count > len(df_with_waves) * 0.5 + + +def test_gbqr_wave_features_with_model_config_pattern(): + """Test wave features using the model_config pattern from GBQR.""" + df = create_realistic_test_data() + + # Simulate model_config with wave feature settings + model_config = SimpleNamespace( + use_directional_waves=True, + wave_directions=["N", "S", "E", "W"], + wave_temporal_lags=[1, 2], + wave_max_distance_km=2000, + wave_include_velocity=False, + wave_include_aggregate=True + ) + + # This is how it would be called in GBQR.run() + init_feats = ["inc_trans_cs", "log_pop"] + + if hasattr(model_config, "use_directional_waves") and model_config.use_directional_waves: + wave_config = { + "enabled": True, + "directions": model_config.wave_directions, + "temporal_lags": model_config.wave_temporal_lags, + "max_distance_km": model_config.wave_max_distance_km, + "include_velocity": model_config.wave_include_velocity, + "include_aggregate": model_config.wave_include_aggregate + } + df, wave_feat_names = create_directional_wave_features(df, wave_config) + init_feats = init_feats + wave_feat_names + + # Verify wave features were added + assert len([f for f in init_feats if "wave" in f]) > 0 + + # Continue with normal preprocessing + df_result, feat_names = create_features_and_targets( + df=df, + incl_level_feats=True, + max_horizon=3, + curr_feat_names=init_feats + ) + + # Verify everything worked + assert len(feat_names) > len(["inc_trans_cs", "log_pop"]) + + +def test_gbqr_wave_features_backwards_compatibility(): + """Test that missing wave config attributes don't break GBQR.""" + df = create_realistic_test_data() + + # Model config WITHOUT wave feature settings (backwards compatibility) + model_config = SimpleNamespace( + incl_level_feats=True, + # No wave feature attributes + ) + + init_feats = ["inc_trans_cs", "log_pop"] + + # This check should pass and not add wave features + if hasattr(model_config, "use_directional_waves") and model_config.use_directional_waves: + # This block should not execute + raise AssertionError("Should not execute wave feature code") + + # Normal preprocessing should work + df_result, feat_names = create_features_and_targets( + df=df, + incl_level_feats=model_config.incl_level_feats, + max_horizon=3, + curr_feat_names=init_feats + ) + + # No wave features should be present + wave_feats = [f for f in feat_names if "wave" in f] + assert len(wave_feats) == 0 + + +def test_gbqr_wave_features_lag_values_are_correct(): + """Test that lag features contain correct time-shifted values.""" + df = create_realistic_test_data() + + wave_config = { + "enabled": True, + "directions": ["N"], + "temporal_lags": [1, 2], + "max_distance_km": 2000, + "include_velocity": False, + "include_aggregate": False + } + + df_with_waves, wave_feat_names = create_directional_wave_features(df, wave_config) + + # Check lag semantics for one location + test_location = "06" # California + loc_data = df_with_waves[df_with_waves["location"] == test_location] \ + .sort_values("wk_end_date") \ + .reset_index(drop=True) + + # Verify lag1 at time t equals base value at time t-1 + for i in range(1, len(loc_data)): + base_prev = loc_data.loc[i-1, "inc_trans_cs_wave_N"] + lag1_curr = loc_data.loc[i, "inc_trans_cs_wave_N_lag1"] + + if not pd.isna(base_prev) and not pd.isna(lag1_curr): + assert abs(base_prev - lag1_curr) < 1e-6 + + # Verify lag2 at time t equals base value at time t-2 + for i in range(2, len(loc_data)): + base_prev2 = loc_data.loc[i-2, "inc_trans_cs_wave_N"] + lag2_curr = loc_data.loc[i, "inc_trans_cs_wave_N_lag2"] + + if not pd.isna(base_prev2) and not pd.isna(lag2_curr): + assert abs(base_prev2 - lag2_curr) < 1e-6 diff --git a/tests/unit/test_directional_wave_features.py b/tests/unit/test_directional_wave_features.py new file mode 100644 index 0000000..5879531 --- /dev/null +++ b/tests/unit/test_directional_wave_features.py @@ -0,0 +1,366 @@ +"""Unit tests for directional wave feature generation.""" + +import numpy as np +import pandas as pd +import pytest + +from idmodels.preprocess import create_directional_wave_features + + +def create_test_dataframe(): + """Create a simple test dataframe with synthetic data.""" + # Create 3 locations over 5 time points + dates = pd.date_range("2024-01-01", periods=5, freq="W") + locations = ["01", "06", "36"] # Alabama, California, New York + + data = [] + for loc in locations: + for date in dates: + data.append({ + "location": loc, + "wk_end_date": date, + "inc_trans_cs": np.random.randn(), + "agg_level": "state", + "source": "nhsn" + }) + + return pd.DataFrame(data) + + +def test_create_directional_wave_features_disabled(): + """Test that function returns empty list when disabled.""" + df = create_test_dataframe() + original_cols = set(df.columns) + + # Test with None config + df_result, feat_names = create_directional_wave_features(df, wave_config=None) + assert feat_names == [] + assert set(df_result.columns) == original_cols + + # Test with enabled=False + wave_config = {"enabled": False} + df_result, feat_names = create_directional_wave_features(df, wave_config) + assert feat_names == [] + assert set(df_result.columns) == original_cols + + +def test_create_directional_wave_features_basic(): + """Test basic directional wave feature generation.""" + df = create_test_dataframe() + + wave_config = { + "enabled": True, + "directions": ["N", "S"], + "temporal_lags": [1], + "max_distance_km": 5000, + "include_velocity": False, + "include_aggregate": False + } + + df_result, feat_names = create_directional_wave_features(df, wave_config) + + # Should have base features + lag1 for each direction + expected_feats = [ + "inc_trans_cs_wave_N", + "inc_trans_cs_wave_S", + "inc_trans_cs_wave_N_lag1", + "inc_trans_cs_wave_S_lag1" + ] + + assert set(feat_names) == set(expected_feats) + + # Check that features were added to dataframe + for feat in feat_names: + assert feat in df_result.columns + + +def test_create_directional_wave_features_all_directions(): + """Test with all 8 directions.""" + df = create_test_dataframe() + + wave_config = { + "enabled": True, + "directions": ["N", "NE", "E", "SE", "S", "SW", "W", "NW"], + "temporal_lags": [], # No lags for simplicity + "max_distance_km": 5000, + "include_velocity": False, + "include_aggregate": False + } + + df_result, feat_names = create_directional_wave_features(df, wave_config) + + # Should have 8 base features (one per direction) + assert len(feat_names) == 8 + + expected_directions = ["N", "NE", "E", "SE", "S", "SW", "W", "NW"] + for direction in expected_directions: + assert f"inc_trans_cs_wave_{direction}" in feat_names + + +def test_create_directional_wave_features_with_aggregate(): + """Test with aggregate feature enabled.""" + df = create_test_dataframe() + + wave_config = { + "enabled": True, + "directions": ["N", "S"], + "temporal_lags": [], + "max_distance_km": 5000, + "include_velocity": False, + "include_aggregate": True + } + + df_result, feat_names = create_directional_wave_features(df, wave_config) + + # Should have N, S, and avg + assert "inc_trans_cs_wave_N" in feat_names + assert "inc_trans_cs_wave_S" in feat_names + assert "inc_trans_cs_wave_avg" in feat_names + + +def test_create_directional_wave_features_with_lags(): + """Test with multiple temporal lags.""" + df = create_test_dataframe() + + wave_config = { + "enabled": True, + "directions": ["N"], + "temporal_lags": [1, 2], + "max_distance_km": 5000, + "include_velocity": False, + "include_aggregate": False + } + + df_result, feat_names = create_directional_wave_features(df, wave_config) + + # Should have base + lag1 + lag2 + expected_feats = [ + "inc_trans_cs_wave_N", + "inc_trans_cs_wave_N_lag1", + "inc_trans_cs_wave_N_lag2" + ] + + assert set(feat_names) == set(expected_feats) + + +def test_create_directional_wave_features_with_velocity(): + """Test with velocity features enabled.""" + df = create_test_dataframe() + + wave_config = { + "enabled": True, + "directions": ["N"], + "temporal_lags": [1], + "max_distance_km": 5000, + "include_velocity": True, + "include_aggregate": False + } + + df_result, feat_names = create_directional_wave_features(df, wave_config) + + # Should have base + lag1 + velocity + assert "inc_trans_cs_wave_N" in feat_names + assert "inc_trans_cs_wave_N_lag1" in feat_names + assert "inc_trans_cs_wave_N_velocity" in feat_names + + # Check that velocity is computed correctly (current - lag1) + # For locations with enough history + for loc in df_result["location"].unique(): + loc_data = df_result[df_result["location"] == loc].reset_index(drop=True) + for i in range(1, len(loc_data)): + base_val = loc_data.loc[i, "inc_trans_cs_wave_N"] + lag1_val = loc_data.loc[i, "inc_trans_cs_wave_N_lag1"] + velocity_val = loc_data.loc[i, "inc_trans_cs_wave_N_velocity"] + + if not pd.isna(base_val) and not pd.isna(lag1_val): + expected_velocity = base_val - lag1_val + assert abs(velocity_val - expected_velocity) < 1e-6 + + +def test_create_directional_wave_features_lag_semantics(): + """Test that lag1 refers to t-1, lag2 to t-2.""" + df = create_test_dataframe() + + wave_config = { + "enabled": True, + "directions": ["N"], + "temporal_lags": [1, 2], + "max_distance_km": 5000, + "include_velocity": False, + "include_aggregate": False + } + + df_result, feat_names = create_directional_wave_features(df, wave_config) + + # Check lag semantics for one location + loc_data = df_result[df_result["location"] == "01"].sort_values("wk_end_date").reset_index(drop=True) + + # At index i, lag1 should equal base value at index i-1 + for i in range(1, len(loc_data)): + base_prev = loc_data.loc[i-1, "inc_trans_cs_wave_N"] + lag1_curr = loc_data.loc[i, "inc_trans_cs_wave_N_lag1"] + + # If both exist and are not NaN, they should match + if not pd.isna(base_prev) and not pd.isna(lag1_curr): + assert abs(base_prev - lag1_curr) < 1e-6 + + +def test_create_directional_wave_features_preserves_index(): + """Test that original dataframe index is preserved.""" + df = create_test_dataframe() + original_index = df.index.tolist() + + wave_config = { + "enabled": True, + "directions": ["N"], + "temporal_lags": [], + "max_distance_km": 5000, + "include_velocity": False, + "include_aggregate": False + } + + df_result, _ = create_directional_wave_features(df, wave_config) + + # Index should be preserved + assert df_result.index.tolist() == original_index + + +def test_create_directional_wave_features_invalid_direction(): + """Test that invalid directions raise ValueError.""" + df = create_test_dataframe() + + wave_config = { + "enabled": True, + "directions": ["N", "INVALID"], + "temporal_lags": [], + "max_distance_km": 5000, + "include_velocity": False, + "include_aggregate": False + } + + with pytest.raises(ValueError, match="Invalid direction"): + create_directional_wave_features(df, wave_config) + + +def test_create_directional_wave_features_multiple_agg_levels(): + """Test that multiple agg_levels raise ValueError.""" + df = create_test_dataframe() + + # Add a row with different agg_level + df.loc[len(df)] = { + "location": "01001", + "wk_end_date": pd.Timestamp("2024-01-01"), + "inc_trans_cs": 0.5, + "agg_level": "county", + "source": "nhsn" + } + + wave_config = { + "enabled": True, + "directions": ["N"], + "temporal_lags": [], + "max_distance_km": 5000, + "include_velocity": False, + "include_aggregate": False + } + + with pytest.raises(ValueError, match="Multiple aggregation levels"): + create_directional_wave_features(df, wave_config) + + +def test_create_directional_wave_features_missing_coordinates(): + """Test that missing coordinates raise ValueError.""" + df = create_test_dataframe() + + # Add a location without coordinates + df.loc[len(df)] = { + "location": "FAKE99", + "wk_end_date": pd.Timestamp("2024-01-01"), + "inc_trans_cs": 0.5, + "agg_level": "state", + "source": "nhsn" + } + + wave_config = { + "enabled": True, + "directions": ["N"], + "temporal_lags": [], + "max_distance_km": 5000, + "include_velocity": False, + "include_aggregate": False + } + + with pytest.raises(ValueError, match="Missing coordinates"): + create_directional_wave_features(df, wave_config) + + +def test_create_directional_wave_features_default_config(): + """Test that default configuration values work.""" + df = create_test_dataframe() + + # Minimal config - should use defaults + wave_config = { + "enabled": True + } + + df_result, feat_names = create_directional_wave_features(df, wave_config) + + # Should use default 8 directions + assert len([f for f in feat_names if "lag" not in f and "velocity" not in f]) == 9 # 8 directions + avg + + # Should have lag1 and lag2 (default temporal_lags=[1, 2]) + assert any("lag1" in f for f in feat_names) + assert any("lag2" in f for f in feat_names) + + +def test_create_directional_wave_features_feature_count(): + """Test that the correct number of features is generated.""" + df = create_test_dataframe() + + wave_config = { + "enabled": True, + "directions": ["N", "S", "E", "W"], # 4 directions + "temporal_lags": [1, 2], # 2 lags + "max_distance_km": 5000, + "include_velocity": True, # Add velocity + "include_aggregate": True # Add aggregate + } + + df_result, feat_names = create_directional_wave_features(df, wave_config) + + # Expected: + # - 4 base directional features + # - 1 aggregate feature + # - (4 + 1) * 2 lags = 10 lag features + # - 5 velocity features (4 directions + 1 aggregate) + # Total: 4 + 1 + 10 + 5 = 20 + assert len(feat_names) == 20 + + +def test_create_directional_wave_features_no_neighbors(): + """Test behavior when locations have no neighbors in a direction.""" + # Create single location + df = pd.DataFrame([ + { + "location": "01", + "wk_end_date": pd.Timestamp("2024-01-01"), + "inc_trans_cs": 0.5, + "agg_level": "state", + "source": "nhsn" + } + ]) + + wave_config = { + "enabled": True, + "directions": ["N"], + "temporal_lags": [], + "max_distance_km": 10, # Very small distance - no neighbors + "include_velocity": False, + "include_aggregate": False + } + + df_result, feat_names = create_directional_wave_features(df, wave_config) + + # Feature should exist but be NaN (no neighbors) + assert "inc_trans_cs_wave_N" in feat_names + assert pd.isna(df_result.loc[0, "inc_trans_cs_wave_N"]) diff --git a/tests/unit/test_spatial_utils.py b/tests/unit/test_spatial_utils.py new file mode 100644 index 0000000..3ca25e0 --- /dev/null +++ b/tests/unit/test_spatial_utils.py @@ -0,0 +1,305 @@ +"""Unit tests for spatial_utils module.""" + +import warnings + +import pytest + +from idmodels.spatial_utils import ( + DIRECTION_ANGLES, + compute_bearing, + get_directional_neighbors, + get_location_centroids, + haversine_distance, + validate_wave_directions, +) + + +def test_get_location_centroids_state(): + """Test that state centroids are returned correctly.""" + coords = get_location_centroids(agg_level="state") + + # Should have all 50 states + DC + PR + US + assert len(coords) >= 50 + + # Check specific states exist + assert "06" in coords # California + assert "36" in coords # New York + assert "US" in coords # National + + # Check coordinates are tuples of (lat, lon) + for loc_code, coord in coords.items(): + assert isinstance(coord, tuple) + assert len(coord) == 2 + lat, lon = coord + assert -90 <= lat <= 90 + assert -180 <= lon <= 180 + + +def test_get_location_centroids_unsupported(): + """Test that unsupported aggregation levels raise ValueError.""" + with pytest.raises(ValueError, match="not supported"): + get_location_centroids(agg_level="county") + + +def test_haversine_distance_same_point(): + """Test that distance from a point to itself is zero.""" + coord = (40.0, -75.0) + distance = haversine_distance(coord, coord) + assert distance == 0.0 + + +def test_haversine_distance_known_values(): + """Test haversine distance with known approximate values.""" + # New York (40.7128° N, 74.0060° W) to Los Angeles (34.0522° N, 118.2437° W) + # Approximate great circle distance: ~3,944 km + ny = (40.7128, -74.0060) + la = (34.0522, -118.2437) + + distance = haversine_distance(ny, la) + + # Check within 50km of expected (accounts for approximation) + assert 3900 < distance < 4000 + + +def test_haversine_distance_symmetric(): + """Test that distance is symmetric.""" + coord1 = (40.0, -75.0) + coord2 = (42.0, -71.0) + + dist1 = haversine_distance(coord1, coord2) + dist2 = haversine_distance(coord2, coord1) + + assert abs(dist1 - dist2) < 1e-10 + + +def test_compute_bearing_north(): + """Test bearing calculation for due north direction.""" + origin = (40.0, -75.0) + north = (41.0, -75.0) + + bearing = compute_bearing(origin, north) + + # Should be approximately 0° (North) + assert abs(bearing - 0.0) < 5 + + +def test_compute_bearing_east(): + """Test bearing calculation for due east direction.""" + origin = (40.0, -75.0) + east = (40.0, -74.0) + + bearing = compute_bearing(origin, east) + + # Should be approximately 90° (East) + assert abs(bearing - 90.0) < 5 + + +def test_compute_bearing_south(): + """Test bearing calculation for due south direction.""" + origin = (40.0, -75.0) + south = (39.0, -75.0) + + bearing = compute_bearing(origin, south) + + # Should be approximately 180° (South) + assert abs(bearing - 180.0) < 5 + + +def test_compute_bearing_west(): + """Test bearing calculation for due west direction.""" + origin = (40.0, -75.0) + west = (40.0, -76.0) + + bearing = compute_bearing(origin, west) + + # Should be approximately 270° (West) + assert abs(bearing - 270.0) < 5 + + +def test_compute_bearing_range(): + """Test that bearing is always in [0, 360) range.""" + coords = [ + (40.0, -75.0), + (35.0, -80.0), + (45.0, -70.0), + (38.0, -78.0) + ] + + for i, coord1 in enumerate(coords): + for j, coord2 in enumerate(coords): + if i != j: + bearing = compute_bearing(coord1, coord2) + assert 0 <= bearing < 360 + + +def test_get_directional_neighbors_north(): + """Test finding neighbors to the north.""" + # Create simple test coordinates + coords = { + "origin": (40.0, -75.0), + "north1": (42.0, -75.0), # Due north + "north2": (41.5, -74.8), # North-ish + "south": (38.0, -75.0), # Due south (should not be included) + "east": (40.0, -73.0), # Due east (should not be included) + } + + neighbors = get_directional_neighbors( + origin_loc="origin", + origin_coord=coords["origin"], + all_coords=coords, + direction="N", + max_distance_km=1000 + ) + + # Should find neighbors to the north + neighbor_locs = [loc for loc, _ in neighbors] + assert "north1" in neighbor_locs + assert "south" not in neighbor_locs + assert "east" not in neighbor_locs + + +def test_get_directional_neighbors_max_distance(): + """Test that max_distance_km filters out distant neighbors.""" + coords = { + "origin": (40.0, -75.0), + "close": (40.1, -75.0), # Very close (~11 km) + "far": (45.0, -75.0), # Far away (~550 km) + } + + # With large max distance, should find both + neighbors_large = get_directional_neighbors( + origin_loc="origin", + origin_coord=coords["origin"], + all_coords=coords, + direction="N", + max_distance_km=1000 + ) + assert len(neighbors_large) == 2 + + # With small max distance, should only find close one + neighbors_small = get_directional_neighbors( + origin_loc="origin", + origin_coord=coords["origin"], + all_coords=coords, + direction="N", + max_distance_km=100 + ) + assert len(neighbors_small) == 1 + assert neighbors_small[0][0] == "close" + + +def test_get_directional_neighbors_sorted_by_distance(): + """Test that neighbors are sorted by distance (nearest first).""" + coords = { + "origin": (40.0, -75.0), + "close": (40.5, -75.0), + "medium": (41.0, -75.0), + "far": (42.0, -75.0), + } + + neighbors = get_directional_neighbors( + origin_loc="origin", + origin_coord=coords["origin"], + all_coords=coords, + direction="N", + max_distance_km=1000 + ) + + # Extract distances + distances = [dist for _, dist in neighbors] + + # Should be sorted in ascending order + assert distances == sorted(distances) + + +def test_get_directional_neighbors_invalid_direction(): + """Test that invalid direction raises ValueError.""" + coords = {"origin": (40.0, -75.0)} + + with pytest.raises(ValueError, match="Invalid direction"): + get_directional_neighbors( + origin_loc="origin", + origin_coord=coords["origin"], + all_coords=coords, + direction="INVALID", + max_distance_km=1000 + ) + + +def test_validate_wave_directions_valid(): + """Test validation with valid directions.""" + # Should not raise + validate_wave_directions(["N", "S", "E", "W"]) + validate_wave_directions(["N", "NE", "E", "SE", "S", "SW", "W", "NW"]) + validate_wave_directions(["NE", "SW"]) + + +def test_validate_wave_directions_invalid(): + """Test validation with invalid directions.""" + with pytest.raises(ValueError, match="Invalid direction"): + validate_wave_directions(["N", "INVALID", "S"]) + + +def test_validate_wave_directions_opposite_warning(): + """Test that opposite directions trigger a warning.""" + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + validate_wave_directions(["N", "S"]) + + # Should have generated a warning about opposite directions + assert len(w) >= 1 + assert "opposite" in str(w[0].message).lower() + + +def test_direction_angles_coverage(): + """Test that all 8 directions are defined.""" + expected_directions = {"N", "NE", "E", "SE", "S", "SW", "W", "NW"} + assert set(DIRECTION_ANGLES.keys()) == expected_directions + + +def test_direction_angles_values(): + """Test that direction angles are correct.""" + assert DIRECTION_ANGLES["N"] == 0 + assert DIRECTION_ANGLES["NE"] == 45 + assert DIRECTION_ANGLES["E"] == 90 + assert DIRECTION_ANGLES["SE"] == 135 + assert DIRECTION_ANGLES["S"] == 180 + assert DIRECTION_ANGLES["SW"] == 225 + assert DIRECTION_ANGLES["W"] == 270 + assert DIRECTION_ANGLES["NW"] == 315 + + +def test_get_directional_neighbors_with_real_states(): + """Test directional neighbors using real state centroids.""" + coords = get_location_centroids("state") + + # Pennsylvania (42) should have neighbors in all directions + pa_coord = coords["42"] + + # Check NE direction - should include NY (36) + # NY is actually northeast of PA (bearing ~38°), not due north + ne_neighbors = get_directional_neighbors( + origin_loc="42", + origin_coord=pa_coord, + all_coords=coords, + direction="NE", + max_distance_km=500 + ) + + ne_locs = [loc for loc, _ in ne_neighbors] + assert "36" in ne_locs # New York is northeast of Pennsylvania + + # Check South direction - should include MD (24) or WV (54) + south_neighbors = get_directional_neighbors( + origin_loc="42", + origin_coord=pa_coord, + all_coords=coords, + direction="S", + max_distance_km=500 + ) + + south_locs = [loc for loc, _ in south_neighbors] + # At least one southern neighbor + assert len(south_locs) > 0 + # NY should not be in south neighbors + assert "36" not in south_locs diff --git a/uv.lock b/uv.lock index 5a13c7b..08be397 100644 --- a/uv.lock +++ b/uv.lock @@ -2,7 +2,8 @@ version = 1 revision = 3 requires-python = ">=3.11" resolution-markers = [ - "python_full_version >= '3.13'", + "python_full_version >= '3.14'", + "python_full_version == '3.13.*'", "python_full_version == '3.12.*'", "python_full_version < '3.12'", "python_version < '0'",