Skip to content

Commit 08c5df0

Browse files
authored
test: replace fixtures with dynamic mock data
1 parent d18276b commit 08c5df0

File tree

9 files changed

+88
-53
lines changed

9 files changed

+88
-53
lines changed
-68.8 KB
Binary file not shown.
-516 KB
Binary file not shown.
-181 KB
Binary file not shown.
-2.94 MB
Binary file not shown.
-1.31 MB
Binary file not shown.
-21.3 KB
Binary file not shown.
-120 KB
Binary file not shown.
-90.8 KB
Binary file not shown.

tests/end_to_end/test_report.py

Lines changed: 88 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -12,43 +12,55 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import os
1615
import uuid
1716
from pathlib import Path
1817

1918
import pandas as pd
19+
import numpy as np
2020

21-
from mostlyai.qa.report_from_statistics import report_from_statistics
22-
from mostlyai.qa.report import report
21+
from mostlyai import qa
2322

24-
baseball_path = Path(os.path.realpath(__file__)).parent / "fixtures" / "baseball"
25-
census_path = Path(os.path.realpath(__file__)).parent / "fixtures" / "census"
23+
24+
def mock_data(n):
25+
df = pd.DataFrame(
26+
{
27+
"int": pd.Series(np.random.choice(list(range(10)) + [np.nan], n), dtype="Int64[pyarrow]"),
28+
"float": pd.Series(
29+
np.random.choice(list(np.random.uniform(size=10)) + [np.nan], n), dtype="float64[pyarrow]"
30+
),
31+
"cat": pd.Series(np.random.choice(["f", "m", np.nan], n), dtype="string[pyarrow]"),
32+
"bool": pd.Series(np.random.choice([True, False, np.nan], n), dtype="bool[pyarrow]"),
33+
"date": pd.Series(np.random.choice([np.datetime64("today", "D"), np.nan], n), dtype="datetime64[ns]"),
34+
"text": pd.Series([str(uuid.uuid4())[:4] for _ in range(n)], dtype="object"),
35+
}
36+
)
37+
return df
2638

2739

