Skip to content

Commit 7de630d

Browse files
committed
add test
1 parent b54569b commit 7de630d

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#
2+
# Copyright (c) 2021, NVIDIA CORPORATION.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
import pytest
17+
from testbook import testbook
18+
19+
from tests.conftest import REPO_ROOT
20+
21+
22+
@testbook(REPO_ROOT / "examples/pytorch/01-Getting-started.ipynb", execute=False)
23+
@pytest.mark.notebook
24+
def test_example_01_getting_started(tb):
25+
tb.inject(
26+
"""
27+
from unittest.mock import patch
28+
from merlin.datasets.synthetic import generate_data
29+
mock_train, mock_valid = generate_data(
30+
input="movielens-1m",
31+
num_rows=1000,
32+
set_sizes=(0.8, 0.2)
33+
)
34+
p1 = patch(
35+
"merlin.datasets.entertainment.get_movielens",
36+
return_value=[mock_train, mock_valid]
37+
)
38+
p1.start()
39+
"""
40+
)
41+
tb.execute()
42+
metrics = tb.ref("metrics")
43+
assert set(metrics[0].keys()) == set(
44+
[
45+
"val_loss",
46+
"val_binary_accuracy",
47+
"val_binary_auroc",
48+
"val_binary_precision",
49+
"val_binary_recall"
50+
]
51+
)

0 commit comments

Comments
 (0)