From e63d29a47fae346c187d54111bb93f5ccb0e7a36 Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Tue, 11 Jul 2023 16:53:38 -0500 Subject: [PATCH 1/2] tests for _task_row_update_statement --- exorcist/taskdb.py | 5 +- exorcist/tests/test_taskstatusdb.py | 106 +++++++++++++++++++++++++++- 2 files changed, 105 insertions(+), 6 deletions(-) diff --git a/exorcist/taskdb.py b/exorcist/taskdb.py index 1df279f..c0c4dc0 100644 --- a/exorcist/taskdb.py +++ b/exorcist/taskdb.py @@ -481,13 +481,10 @@ def check_out_task(self): return task_row.taskid - def mark_task_aborted_incomplete(self, taskid: str): - ... - def _mark_task_completed_failure(self, taskid: str): status_statement = sqla.case( ( - self.tasks_table.c.tries == self.tasks_table.c.max_tries, + self.tasks_table.c.tries >= self.tasks_table.c.max_tries, TaskStatus.TOO_MANY_RETRIES.value ), else_=TaskStatus.AVAILABLE.value diff --git a/exorcist/tests/test_taskstatusdb.py b/exorcist/tests/test_taskstatusdb.py index 5235af6..96c7ee9 100644 --- a/exorcist/tests/test_taskstatusdb.py +++ b/exorcist/tests/test_taskstatusdb.py @@ -245,8 +245,110 @@ def test_add_task_network(self, fresh_db, diamond_taskid_network): assert len(deps) == len(expected_deps) def test_task_row_update_statement(self, loaded_db): - # TODO: I'm going to do this in a future PR - ... + expected_sql = ( + "UPDATE tasks SET status=:status, last_modified=:last_modified " + "WHERE tasks.taskid = :taskid_1" + ) + + with patch_datetime_now(): + stmt = loaded_db._task_row_update_statement( + "foo", + TaskStatus.BLOCKED + ) + + assert str(stmt) == expected_sql + + with loaded_db.engine.begin() as conn: + res = conn.execute(stmt) + + assert res.rowcount == 1 + tasks, deps = get_tasks_and_deps(loaded_db) + expected_tasks = { + ("foo", TaskStatus.BLOCKED.value, _DEFAULT_DATETIME, 0, 3), + ("bar", TaskStatus.BLOCKED.value, None, 0, 3) + } + assert set(tasks) == expected_tasks + + def test_task_row_update_statement_old_status(self, loaded_db): + expected_sql = ( + "UPDATE tasks SET status=:status, last_modified=:last_modified " + "WHERE tasks.taskid = :taskid_1 " + "AND tasks.status = :status_1" + ) + + with patch_datetime_now(): + stmt = loaded_db._task_row_update_statement( + "foo", + TaskStatus.IN_PROGRESS, + old_status=TaskStatus.AVAILABLE + ) + + assert str(stmt) == expected_sql + + with loaded_db.engine.begin() as conn: + res = conn.execute(stmt) + + assert res.rowcount == 1 + tasks, deps = get_tasks_and_deps(loaded_db) + expected_tasks = { + ("foo", TaskStatus.IN_PROGRESS.value, _DEFAULT_DATETIME, 0, 3), + ("bar", TaskStatus.BLOCKED.value, None, 0, 3) + } + assert set(tasks) == expected_tasks + + + def test_task_row_update_statement_is_checkout(self, loaded_db): + expected_sql = ( + "UPDATE tasks SET status=:status, last_modified=:last_modified, " + "tries=(tasks.tries + :tries_1) " + "WHERE tasks.taskid = :taskid_1" + ) + + with patch_datetime_now(): + stmt = loaded_db._task_row_update_statement( + "foo", + TaskStatus.IN_PROGRESS, + is_checkout=True + ) + + assert str(stmt) == expected_sql + + with loaded_db.engine.begin() as conn: + res = conn.execute(stmt) + + assert res.rowcount == 1 + tasks, deps = get_tasks_and_deps(loaded_db) + expected_tasks = { + ("foo", TaskStatus.IN_PROGRESS.value, _DEFAULT_DATETIME, 1, 3), + ("bar", TaskStatus.BLOCKED.value, None, 0, 3) + } + assert set(tasks) == expected_tasks + + def test_task_row_update_statement_max_tries(self, loaded_db): + expected_sql = ( + "UPDATE tasks SET status=:status, last_modified=:last_modified, " + "max_tries=:max_tries " + "WHERE tasks.taskid = :taskid_1" + ) + + with patch_datetime_now(): + stmt = loaded_db._task_row_update_statement( + "foo", + TaskStatus.AVAILABLE, + max_tries=10 + ) + + assert str(stmt) == expected_sql + + with loaded_db.engine.begin() as conn: + res = conn.execute(stmt) + + assert res.rowcount == 1 + tasks, deps = get_tasks_and_deps(loaded_db) + expected_tasks = { + ("foo", TaskStatus.AVAILABLE.value, _DEFAULT_DATETIME, 1, 10), + ("bar", TaskStatus.BLOCKED.value, None, 0, 3) + } def test_check_out_task(self, loaded_db): taskid = loaded_db.check_out_task() From 1090896a5ba2af2ace05a54c98f301d3e8af0580 Mon Sep 17 00:00:00 2001 From: "David W.H. Swenson" Date: Tue, 11 Jul 2023 16:54:12 -0500 Subject: [PATCH 2/2] undo accidental change --- exorcist/taskdb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/exorcist/taskdb.py b/exorcist/taskdb.py index c0c4dc0..2e309b8 100644 --- a/exorcist/taskdb.py +++ b/exorcist/taskdb.py @@ -484,7 +484,7 @@ def check_out_task(self): def _mark_task_completed_failure(self, taskid: str): status_statement = sqla.case( ( - self.tasks_table.c.tries >= self.tasks_table.c.max_tries, + self.tasks_table.c.tries == self.tasks_table.c.max_tries, TaskStatus.TOO_MANY_RETRIES.value ), else_=TaskStatus.AVAILABLE.value