Skip to content

Commit f771cdd

Browse files
committed
fix: add setter/getter for stratigraphic column
ensures it can never be a dict
1 parent c50afdb commit f771cdd

File tree

1 file changed

+37
-22
lines changed

1 file changed

+37
-22
lines changed

LoopStructural/modelling/core/geological_model.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import numpy as np
88
import pandas as pd
9-
from typing import List, Optional
9+
from typing import List, Optional, Union, Dict
1010
import pathlib
1111
from ...modelling.features.fault import FaultSegment
1212

@@ -123,8 +123,7 @@ def __init__(self, *args):
123123
self.feature_name_index = {}
124124
self._data = pd.DataFrame() # None
125125

126-
self.stratigraphic_column = StratigraphicColumn()
127-
126+
self._stratigraphic_column = StratigraphicColumn()
128127

129128
self.tol = 1e-10 * np.max(self.bounding_box.maximum - self.bounding_box.origin)
130129
self._dtm = None
@@ -187,7 +186,6 @@ def prepare_data(self, data: pd.DataFrame) -> pd.DataFrame:
187186
].astype(float)
188187
return data
189188

190-
191189
if "type" in data:
192190
logger.warning("'type' is deprecated replace with 'feature_name' \n")
193191
data.rename(columns={"type": "feature_name"}, inplace=True)
@@ -409,7 +407,6 @@ def fault_names(self):
409407
"""
410408
return [f.name for f in self.faults]
411409

412-
413410
def to_file(self, file):
414411
"""Save a model to a pickle file requires dill
415412
@@ -506,10 +503,34 @@ def data(self, data: pd.DataFrame):
506503
self._data = data.copy()
507504
# self._data[['X','Y','Z']] = self.bounding_box.project(self._data[['X','Y','Z']].to_numpy())
508505

509-
510506
def set_model_data(self, data):
511507
logger.warning("deprecated method. Model data can now be set using the data attribute")
512508
self.data = data.copy()
509+
@property
510+
def stratigraphic_column(self):
511+
"""Get the stratigraphic column of the model
512+
513+
Returns
514+
-------
515+
StratigraphicColumn
516+
the stratigraphic column of the model
517+
"""
518+
return self._stratigraphic_column
519+
@stratigraphic_column.setter
520+
def stratigraphic_column(self, stratigraphic_column: Union[StratigraphicColumn,Dict]):
521+
"""Set the stratigraphic column of the model
522+
523+
Parameters
524+
----------
525+
stratigraphic_column : StratigraphicColumn
526+
the stratigraphic column to set
527+
"""
528+
if isinstance(stratigraphic_column, dict):
529+
self.set_stratigraphic_column(stratigraphic_column)
530+
return
531+
elif not isinstance(stratigraphic_column, StratigraphicColumn):
532+
raise ValueError("stratigraphic_column must be a StratigraphicColumn object")
533+
self._stratigraphic_column = stratigraphic_column
513534

514535
def set_stratigraphic_column(self, stratigraphic_column, cmap="tab20"):
515536
"""
@@ -1400,7 +1421,6 @@ def rescale(self, points: np.ndarray, *, inplace: bool = False) -> np.ndarray:
14001421

14011422
return self.bounding_box.reproject(points, inplace=inplace)
14021423

1403-
14041424
# TODO move scale to bounding box/transformer
14051425
def scale(self, points: np.ndarray, *, inplace: bool = False) -> np.ndarray:
14061426
"""Take points in UTM coordinates and reproject
@@ -1419,7 +1439,6 @@ def scale(self, points: np.ndarray, *, inplace: bool = False) -> np.ndarray:
14191439
"""
14201440
return self.bounding_box.project(np.array(points).astype(float), inplace=inplace)
14211441

1422-
14231442
def regular_grid(self, *, nsteps=None, shuffle=True, rescale=False, order="C"):
14241443
"""
14251444
Return a regular grid within the model bounding box
@@ -1494,22 +1513,18 @@ def evaluate_model(self, xyz: np.ndarray, *, scale: bool = True) -> np.ndarray:
14941513
if self.stratigraphic_column is None:
14951514
logger.warning("No stratigraphic column defined")
14961515
return strat_id
1497-
for group in reversed(self.stratigraphic_column.keys()):
1498-
if group == "faults":
1499-
continue
1500-
feature_id = self.feature_name_index.get(group, -1)
1516+
1517+
s_id = 0
1518+
for g in reversed(self.stratigraphic_column.get_groups()):
1519+
feature_id = self.feature_name_index.get(g.name, -1)
15011520
if feature_id >= 0:
1502-
feature = self.features[feature_id]
1503-
vals = feature.evaluate_value(xyz)
1504-
for series in self.stratigraphic_column[group].values():
1505-
strat_id[
1506-
np.logical_and(
1507-
vals < series.get("max", feature.max()),
1508-
vals > series.get("min", feature.min()),
1509-
)
1510-
] = series["id"]
1521+
vals = self.features[feature_id].evaluate_value(xyz)
1522+
for u in g.units:
1523+
strat_id[np.logical_and(vals < u.max, vals > u.min)] = s_id
1524+
s_id += 1
15111525
if feature_id == -1:
1512-
logger.error(f"Model does not contain {group}")
1526+
logger.error(f"Model does not contain {g.name}")
1527+
15131528
return strat_id
15141529

15151530
def evaluate_model_gradient(self, points: np.ndarray, *, scale: bool = True) -> np.ndarray:

0 commit comments

Comments
 (0)