@@ -63,6 +63,35 @@ def sklearn_model(train_data):
6363    return  model 
6464
6565
66+ @pytest .fixture  
67+ def  sklearn_pipeline (train_data ):
68+     from  sklearn .pipeline  import  Pipeline 
69+     from  sklearn .ensemble  import  GradientBoostingClassifier 
70+     from  sklearn .preprocessing  import  StandardScaler 
71+     from  sklearn .impute  import  SimpleImputer 
72+     from  sklearn .compose  import  ColumnTransformer 
73+ 
74+     X , y  =  train_data 
75+ 
76+     numeric_transformer  =  Pipeline ([
77+         ('imputer' , SimpleImputer (strategy = 'median' )),
78+         ('scaler' , StandardScaler ())
79+     ])
80+ 
81+     preprocessor  =  ColumnTransformer ([
82+         ('num' , numeric_transformer , X .columns )
83+     ])
84+ 
85+     pipe  =  Pipeline ([
86+         ('preprocess' , preprocessor ),
87+         ('classifier' , GradientBoostingClassifier ())
88+     ])
89+ 
90+     pipe .fit (X , y )
91+ 
92+     return  pipe 
93+ 
94+ 
6695@pytest .fixture  
6796def  pickle_file (tmpdir_factory , sklearn_model ):
6897    """Returns the path to a file containing a pickled Scikit-Learn model """ 
@@ -215,6 +244,17 @@ def test_from_python_file(python_file):
215244    assert  isinstance (p , PyMAS )
216245
217246
247+ def  test_with_sklearn_pipeline (train_data , sklearn_pipeline ):
248+     from  sasctl .utils .pymas  import  PyMAS , from_pickle 
249+ 
250+     X , y  =  train_data 
251+     p  =  from_pickle (pickle .dumps (sklearn_pipeline ),
252+                     func_name = 'predict' ,
253+                     input_types = X )
254+ 
255+     assert  isinstance (p , PyMAS )
256+     assert  len (p .variables ) >  4   # 4 input features in Iris data set 
257+ 
218258@pytest .mark .usefixtures ('session' ) 
219259def  test_publish_and_execute (tmpdir ):
220260    import  pickle 
0 commit comments