2840
def test_report_flat(tmp_path):
2941
statistics_path = tmp_path / "statistics"
30-
syn_tgt_data = pd.read_parquet(census_path / "census-synthetic.parquet")
31-
trn_tgt_data = pd.read_parquet(census_path / "census-training.parquet")
32-
hol_tgt_data = pd.read_parquet(census_path / "census-holdout.parquet")
33-
report_path, metrics = report(
42+
syn_tgt_data = mock_data(220)
43+
trn_tgt_data = mock_data(180)
44+
hol_tgt_data = mock_data(140)
45+
report_path, metrics = qa.report(
3446
syn_tgt_data=syn_tgt_data,
3547
trn_tgt_data=trn_tgt_data,
3648
hol_tgt_data=hol_tgt_data,
3749
statistics_path=statistics_path,
38-
max_sample_size_accuracy=2000,
39-
max_sample_size_embeddings=200,
50+
max_sample_size_accuracy=120,
51+
max_sample_size_embeddings=80,
4052
)
4153

4254
assert report_path.exists()
4355

4456
accuracy = metrics.accuracy
45-
assert 0.5 <= accuracy.overall <= 1.0
46-
assert 0.5 <= accuracy.univariate <= 1.0
47-
assert 0.5 <= accuracy.bivariate <= 1.0
57+
assert 0.3 <= accuracy.overall <= 1.0
58+
assert 0.3 <= accuracy.univariate <= 1.0
59+
assert 0.3 <= accuracy.bivariate <= 1.0
4860
assert accuracy.coherence is None
49-
assert 0.8 <= accuracy.overall_max <= 1.0
50-
assert 0.8 <= accuracy.univariate_max <= 1.0
51-
assert 0.8 <= accuracy.bivariate_max <= 1.0
61+
assert 0.3 <= accuracy.overall_max <= 1.0
62+
assert 0.3 <= accuracy.univariate_max <= 1.0
63+
assert 0.3 <= accuracy.bivariate_max <= 1.0
5264

5365
similarity = metrics.similarity
5466
assert 0.8 <= similarity.cosine_similarity_training_synthetic <= 1.0
@@ -63,11 +75,11 @@ def test_report_flat(tmp_path):
6375
assert 0 <= distances.dcr_holdout <= 1.0
6476
assert 0 <= distances.dcr_share <= 1.0
6577

66-
report_path = report_from_statistics(
78+
report_path = qa.report_from_statistics(
6779
syn_tgt_data=syn_tgt_data,
6880
statistics_path=statistics_path,
69-
max_sample_size_accuracy=2000,
70-
max_sample_size_embeddings=200,
81+
max_sample_size_accuracy=110,
82+
max_sample_size_embeddings=70,
7183
)
7284

7385
assert report_path.exists()
@@ -77,44 +89,51 @@ def test_report_sequential(tmp_path):
7789
statistics_path = tmp_path / "statistics"
7890
report_path = Path(tmp_path / "my-report.html")
7991

80-
trn_tgt_data = pd.read_parquet(baseball_path / "seasons-training.parquet")
81-
hol_tgt_data = pd.read_parquet(baseball_path / "seasons-holdout.parquet")
82-
syn_tgt_data = pd.read_parquet(baseball_path / "seasons-synthetic.parquet")
83-
trn_ctx_data = pd.read_parquet(baseball_path / "players-training.parquet")
84-
hol_ctx_data = pd.read_parquet(baseball_path / "players-holdout.parquet")
85-
syn_ctx_data = pd.concat([trn_ctx_data, hol_ctx_data])
86-
87-
report_path, metrics = report(
92+
# generate mock context data
93+
syn_ctx_data = mock_data(220).reset_index(names="id")
94+
trn_ctx_data = mock_data(180).reset_index(names="id")
95+
hol_ctx_data = mock_data(140).reset_index(names="id")
96+
97+
# generate mock sequential target data
98+
syn_tgt_data = mock_data(220 * 3)
99+
syn_tgt_data["ctx_id"] = np.random.choice(syn_ctx_data["id"], 220 * 3)
100+
trn_tgt_data = mock_data(180 * 4)
101+
trn_tgt_data["ctx_id"] = np.random.choice(trn_ctx_data["id"], 180 * 4)
102+
hol_tgt_data = mock_data(140 * 4)
103+
hol_tgt_data["ctx_id"] = np.random.choice(hol_ctx_data["id"], 140 * 4)
104+
105+
# generate report
106+
report_path, metrics = qa.report(
88107
syn_tgt_data=syn_tgt_data,
89108
trn_tgt_data=trn_tgt_data,
90109
hol_tgt_data=hol_tgt_data,
91110
syn_ctx_data=syn_ctx_data,
92111
trn_ctx_data=trn_ctx_data,
93112
hol_ctx_data=hol_ctx_data,
94113
ctx_primary_key="id",
95-
tgt_context_key="players_id",
114+
tgt_context_key="ctx_id",
96115
report_path=report_path,
97116
statistics_path=statistics_path,
98-
max_sample_size_accuracy=2000,
99-
max_sample_size_embeddings=200,
117+
max_sample_size_accuracy=120,
118+
max_sample_size_embeddings=80,
100119
)
101120

102121
assert report_path.exists()
103122

104123
accuracy = metrics.accuracy
105-
assert 0.8 <= accuracy.overall <= 1.0
106-
assert 0.8 <= accuracy.univariate <= 1.0
107-
assert 0.8 <= accuracy.bivariate <= 1.0
108-
assert 0.8 <= accuracy.coherence <= 1.0
109-
assert 0.8 <= accuracy.overall_max <= 1.0
110-
assert 0.8 <= accuracy.univariate_max <= 1.0
111-
assert 0.8 <= accuracy.bivariate_max <= 1.0
112-
assert 0.8 <= accuracy.coherence_max <= 1.0
124+
assert 0.3 <= accuracy.overall <= 1.0
125+
assert 0.3 <= accuracy.univariate <= 1.0
126+
assert 0.3 <= accuracy.bivariate <= 1.0
127+
assert 0.3 <= accuracy.coherence <= 1.0
128+
assert 0.3 <= accuracy.overall_max <= 1.0
129+
assert 0.3 <= accuracy.univariate_max <= 1.0
130+
assert 0.3 <= accuracy.bivariate_max <= 1.0
131+
assert 0.3 <= accuracy.coherence_max <= 1.0
113132

114133
similarity = metrics.similarity
115-
assert 0.8 <= similarity.cosine_similarity_training_synthetic <= 1.0
134+
assert 0.3 <= similarity.cosine_similarity_training_synthetic <= 1.0
116135
assert 0.0 <= similarity.discriminator_auc_training_synthetic <= 1.0
117-
assert 0.8 <= similarity.cosine_similarity_training_holdout <= 1.0
136+
assert 0.3 <= similarity.cosine_similarity_training_holdout <= 1.0
118137
assert 0.0 <= similarity.discriminator_auc_training_holdout <= 1.0
119138

120139
distances = metrics.distances
@@ -124,13 +143,13 @@ def test_report_sequential(tmp_path):
124143
assert 0 <= distances.dcr_holdout <= 1.0
125144
assert 0 <= distances.dcr_share <= 1.0
126145

127-
report_path = report_from_statistics(
146+
report_path = qa.report_from_statistics(
128147
syn_tgt_data=syn_tgt_data,
129148
syn_ctx_data=syn_ctx_data,
130149
ctx_primary_key="id",
131-
tgt_context_key="players_id",
132-
max_sample_size_accuracy=3000,
133-
max_sample_size_embeddings=300,
150+
tgt_context_key="ctx_id",
151+
max_sample_size_accuracy=130,
152+
max_sample_size_embeddings=90,
134153
statistics_path=statistics_path,
135154
)
136155

@@ -144,7 +163,7 @@ def test_report_flat_rare(tmp_path):
144163
syn_tgt_data = pd.DataFrame({"x": ["_RARE_" for _ in range(100)]})
145164
trn_tgt_data = pd.DataFrame({"x": [str(uuid.uuid4()) for _ in range(100)]})
146165
hol_tgt_data = pd.DataFrame({"x": [str(uuid.uuid4()) for _ in range(100)]})
147-
_, metrics = report(
166+
_, metrics = qa.report(
148167
syn_tgt_data=syn_tgt_data,
149168
trn_tgt_data=trn_tgt_data,
150169
hol_tgt_data=hol_tgt_data,
@@ -155,7 +174,7 @@ def test_report_flat_rare(tmp_path):
155174

156175
# test case where rare values are not protected, and we leak trn into synthetic
157176
syn_tgt_data = pd.DataFrame({"x": trn_tgt_data["x"].sample(100, replace=True)})
158-
_, metrics = report(
177+
_, metrics = qa.report(
159178
syn_tgt_data=syn_tgt_data,
160179
trn_tgt_data=trn_tgt_data,
161180
hol_tgt_data=hol_tgt_data,
@@ -168,7 +187,7 @@ def test_report_flat_rare(tmp_path):
168187
def test_report_flat_early_exit(tmp_path):
169188
# test early exit for dfs with <100 rows
170189
df = pd.DataFrame({"col": list(range(99))})
171-
_, metrics = report(syn_tgt_data=df, trn_tgt_data=df, hol_tgt_data=df)
190+
_, metrics = qa.report(syn_tgt_data=df, trn_tgt_data=df, hol_tgt_data=df)
172191
assert metrics is None
173192

174193

@@ -196,7 +215,7 @@ def make_dfs(
196215
syn_ctx_data = trn_ctx_data = hol_ctx_data = ctx_df
197216
syn_tgt_data = trn_tgt_data = hol_tgt_data = tgt_df
198217
early_term = df_dict.pop("early_term")
199-
_, metrics = report(
218+
_, metrics = qa.report(
200219
syn_tgt_data=syn_tgt_data,
201220
trn_tgt_data=trn_tgt_data,
202221
hol_tgt_data=hol_tgt_data,
@@ -211,7 +230,7 @@ def make_dfs(
211230

212231
def test_report_few_holdout_records(tmp_path):
213232
tgt = pd.DataFrame({"id": list(range(100)), "col": ["a"] * 100})
214-
_, metrics = report(
233+
_, metrics = qa.report(
215234
syn_tgt_data=tgt,
216235
trn_tgt_data=tgt,
217236
hol_tgt_data=tgt[:10],
@@ -223,7 +242,7 @@ def test_report_sequential_few_records(tmp_path):
223242
# ensure that we don't crash in case of dominant zero-seq-length
224243
ctx = pd.DataFrame({"id": list(range(1000))})
225244
tgt = pd.DataFrame({"id": [1, 2, 3, 4, 5] * 100, "col": ["a"] * 500})
226-
_, metrics = report(
245+
_, metrics = qa.report(
227246
syn_tgt_data=tgt,
228247
trn_tgt_data=tgt,
229248
hol_tgt_data=tgt,
@@ -245,14 +264,30 @@ def test_odd_column_names(tmp_path):
245264
"3": values,
246265
}
247266
)
248-
path, metrics = report(
267+
path, metrics = qa.report(
249268
syn_tgt_data=df,
250269
trn_tgt_data=df,
251270
statistics_path=tmp_path / "stats",
252271
)
253272
assert metrics is not None
254-
path = report_from_statistics(
273+
path = qa.report_from_statistics(
255274
syn_tgt_data=df,
256275
statistics_path=tmp_path / "stats",
257276
)
258277
assert path is not None
278+
279+
280+
def test_missing(tmp_path):
281+
df1 = mock_data(100)
282+
df2 = df1.copy()
283+
df2.loc[:, :] = np.nan
284+
_, metrics = qa.report(
285+
syn_tgt_data=df1,
286+
trn_tgt_data=df2,
287+
)
288+
assert metrics is not None
289+
_, metrics = qa.report(
290+
syn_tgt_data=df2,
291+
trn_tgt_data=df1,
292+
)
293+
assert metrics is not None

0 commit comments

Comments
 (0)