Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions deepethogram/feature_extractor/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,6 @@
"and test dataloaders.",
)

plt.switch_backend("agg")

log = logging.getLogger(__name__)


Expand All @@ -60,6 +58,7 @@ def feature_extractor_train(cfg: DictConfig) -> nn.Module:
nn.Module
Trained feature extractor
"""
plt.switch_backend("agg")
cfg = projects.setup_run(cfg)

log.info("args: {}".format(" ".join(sys.argv)))
Expand Down
3 changes: 1 addition & 2 deletions deepethogram/flow_generator/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@

flow_generators = utils.get_models_from_module(models, get_function=False)

plt.switch_backend("agg")

log = logging.getLogger(__name__)


Expand All @@ -50,6 +48,7 @@ def flow_generator_train(cfg: DictConfig) -> nn.Module:
nn.Module
Trained flow generator
"""
plt.switch_backend("agg")
cfg = projects.setup_run(cfg)
log.info("args: {}".format(" ".join(sys.argv)))
# only two custom overwrites of the configuration file
Expand Down
7 changes: 6 additions & 1 deletion deepethogram/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,12 @@ def add_label_to_project(path_to_labels: Union[str, os.PathLike], path_to_video)
if os.path.isfile(label_dst):
warnings.warn("Label already exists in destination {}, overwriting...".format(label_dst))

df = pd.read_csv(path_to_labels, index_col=0)
df = pd.read_csv(path_to_labels)
# Drop unnamed index column if present (DEG-generated CSVs have one)
first_col = df.columns[0]
if first_col == "" or str(first_col).startswith("Unnamed"):
Comment thread
jbohnslav marked this conversation as resolved.
df = df.drop(columns=[first_col])

if "none" in list(df.columns):
df = df.rename(columns={"none": "background"})
if "background" not in list(df.columns):
Expand Down
3 changes: 1 addition & 2 deletions deepethogram/sequence/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@

log = logging.getLogger(__name__)

plt.switch_backend("agg")


def sequence_train(cfg: DictConfig) -> nn.Module:
"""Trains sequence models from a configuration.
Expand All @@ -35,6 +33,7 @@ def sequence_train(cfg: DictConfig) -> nn.Module:
nn.Module
Trained sequence model
"""
plt.switch_backend("agg")
cfg = projects.setup_run(cfg)
log.info("args: {}".format(" ".join(sys.argv)))

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "deepethogram"
version = "0.3.0"
version = "0.4.0"
description = "Temporal action detection for biology"
readme = "README.md"
authors = [
Expand Down
75 changes: 75 additions & 0 deletions tests/test_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,80 @@ def test_add_external_label():
projects.add_label_to_project(labelfile, videofile)


@pytest.mark.filterwarnings("ignore::UserWarning")
def test_add_label_deg_style_csv(tmp_path):
"""Test add_label_to_project with DEG-generated CSV (has unnamed index column)."""
make_project_from_archive()
mousedir = os.path.join(project_path, "DATA", "mouse06")
videofile = os.path.join(mousedir, "mouse06.h5")

# Create a DEG-style CSV with unnamed numeric index
csv_path = tmp_path / "labels_with_index.csv"
csv_path.write_text(
",background,behavior1,behavior2\n"
"0,1,0,0\n"
"1,0,1,0\n"
"2,0,0,1\n"
)

result = projects.add_label_to_project(str(csv_path), videofile)
df = pd.read_csv(result, index_col=0)
assert "background" in df.columns
assert "behavior1" in df.columns
assert "behavior2" in df.columns
assert df.shape[1] == 3 # background + 2 behaviors


@pytest.mark.filterwarnings("ignore::UserWarning")
def test_add_label_external_csv_no_index(tmp_path):
"""Test add_label_to_project with external CSV (no index column, no background).

Regression test for GitHub issue #116: the old code used index_col=0 which
silently ate the first data column when no index column was present.
"""
make_project_from_archive()
mousedir = os.path.join(project_path, "DATA", "mouse06")
videofile = os.path.join(mousedir, "mouse06.h5")

# Create a user-provided CSV without index or background column
csv_path = tmp_path / "labels_no_index.csv"
csv_path.write_text(
"behavior1,behavior2\n"
"0,0\n"
"1,0\n"
"0,1\n"
)

result = projects.add_label_to_project(str(csv_path), videofile)
df = pd.read_csv(result, index_col=0)
assert "background" in df.columns, "background column should be auto-inserted"
assert "behavior1" in df.columns, "behavior1 should NOT be eaten by index_col"
assert "behavior2" in df.columns
assert df.shape[1] == 3 # background + behavior1 + behavior2


@pytest.mark.filterwarnings("ignore::UserWarning")
def test_add_label_external_csv_with_background_no_index(tmp_path):
"""Test external CSV that has background but no index column."""
make_project_from_archive()
mousedir = os.path.join(project_path, "DATA", "mouse06")
videofile = os.path.join(mousedir, "mouse06.h5")

csv_path = tmp_path / "labels_bg_no_index.csv"
csv_path.write_text(
"background,behavior1,behavior2\n"
"1,0,0\n"
"0,1,0\n"
"0,0,1\n"
)

result = projects.add_label_to_project(str(csv_path), videofile)
df = pd.read_csv(result, index_col=0)
assert "background" in df.columns
assert "behavior1" in df.columns
assert "behavior2" in df.columns
assert df.shape[1] == 3


if __name__ == "__main__":
test_add_external_label()
2 changes: 1 addition & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading