1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- import os
1615import uuid
1716from pathlib import Path
1817
1918import pandas as pd
19+ import numpy as np
2020
21- from mostlyai .qa .report_from_statistics import report_from_statistics
22- from mostlyai .qa .report import report
21+ from mostlyai import qa
2322
24- baseball_path = Path (os .path .realpath (__file__ )).parent / "fixtures" / "baseball"
25- census_path = Path (os .path .realpath (__file__ )).parent / "fixtures" / "census"
23+
24+ def mock_data (n ):
25+ df = pd .DataFrame (
26+ {
27+ "int" : pd .Series (np .random .choice (list (range (10 )) + [np .nan ], n ), dtype = "Int64[pyarrow]" ),
28+ "float" : pd .Series (
29+ np .random .choice (list (np .random .uniform (size = 10 )) + [np .nan ], n ), dtype = "float64[pyarrow]"
30+ ),
31+ "cat" : pd .Series (np .random .choice (["f" , "m" , np .nan ], n ), dtype = "string[pyarrow]" ),
32+ "bool" : pd .Series (np .random .choice ([True , False , np .nan ], n ), dtype = "bool[pyarrow]" ),
33+ "date" : pd .Series (np .random .choice ([np .datetime64 ("today" , "D" ), np .nan ], n ), dtype = "datetime64[ns]" ),
34+ "text" : pd .Series ([str (uuid .uuid4 ())[:4 ] for _ in range (n )], dtype = "object" ),
35+ }
36+ )
37+ return df
2638
2739
2840def test_report_flat (tmp_path ):
2941 statistics_path = tmp_path / "statistics"
30- syn_tgt_data = pd . read_parquet ( census_path / "census-synthetic.parquet" )
31- trn_tgt_data = pd . read_parquet ( census_path / "census-training.parquet" )
32- hol_tgt_data = pd . read_parquet ( census_path / "census-holdout.parquet" )
33- report_path , metrics = report (
42+ syn_tgt_data = mock_data ( 220 )
43+ trn_tgt_data = mock_data ( 180 )
44+ hol_tgt_data = mock_data ( 140 )
45+ report_path , metrics = qa . report (
3446 syn_tgt_data = syn_tgt_data ,
3547 trn_tgt_data = trn_tgt_data ,
3648 hol_tgt_data = hol_tgt_data ,
3749 statistics_path = statistics_path ,
38- max_sample_size_accuracy = 2000 ,
39- max_sample_size_embeddings = 200 ,
50+ max_sample_size_accuracy = 120 ,
51+ max_sample_size_embeddings = 80 ,
4052 )
4153
4254 assert report_path .exists ()
4355
4456 accuracy = metrics .accuracy
45- assert 0.5 <= accuracy .overall <= 1.0
46- assert 0.5 <= accuracy .univariate <= 1.0
47- assert 0.5 <= accuracy .bivariate <= 1.0
57+ assert 0.3 <= accuracy .overall <= 1.0
58+ assert 0.3 <= accuracy .univariate <= 1.0
59+ assert 0.3 <= accuracy .bivariate <= 1.0
4860 assert accuracy .coherence is None
49- assert 0.8 <= accuracy .overall_max <= 1.0
50- assert 0.8 <= accuracy .univariate_max <= 1.0
51- assert 0.8 <= accuracy .bivariate_max <= 1.0
61+ assert 0.3 <= accuracy .overall_max <= 1.0
62+ assert 0.3 <= accuracy .univariate_max <= 1.0
63+ assert 0.3 <= accuracy .bivariate_max <= 1.0
5264
5365 similarity = metrics .similarity
5466 assert 0.8 <= similarity .cosine_similarity_training_synthetic <= 1.0
@@ -63,11 +75,11 @@ def test_report_flat(tmp_path):
6375 assert 0 <= distances .dcr_holdout <= 1.0
6476 assert 0 <= distances .dcr_share <= 1.0
6577
66- report_path = report_from_statistics (
78+ report_path = qa . report_from_statistics (
6779 syn_tgt_data = syn_tgt_data ,
6880 statistics_path = statistics_path ,
69- max_sample_size_accuracy = 2000 ,
70- max_sample_size_embeddings = 200 ,
81+ max_sample_size_accuracy = 110 ,
82+ max_sample_size_embeddings = 70 ,
7183 )
7284
7385 assert report_path .exists ()
@@ -77,44 +89,51 @@ def test_report_sequential(tmp_path):
7789 statistics_path = tmp_path / "statistics"
7890 report_path = Path (tmp_path / "my-report.html" )
7991
80- trn_tgt_data = pd .read_parquet (baseball_path / "seasons-training.parquet" )
81- hol_tgt_data = pd .read_parquet (baseball_path / "seasons-holdout.parquet" )
82- syn_tgt_data = pd .read_parquet (baseball_path / "seasons-synthetic.parquet" )
83- trn_ctx_data = pd .read_parquet (baseball_path / "players-training.parquet" )
84- hol_ctx_data = pd .read_parquet (baseball_path / "players-holdout.parquet" )
85- syn_ctx_data = pd .concat ([trn_ctx_data , hol_ctx_data ])
86-
87- report_path , metrics = report (
92+ # generate mock context data
93+ syn_ctx_data = mock_data (220 ).reset_index (names = "id" )
94+ trn_ctx_data = mock_data (180 ).reset_index (names = "id" )
95+ hol_ctx_data = mock_data (140 ).reset_index (names = "id" )
96+
97+ # generate mock sequential target data
98+ syn_tgt_data = mock_data (220 * 3 )
99+ syn_tgt_data ["ctx_id" ] = np .random .choice (syn_ctx_data ["id" ], 220 * 3 )
100+ trn_tgt_data = mock_data (180 * 4 )
101+ trn_tgt_data ["ctx_id" ] = np .random .choice (trn_ctx_data ["id" ], 180 * 4 )
102+ hol_tgt_data = mock_data (140 * 4 )
103+ hol_tgt_data ["ctx_id" ] = np .random .choice (hol_ctx_data ["id" ], 140 * 4 )
104+
105+ # generate report
106+ report_path , metrics = qa .report (
88107 syn_tgt_data = syn_tgt_data ,
89108 trn_tgt_data = trn_tgt_data ,
90109 hol_tgt_data = hol_tgt_data ,
91110 syn_ctx_data = syn_ctx_data ,
92111 trn_ctx_data = trn_ctx_data ,
93112 hol_ctx_data = hol_ctx_data ,
94113 ctx_primary_key = "id" ,
95- tgt_context_key = "players_id " ,
114+ tgt_context_key = "ctx_id " ,
96115 report_path = report_path ,
97116 statistics_path = statistics_path ,
98- max_sample_size_accuracy = 2000 ,
99- max_sample_size_embeddings = 200 ,
117+ max_sample_size_accuracy = 120 ,
118+ max_sample_size_embeddings = 80 ,
100119 )
101120
102121 assert report_path .exists ()
103122
104123 accuracy = metrics .accuracy
105- assert 0.8 <= accuracy .overall <= 1.0
106- assert 0.8 <= accuracy .univariate <= 1.0
107- assert 0.8 <= accuracy .bivariate <= 1.0
108- assert 0.8 <= accuracy .coherence <= 1.0
109- assert 0.8 <= accuracy .overall_max <= 1.0
110- assert 0.8 <= accuracy .univariate_max <= 1.0
111- assert 0.8 <= accuracy .bivariate_max <= 1.0
112- assert 0.8 <= accuracy .coherence_max <= 1.0
124+ assert 0.3 <= accuracy .overall <= 1.0
125+ assert 0.3 <= accuracy .univariate <= 1.0
126+ assert 0.3 <= accuracy .bivariate <= 1.0
127+ assert 0.3 <= accuracy .coherence <= 1.0
128+ assert 0.3 <= accuracy .overall_max <= 1.0
129+ assert 0.3 <= accuracy .univariate_max <= 1.0
130+ assert 0.3 <= accuracy .bivariate_max <= 1.0
131+ assert 0.3 <= accuracy .coherence_max <= 1.0
113132
114133 similarity = metrics .similarity
115- assert 0.8 <= similarity .cosine_similarity_training_synthetic <= 1.0
134+ assert 0.3 <= similarity .cosine_similarity_training_synthetic <= 1.0
116135 assert 0.0 <= similarity .discriminator_auc_training_synthetic <= 1.0
117- assert 0.8 <= similarity .cosine_similarity_training_holdout <= 1.0
136+ assert 0.3 <= similarity .cosine_similarity_training_holdout <= 1.0
118137 assert 0.0 <= similarity .discriminator_auc_training_holdout <= 1.0
119138
120139 distances = metrics .distances
@@ -124,13 +143,13 @@ def test_report_sequential(tmp_path):
124143 assert 0 <= distances .dcr_holdout <= 1.0
125144 assert 0 <= distances .dcr_share <= 1.0
126145
127- report_path = report_from_statistics (
146+ report_path = qa . report_from_statistics (
128147 syn_tgt_data = syn_tgt_data ,
129148 syn_ctx_data = syn_ctx_data ,
130149 ctx_primary_key = "id" ,
131- tgt_context_key = "players_id " ,
132- max_sample_size_accuracy = 3000 ,
133- max_sample_size_embeddings = 300 ,
150+ tgt_context_key = "ctx_id " ,
151+ max_sample_size_accuracy = 130 ,
152+ max_sample_size_embeddings = 90 ,
134153 statistics_path = statistics_path ,
135154 )
136155
@@ -144,7 +163,7 @@ def test_report_flat_rare(tmp_path):
144163 syn_tgt_data = pd .DataFrame ({"x" : ["_RARE_" for _ in range (100 )]})
145164 trn_tgt_data = pd .DataFrame ({"x" : [str (uuid .uuid4 ()) for _ in range (100 )]})
146165 hol_tgt_data = pd .DataFrame ({"x" : [str (uuid .uuid4 ()) for _ in range (100 )]})
147- _ , metrics = report (
166+ _ , metrics = qa . report (
148167 syn_tgt_data = syn_tgt_data ,
149168 trn_tgt_data = trn_tgt_data ,
150169 hol_tgt_data = hol_tgt_data ,
@@ -155,7 +174,7 @@ def test_report_flat_rare(tmp_path):
155174
156175 # test case where rare values are not protected, and we leak trn into synthetic
157176 syn_tgt_data = pd .DataFrame ({"x" : trn_tgt_data ["x" ].sample (100 , replace = True )})
158- _ , metrics = report (
177+ _ , metrics = qa . report (
159178 syn_tgt_data = syn_tgt_data ,
160179 trn_tgt_data = trn_tgt_data ,
161180 hol_tgt_data = hol_tgt_data ,
@@ -168,7 +187,7 @@ def test_report_flat_rare(tmp_path):
168187def test_report_flat_early_exit (tmp_path ):
169188 # test early exit for dfs with <100 rows
170189 df = pd .DataFrame ({"col" : list (range (99 ))})
171- _ , metrics = report (syn_tgt_data = df , trn_tgt_data = df , hol_tgt_data = df )
190+ _ , metrics = qa . report (syn_tgt_data = df , trn_tgt_data = df , hol_tgt_data = df )
172191 assert metrics is None
173192
174193
@@ -196,7 +215,7 @@ def make_dfs(
196215 syn_ctx_data = trn_ctx_data = hol_ctx_data = ctx_df
197216 syn_tgt_data = trn_tgt_data = hol_tgt_data = tgt_df
198217 early_term = df_dict .pop ("early_term" )
199- _ , metrics = report (
218+ _ , metrics = qa . report (
200219 syn_tgt_data = syn_tgt_data ,
201220 trn_tgt_data = trn_tgt_data ,
202221 hol_tgt_data = hol_tgt_data ,
@@ -211,7 +230,7 @@ def make_dfs(
211230
212231def test_report_few_holdout_records (tmp_path ):
213232 tgt = pd .DataFrame ({"id" : list (range (100 )), "col" : ["a" ] * 100 })
214- _ , metrics = report (
233+ _ , metrics = qa . report (
215234 syn_tgt_data = tgt ,
216235 trn_tgt_data = tgt ,
217236 hol_tgt_data = tgt [:10 ],
@@ -223,7 +242,7 @@ def test_report_sequential_few_records(tmp_path):
223242 # ensure that we don't crash in case of dominant zero-seq-length
224243 ctx = pd .DataFrame ({"id" : list (range (1000 ))})
225244 tgt = pd .DataFrame ({"id" : [1 , 2 , 3 , 4 , 5 ] * 100 , "col" : ["a" ] * 500 })
226- _ , metrics = report (
245+ _ , metrics = qa . report (
227246 syn_tgt_data = tgt ,
228247 trn_tgt_data = tgt ,
229248 hol_tgt_data = tgt ,
@@ -245,14 +264,30 @@ def test_odd_column_names(tmp_path):
245264 "3" : values ,
246265 }
247266 )
248- path , metrics = report (
267+ path , metrics = qa . report (
249268 syn_tgt_data = df ,
250269 trn_tgt_data = df ,
251270 statistics_path = tmp_path / "stats" ,
252271 )
253272 assert metrics is not None
254- path = report_from_statistics (
273+ path = qa . report_from_statistics (
255274 syn_tgt_data = df ,
256275 statistics_path = tmp_path / "stats" ,
257276 )
258277 assert path is not None
278+
279+
280+ def test_missing (tmp_path ):
281+ df1 = mock_data (100 )
282+ df2 = df1 .copy ()
283+ df2 .loc [:, :] = np .nan
284+ _ , metrics = qa .report (
285+ syn_tgt_data = df1 ,
286+ trn_tgt_data = df2 ,
287+ )
288+ assert metrics is not None
289+ _ , metrics = qa .report (
290+ syn_tgt_data = df2 ,
291+ trn_tgt_data = df1 ,
292+ )
293+ assert metrics is not None
0 commit comments