Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
33 changes: 33 additions & 0 deletions tests/test_create_input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pandas as pd
import pytest

from src.create_input import conv_filename_to_condition, validate_inputs


def test_conv_filename_to_condition_match():
pattern = r"(?P<noise>[^_]+)_(?P<level>\d+)\.wav"
result = conv_filename_to_condition("white_10.wav", pattern)
assert list(result.items()) == [("level", "10"), ("noise", "white")]


def test_conv_filename_to_condition_no_match():
pattern = r"(?P<noise>[^_]+)_(?P<level>\d+)\.wav"
result = conv_filename_to_condition("unexpected.wav", pattern)
assert result == {"Unknown": "NoCondition"}


def test_validate_inputs_acr_minimal():
cfg = {"number_of_gold_clips_per_session": "0"}
df = pd.DataFrame(columns=[
"rating_clips", "math", "pair_a", "pair_b", "trapping_clips", "trapping_ans"
])
validate_inputs(cfg, df, "acr")


def test_validate_inputs_missing_column():
cfg = {"number_of_gold_clips_per_session": "0"}
df = pd.DataFrame(columns=[
"rating_clips", "math", "pair_a", "trapping_clips", "trapping_ans"
])
with pytest.raises(AssertionError):
validate_inputs(cfg, df, "acr")
12 changes: 12 additions & 0 deletions tests/test_result_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from src.result_parser import outliers_modified_z_score, outliers_z_score


def test_outliers_modified_z_score_removes_outlier():
data = [1, 1, 1, 100]
assert outliers_modified_z_score(data) == [1, 1, 1]


def test_outliers_z_score_threshold_high():
data = [10, 10, 10, 1000]
# With default threshold 3.29 the outlier is not removed
assert outliers_z_score(data) == data