diff --git a/deepethogram/feature_extractor/train.py b/deepethogram/feature_extractor/train.py index b5e0d6f..14c4353 100644 --- a/deepethogram/feature_extractor/train.py +++ b/deepethogram/feature_extractor/train.py @@ -42,8 +42,6 @@ "and test dataloaders.", ) -plt.switch_backend("agg") - log = logging.getLogger(__name__) @@ -60,6 +58,7 @@ def feature_extractor_train(cfg: DictConfig) -> nn.Module: nn.Module Trained feature extractor """ + plt.switch_backend("agg") cfg = projects.setup_run(cfg) log.info("args: {}".format(" ".join(sys.argv))) diff --git a/deepethogram/flow_generator/train.py b/deepethogram/flow_generator/train.py index 609214e..4a46335 100644 --- a/deepethogram/flow_generator/train.py +++ b/deepethogram/flow_generator/train.py @@ -32,8 +32,6 @@ flow_generators = utils.get_models_from_module(models, get_function=False) -plt.switch_backend("agg") - log = logging.getLogger(__name__) @@ -50,6 +48,7 @@ def flow_generator_train(cfg: DictConfig) -> nn.Module: nn.Module Trained flow generator """ + plt.switch_backend("agg") cfg = projects.setup_run(cfg) log.info("args: {}".format(" ".join(sys.argv))) # only two custom overwrites of the configuration file diff --git a/deepethogram/projects.py b/deepethogram/projects.py index 9e54e04..43f864c 100644 --- a/deepethogram/projects.py +++ b/deepethogram/projects.py @@ -171,7 +171,12 @@ def add_label_to_project(path_to_labels: Union[str, os.PathLike], path_to_video) if os.path.isfile(label_dst): warnings.warn("Label already exists in destination {}, overwriting...".format(label_dst)) - df = pd.read_csv(path_to_labels, index_col=0) + df = pd.read_csv(path_to_labels) + # Drop unnamed index column if present (DEG-generated CSVs have one) + first_col = df.columns[0] + if first_col == "" or str(first_col).startswith("Unnamed"): + df = df.drop(columns=[first_col]) + if "none" in list(df.columns): df = df.rename(columns={"none": "background"}) if "background" not in list(df.columns): diff --git a/deepethogram/sequence/train.py b/deepethogram/sequence/train.py index e172993..5fbbb02 100644 --- a/deepethogram/sequence/train.py +++ b/deepethogram/sequence/train.py @@ -19,8 +19,6 @@ log = logging.getLogger(__name__) -plt.switch_backend("agg") - def sequence_train(cfg: DictConfig) -> nn.Module: """Trains sequence models from a configuration. @@ -35,6 +33,7 @@ def sequence_train(cfg: DictConfig) -> nn.Module: nn.Module Trained sequence model """ + plt.switch_backend("agg") cfg = projects.setup_run(cfg) log.info("args: {}".format(" ".join(sys.argv))) diff --git a/pyproject.toml b/pyproject.toml index dc9e477..b18b27a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "deepethogram" -version = "0.3.0" +version = "0.4.0" description = "Temporal action detection for biology" readme = "README.md" authors = [ diff --git a/tests/test_projects.py b/tests/test_projects.py index 4da8c2c..f2c45cb 100644 --- a/tests/test_projects.py +++ b/tests/test_projects.py @@ -115,5 +115,80 @@ def test_add_external_label(): projects.add_label_to_project(labelfile, videofile) +@pytest.mark.filterwarnings("ignore::UserWarning") +def test_add_label_deg_style_csv(tmp_path): + """Test add_label_to_project with DEG-generated CSV (has unnamed index column).""" + make_project_from_archive() + mousedir = os.path.join(project_path, "DATA", "mouse06") + videofile = os.path.join(mousedir, "mouse06.h5") + + # Create a DEG-style CSV with unnamed numeric index + csv_path = tmp_path / "labels_with_index.csv" + csv_path.write_text( + ",background,behavior1,behavior2\n" + "0,1,0,0\n" + "1,0,1,0\n" + "2,0,0,1\n" + ) + + result = projects.add_label_to_project(str(csv_path), videofile) + df = pd.read_csv(result, index_col=0) + assert "background" in df.columns + assert "behavior1" in df.columns + assert "behavior2" in df.columns + assert df.shape[1] == 3 # background + 2 behaviors + + +@pytest.mark.filterwarnings("ignore::UserWarning") +def test_add_label_external_csv_no_index(tmp_path): + """Test add_label_to_project with external CSV (no index column, no background). + + Regression test for GitHub issue #116: the old code used index_col=0 which + silently ate the first data column when no index column was present. + """ + make_project_from_archive() + mousedir = os.path.join(project_path, "DATA", "mouse06") + videofile = os.path.join(mousedir, "mouse06.h5") + + # Create a user-provided CSV without index or background column + csv_path = tmp_path / "labels_no_index.csv" + csv_path.write_text( + "behavior1,behavior2\n" + "0,0\n" + "1,0\n" + "0,1\n" + ) + + result = projects.add_label_to_project(str(csv_path), videofile) + df = pd.read_csv(result, index_col=0) + assert "background" in df.columns, "background column should be auto-inserted" + assert "behavior1" in df.columns, "behavior1 should NOT be eaten by index_col" + assert "behavior2" in df.columns + assert df.shape[1] == 3 # background + behavior1 + behavior2 + + +@pytest.mark.filterwarnings("ignore::UserWarning") +def test_add_label_external_csv_with_background_no_index(tmp_path): + """Test external CSV that has background but no index column.""" + make_project_from_archive() + mousedir = os.path.join(project_path, "DATA", "mouse06") + videofile = os.path.join(mousedir, "mouse06.h5") + + csv_path = tmp_path / "labels_bg_no_index.csv" + csv_path.write_text( + "background,behavior1,behavior2\n" + "1,0,0\n" + "0,1,0\n" + "0,0,1\n" + ) + + result = projects.add_label_to_project(str(csv_path), videofile) + df = pd.read_csv(result, index_col=0) + assert "background" in df.columns + assert "behavior1" in df.columns + assert "behavior2" in df.columns + assert df.shape[1] == 3 + + if __name__ == "__main__": test_add_external_label() diff --git a/uv.lock b/uv.lock index 37f8acc..2c9261d 100644 --- a/uv.lock +++ b/uv.lock @@ -331,7 +331,7 @@ wheels = [ [[package]] name = "deepethogram" -version = "0.3.0" +version = "0.4.0" source = { editable = "." } dependencies = [ { name = "chardet" },