Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: build orcai

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]

jobs:
uv-example:
name: python
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v6

- name: "Set up Python"
uses: actions/setup-python@v6
with:
python-version-file: "pyproject.toml"

- name: Install uv
uses: astral-sh/setup-uv@v8.0.0
with:
enable-cache: true
version: "0.11.3"

- name: Install the project
run: uv sync --locked --all-extras --dev

- name: check formatting
run: uv run ruff check

- name: run tests
run: uv run pytest
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## [3.0.1]

### Added

- improve reporting of tf devices
- CI build workflow

## [3.0.0]

### Added
Expand Down Expand Up @@ -743,4 +750,5 @@
[2.2.0]:https://github.com/ethz-tb/orcAI/releases/tag/v2.2.0
[2.3.0]:https://github.com/ethz-tb/orcAI/releases/tag/v2.3.0
[3.0.0]:https://github.com/ethz-tb/orcAI/releases/tag/v3.0.0
[3.0.1]:https://github.com/ethz-tb/orcAI/releases/tag/v3.0.1

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "uv_build"

[project]
name = "orcai"
version = "3.0.0"
version = "3.0.1"
authors = [
{ name = "Chérine Baumgartner", email = "cherine.baumgartner@env.ethz.ch" },
{ name = "Sebastian Bonhoeffer", email = "sebastian.bonhoeffer@env.ethz.ch" },
Expand Down
21 changes: 12 additions & 9 deletions src/orcai/auxiliary.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,16 +326,19 @@ def print_tf_device_info(
) # suppress tensorflow logging (ERROR and worse only)

physical_devices = tf.config.list_physical_devices("GPU")
devices_info = [
tf.config.experimental.get_device_details(i) for i in physical_devices
]

devices_string = ", ".join(
[
f"{dev.name.replace('physical_device:', '')}: {info['device_name']}"
for dev, info in zip(physical_devices, devices_info)

if len(physical_devices) == 0:
devices_string = "No GPU devices found. Using CPU."
else:
devices_info = [
tf.config.experimental.get_device_details(i) for i in physical_devices
]
)
devices_string = ", ".join(
[
f"{dev.name.replace('physical_device:', '')}: {info['device_name']}"
for dev, info in zip(physical_devices, devices_info)
]
)

self.info(
f"Available TensorFlow devices: {devices_string}",
Expand Down
1 change: 1 addition & 0 deletions src/orcai/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def compute_aggregated_predictions(

# Step 2: Model predictions for all snippets
msgr.info("Prediction of snippets")
msgr.print_tf_device_info(severity=3)
snippets = snippets[..., np.newaxis] # Shape: (num_snippets, 736, 171, 1)
predictions = model.predict(
snippets,
Expand Down
9 changes: 8 additions & 1 deletion tests/test_json_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,13 @@ def test_unsupported_type_raises(self):

def test_standard_types_unchanged(self):
"""Standard JSON-serializable types pass through unchanged."""
data = {"int": 1, "float": 2.5, "str": "hello", "list": [1, 2], "bool": True, "none": None}
data = {
"int": 1,
"float": 2.5,
"str": "hello",
"list": [1, 2],
"bool": True,
"none": None,
}
result = encode(data)
assert result == data
39 changes: 23 additions & 16 deletions tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ def test_multiple_calls(self):
probs = np.array([0.1, 0.5, 0.9, 0.3])
result = _calulate_mean_probabilities(probs, [0, 2], [2, 4])
assert len(result) == 2
assert result[0] == pytest.approx(0.3) # mean(probs[0:2]) = mean(0.1, 0.5)
assert result[1] == pytest.approx(0.6) # mean(0.9, 0.3)
assert result[0] == pytest.approx(0.3) # mean(probs[0:2]) = mean(0.1, 0.5)
assert result[1] == pytest.approx(0.6) # mean(0.9, 0.3)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -178,14 +178,10 @@ def test_custom_threshold(self):
preds = np.zeros((10, 1))
preds[3:6, 0] = 0.6
# threshold=0.7 → nothing detected
_, _, labels_high, _ = compute_binary_predictions(
preds, ["BR"], threshold=0.7
)
_, _, labels_high, _ = compute_binary_predictions(preds, ["BR"], threshold=0.7)
assert len(labels_high) == 0
# threshold=0.5 → detected
_, _, labels_low, _ = compute_binary_predictions(
preds, ["BR"], threshold=0.5
)
_, _, labels_low, _ = compute_binary_predictions(preds, ["BR"], threshold=0.5)
assert "BR" in labels_low


Expand Down Expand Up @@ -289,7 +285,9 @@ def test_different_channel(self, tmp_path):
class TestFilterPredictions:
"""Tests for filter_predictions."""

def test_keeps_all_within_limits(self, predicted_labels_df, call_duration_limits_dict):
def test_keeps_all_within_limits(
self, predicted_labels_df, call_duration_limits_dict
):
"""All calls within limits are kept."""
# predicted_labels_df has durations 5, 5, 5, 5 (stop-start), delta_t=1
# BR limits [2,8], BUZZ limits [3,20], WHISTLE limits [1,10] → all kept
Expand All @@ -310,7 +308,9 @@ def test_removes_too_short(self, call_duration_limits_dict):
"mean_p": [0.9],
}
)
result = filter_predictions(df, delta_t=1.0, call_duration_limits=call_duration_limits_dict)
result = filter_predictions(
df, delta_t=1.0, call_duration_limits=call_duration_limits_dict
)
assert len(result) == 0

def test_removes_too_long(self, call_duration_limits_dict):
Expand All @@ -323,20 +323,28 @@ def test_removes_too_long(self, call_duration_limits_dict):
"mean_p": [0.9],
}
)
result = filter_predictions(df, delta_t=1.0, call_duration_limits=call_duration_limits_dict)
result = filter_predictions(
df, delta_t=1.0, call_duration_limits=call_duration_limits_dict
)
assert len(result) == 0

def test_empty_input_returns_empty(self, call_duration_limits_dict):
"""Empty input DataFrame is returned unchanged."""
df = pd.DataFrame(columns=["start", "stop", "label", "mean_p"])
result = filter_predictions(df, delta_t=1.0, call_duration_limits=call_duration_limits_dict)
result = filter_predictions(
df, delta_t=1.0, call_duration_limits=call_duration_limits_dict
)
assert result.empty

def test_output_columns_preserved(self, predicted_labels_df, call_duration_limits_dict):
def test_output_columns_preserved(
self, predicted_labels_df, call_duration_limits_dict
):
"""Output has the same columns as input (filter_predictions modifies df in-place)."""
original_cols = list(predicted_labels_df.columns)
result = filter_predictions(
predicted_labels_df, delta_t=1.0, call_duration_limits=call_duration_limits_dict
predicted_labels_df,
delta_t=1.0,
call_duration_limits=call_duration_limits_dict,
)
assert list(result.columns) == original_cols

Expand All @@ -362,8 +370,7 @@ class TestFilterPredictionsFile:
def _write_predictions_file(self, path: Path, rows: list[tuple]) -> None:
"""Write a tab-separated predictions file."""
lines = "\n".join(
f"{start}\t{stop}\t{label}\t{p}\tsource"
for start, stop, label, p in rows
f"{start}\t{stop}\t{label}\t{p}\tsource" for start, stop, label, p in rows
)
path.write_text(lines)

Expand Down
52 changes: 40 additions & 12 deletions tests/test_snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def test_total_equals_sum_of_types(self, snippet_table_df, label_calls):
"""Total column equals sum of train + val + test."""
stats = _compute_snippet_stats(snippet_table_df, for_calls=label_calls)
computed_total = stats[["train", "val", "test"]].sum(axis=1)
pd.testing.assert_series_equal(stats["total"], computed_total, check_names=False)
pd.testing.assert_series_equal(
stats["total"], computed_total, check_names=False
)

def test_rows_are_call_names(self, snippet_table_df, label_calls):
"""One row per call in for_calls."""
Expand Down Expand Up @@ -78,7 +80,9 @@ def test_output_columns_preserved(self, snippet_table_df, orcai_parameter_snippe
result = _filter_snippet_table(snippet_table_df, orcai_parameter_snippets)
assert set(result.columns) == set(snippet_table_df.columns)

def test_fraction_removal_zero_keeps_all(self, snippet_table_df, orcai_parameter_snippets):
def test_fraction_removal_zero_keeps_all(
self, snippet_table_df, orcai_parameter_snippets
):
"""fraction_removal=0 keeps all no-label snippets."""
params = {**orcai_parameter_snippets}
params["snippets"] = {**params["snippets"], "fraction_removal": 0.0}
Expand All @@ -91,13 +95,17 @@ def test_index_reset(self, snippet_table_df, orcai_parameter_snippets):
assert list(result.index) == list(range(len(result)))

@pytest.mark.parametrize("seed", [0, 42, 123])
def test_deterministic_with_same_seed(self, snippet_table_df, orcai_parameter_snippets, seed):
def test_deterministic_with_same_seed(
self, snippet_table_df, orcai_parameter_snippets, seed
):
"""Same rng seed produces identical results."""
rng1 = np.random.default_rng(seed)
rng2 = np.random.default_rng(seed)
r1 = _filter_snippet_table(snippet_table_df, orcai_parameter_snippets, rng=rng1)
r2 = _filter_snippet_table(snippet_table_df, orcai_parameter_snippets, rng=rng2)
pd.testing.assert_frame_equal(r1.reset_index(drop=True), r2.reset_index(drop=True))
pd.testing.assert_frame_equal(
r1.reset_index(drop=True), r2.reset_index(drop=True)
)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -142,11 +150,15 @@ def _build_recording_dir(
class TestMakeSnippetTable:
"""Tests for _make_snippet_table."""

def test_success_returns_dataframe(self, tmp_path, label_calls, orcai_parameter_snippets):
def test_success_returns_dataframe(
self, tmp_path, label_calls, orcai_parameter_snippets
):
"""Returns a DataFrame when directory structure is complete."""
rec_dir = tmp_path / "test_rec"
_build_recording_dir(rec_dir, label_calls)
snippet_table, _, _, _, status = _make_snippet_table(rec_dir, orcai_parameter_snippets)
snippet_table, _, _, _, status = _make_snippet_table(
rec_dir, orcai_parameter_snippets
)
assert status == "success"
assert isinstance(snippet_table, pd.DataFrame)

Expand All @@ -155,10 +167,18 @@ def test_output_columns(self, tmp_path, label_calls, orcai_parameter_snippets):
rec_dir = tmp_path / "test_rec"
_build_recording_dir(rec_dir, label_calls)
snippet_table, *_ = _make_snippet_table(rec_dir, orcai_parameter_snippets)
for col in ["recording", "recording_data_dir", "data_type", "row_start", "row_stop"]:
for col in [
"recording",
"recording_data_dir",
"data_type",
"row_start",
"row_stop",
]:
assert col in snippet_table.columns

def test_missing_spectrogram_raises(self, tmp_path, label_calls, orcai_parameter_snippets):
def test_missing_spectrogram_raises(
self, tmp_path, label_calls, orcai_parameter_snippets
):
"""FileNotFoundError raised when times.json is missing."""
rec_dir = tmp_path / "no_spec"
rec_dir.mkdir()
Expand All @@ -167,7 +187,9 @@ def test_missing_spectrogram_raises(self, tmp_path, label_calls, orcai_parameter
with pytest.raises(FileNotFoundError):
_make_snippet_table(rec_dir, orcai_parameter_snippets)

def test_missing_label_file_returns_none(self, tmp_path, label_calls, orcai_parameter_snippets):
def test_missing_label_file_returns_none(
self, tmp_path, label_calls, orcai_parameter_snippets
):
"""Returns None snippet table when labels.zarr is missing."""
rec_dir = tmp_path / "no_labels"
rec_dir.mkdir()
Expand All @@ -176,15 +198,21 @@ def test_missing_label_file_returns_none(self, tmp_path, label_calls, orcai_para
(spec_dir / "times.json").write_text(
json.dumps({"min": 0.0, "max": 500.0, "length": 500})
)
snippet_table, _, _, _, status = _make_snippet_table(rec_dir, orcai_parameter_snippets)
snippet_table, _, _, _, status = _make_snippet_table(
rec_dir, orcai_parameter_snippets
)
assert snippet_table is None
assert status == "missing label files"

def test_recording_too_short_returns_none(self, tmp_path, label_calls, orcai_parameter_snippets):
def test_recording_too_short_returns_none(
self, tmp_path, label_calls, orcai_parameter_snippets
):
"""Returns None when recording is shorter than segment_duration."""
rec_dir = tmp_path / "short_rec"
# Recording of 5s, segment_duration=10 → n_segments=0
_build_recording_dir(rec_dir, label_calls, n_time=50, recording_duration=5.0)
snippet_table, _, _, _, status = _make_snippet_table(rec_dir, orcai_parameter_snippets)
snippet_table, _, _, _, status = _make_snippet_table(
rec_dir, orcai_parameter_snippets
)
assert snippet_table is None
assert status == "shorter than segment_duration"
Loading