66
77import numpy as np
88import pandas as pd
9- from typing import List , Optional
9+ from typing import List , Optional , Union , Dict
1010import pathlib
1111from ...modelling .features .fault import FaultSegment
1212
@@ -123,8 +123,7 @@ def __init__(self, *args):
123123 self .feature_name_index = {}
124124 self ._data = pd .DataFrame () # None
125125
126- self .stratigraphic_column = StratigraphicColumn ()
127-
126+ self ._stratigraphic_column = StratigraphicColumn ()
128127
129128 self .tol = 1e-10 * np .max (self .bounding_box .maximum - self .bounding_box .origin )
130129 self ._dtm = None
@@ -187,7 +186,6 @@ def prepare_data(self, data: pd.DataFrame) -> pd.DataFrame:
187186 ].astype (float )
188187 return data
189188
190-
191189 if "type" in data :
192190 logger .warning ("'type' is deprecated replace with 'feature_name' \n " )
193191 data .rename (columns = {"type" : "feature_name" }, inplace = True )
@@ -409,7 +407,6 @@ def fault_names(self):
409407 """
410408 return [f .name for f in self .faults ]
411409
412-
413410 def to_file (self , file ):
414411 """Save a model to a pickle file requires dill
415412
@@ -506,10 +503,34 @@ def data(self, data: pd.DataFrame):
506503 self ._data = data .copy ()
507504 # self._data[['X','Y','Z']] = self.bounding_box.project(self._data[['X','Y','Z']].to_numpy())
508505
509-
510506 def set_model_data (self , data ):
511507 logger .warning ("deprecated method. Model data can now be set using the data attribute" )
512508 self .data = data .copy ()
509+ @property
510+ def stratigraphic_column (self ):
511+ """Get the stratigraphic column of the model
512+
513+ Returns
514+ -------
515+ StratigraphicColumn
516+ the stratigraphic column of the model
517+ """
518+ return self ._stratigraphic_column
519+ @stratigraphic_column .setter
520+ def stratigraphic_column (self , stratigraphic_column : Union [StratigraphicColumn ,Dict ]):
521+ """Set the stratigraphic column of the model
522+
523+ Parameters
524+ ----------
525+ stratigraphic_column : StratigraphicColumn
526+ the stratigraphic column to set
527+ """
528+ if isinstance (stratigraphic_column , dict ):
529+ self .set_stratigraphic_column (stratigraphic_column )
530+ return
531+ elif not isinstance (stratigraphic_column , StratigraphicColumn ):
532+ raise ValueError ("stratigraphic_column must be a StratigraphicColumn object" )
533+ self ._stratigraphic_column = stratigraphic_column
513534
514535 def set_stratigraphic_column (self , stratigraphic_column , cmap = "tab20" ):
515536 """
@@ -1400,7 +1421,6 @@ def rescale(self, points: np.ndarray, *, inplace: bool = False) -> np.ndarray:
14001421
14011422 return self .bounding_box .reproject (points , inplace = inplace )
14021423
1403-
14041424 # TODO move scale to bounding box/transformer
14051425 def scale (self , points : np .ndarray , * , inplace : bool = False ) -> np .ndarray :
14061426 """Take points in UTM coordinates and reproject
@@ -1419,7 +1439,6 @@ def scale(self, points: np.ndarray, *, inplace: bool = False) -> np.ndarray:
14191439 """
14201440 return self .bounding_box .project (np .array (points ).astype (float ), inplace = inplace )
14211441
1422-
14231442 def regular_grid (self , * , nsteps = None , shuffle = True , rescale = False , order = "C" ):
14241443 """
14251444 Return a regular grid within the model bounding box
@@ -1494,22 +1513,18 @@ def evaluate_model(self, xyz: np.ndarray, *, scale: bool = True) -> np.ndarray:
14941513 if self .stratigraphic_column is None :
14951514 logger .warning ("No stratigraphic column defined" )
14961515 return strat_id
1497- for group in reversed ( self . stratigraphic_column . keys ()):
1498- if group == "faults" :
1499- continue
1500- feature_id = self .feature_name_index .get (group , - 1 )
1516+
1517+ s_id = 0
1518+ for g in reversed ( self . stratigraphic_column . get_groups ()):
1519+ feature_id = self .feature_name_index .get (g . name , - 1 )
15011520 if feature_id >= 0 :
1502- feature = self .features [feature_id ]
1503- vals = feature .evaluate_value (xyz )
1504- for series in self .stratigraphic_column [group ].values ():
1505- strat_id [
1506- np .logical_and (
1507- vals < series .get ("max" , feature .max ()),
1508- vals > series .get ("min" , feature .min ()),
1509- )
1510- ] = series ["id" ]
1521+ vals = self .features [feature_id ].evaluate_value (xyz )
1522+ for u in g .units :
1523+ strat_id [np .logical_and (vals < u .max , vals > u .min )] = s_id
1524+ s_id += 1
15111525 if feature_id == - 1 :
1512- logger .error (f"Model does not contain { group } " )
1526+ logger .error (f"Model does not contain { g .name } " )
1527+
15131528 return strat_id
15141529
15151530 def evaluate_model_gradient (self , points : np .ndarray , * , scale : bool = True ) -> np .ndarray :
0 commit comments