diff --git a/exorcist/tests/test_taskstatusdb.py b/exorcist/tests/test_taskstatusdb.py index 5c69650..de4ac8d 100644 --- a/exorcist/tests/test_taskstatusdb.py +++ b/exorcist/tests/test_taskstatusdb.py @@ -269,8 +269,110 @@ def test_add_task_network(self, fresh_db, diamond_taskid_network): assert len(deps) == len(expected_deps) def test_task_row_update_statement(self, loaded_db): - # TODO: I'm going to do this in a future PR - ... + expected_sql = ( + "UPDATE tasks SET status=:status, last_modified=:last_modified " + "WHERE tasks.taskid = :taskid_1" + ) + + with patch_datetime_now(): + stmt = loaded_db._task_row_update_statement( + "foo", + TaskStatus.BLOCKED + ) + + assert str(stmt) == expected_sql + + with loaded_db.engine.begin() as conn: + res = conn.execute(stmt) + + assert res.rowcount == 1 + tasks, deps = get_tasks_and_deps(loaded_db) + expected_tasks = { + ("foo", TaskStatus.BLOCKED.value, _DEFAULT_DATETIME, 0, 3), + ("bar", TaskStatus.BLOCKED.value, None, 0, 3) + } + assert set(tasks) == expected_tasks + + def test_task_row_update_statement_old_status(self, loaded_db): + expected_sql = ( + "UPDATE tasks SET status=:status, last_modified=:last_modified " + "WHERE tasks.taskid = :taskid_1 " + "AND tasks.status = :status_1" + ) + + with patch_datetime_now(): + stmt = loaded_db._task_row_update_statement( + "foo", + TaskStatus.IN_PROGRESS, + old_status=TaskStatus.AVAILABLE + ) + + assert str(stmt) == expected_sql + + with loaded_db.engine.begin() as conn: + res = conn.execute(stmt) + + assert res.rowcount == 1 + tasks, deps = get_tasks_and_deps(loaded_db) + expected_tasks = { + ("foo", TaskStatus.IN_PROGRESS.value, _DEFAULT_DATETIME, 0, 3), + ("bar", TaskStatus.BLOCKED.value, None, 0, 3) + } + assert set(tasks) == expected_tasks + + + def test_task_row_update_statement_is_checkout(self, loaded_db): + expected_sql = ( + "UPDATE tasks SET status=:status, last_modified=:last_modified, " + "tries=(tasks.tries + :tries_1) " + "WHERE tasks.taskid = :taskid_1" + ) + + with patch_datetime_now(): + stmt = loaded_db._task_row_update_statement( + "foo", + TaskStatus.IN_PROGRESS, + is_checkout=True + ) + + assert str(stmt) == expected_sql + + with loaded_db.engine.begin() as conn: + res = conn.execute(stmt) + + assert res.rowcount == 1 + tasks, deps = get_tasks_and_deps(loaded_db) + expected_tasks = { + ("foo", TaskStatus.IN_PROGRESS.value, _DEFAULT_DATETIME, 1, 3), + ("bar", TaskStatus.BLOCKED.value, None, 0, 3) + } + assert set(tasks) == expected_tasks + + def test_task_row_update_statement_max_tries(self, loaded_db): + expected_sql = ( + "UPDATE tasks SET status=:status, last_modified=:last_modified, " + "max_tries=:max_tries " + "WHERE tasks.taskid = :taskid_1" + ) + + with patch_datetime_now(): + stmt = loaded_db._task_row_update_statement( + "foo", + TaskStatus.AVAILABLE, + max_tries=10 + ) + + assert str(stmt) == expected_sql + + with loaded_db.engine.begin() as conn: + res = conn.execute(stmt) + + assert res.rowcount == 1 + tasks, deps = get_tasks_and_deps(loaded_db) + expected_tasks = { + ("foo", TaskStatus.AVAILABLE.value, _DEFAULT_DATETIME, 1, 10), + ("bar", TaskStatus.BLOCKED.value, None, 0, 3) + } def test_check_out_task(self, loaded_db): taskid = loaded_db.check_out_task()