Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions GUI/controllers/JobScheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from __future__ import annotations

"""Simple queued job scheduler using QThreadPool."""

import logging
from typing import Any, Dict, List, Optional, Tuple

from PyQt5.QtCore import QObject, QRunnable, QThreadPool, pyqtSignal

from GUI.models import JobDB


class _RunnableWrapper(QRunnable):
"""Internal wrapper that notifies scheduler when done."""

def __init__(self, job_id: int, runnable: QRunnable, scheduler: "JobScheduler") -> None:
super().__init__()
self.job_id = job_id
self.runnable = runnable
self.scheduler = scheduler

def run(self) -> None: # pragma: no cover - executed in worker thread
try:
self.runnable.run()
self.scheduler._job_done(self.job_id, True)
except Exception as exc: # pragma: no cover - defensive
logging.error("Job %s failed: %s", self.job_id, exc)
self.scheduler._job_done(self.job_id, False)


class JobScheduler(QObject):
"""Queue jobs and run them one at a time."""

job_started = pyqtSignal(int)
job_finished = pyqtSignal(int, bool)

def __init__(self, pool: Optional[QThreadPool] = None) -> None:
super().__init__()
self._pool = pool or QThreadPool.globalInstance()
self._queue: List[Tuple[int, QRunnable]] = []
self._current: Optional[Tuple[int, QRunnable]] = None

# ------------------------------------------------------------------
def schedule_job(self, name: str, runnable: QRunnable, config: Optional[Dict[str, Any]] = None) -> int:
"""Add *runnable* to the queue and start if idle."""
job_id = JobDB.add_job(name, config)
self._queue.append((job_id, runnable))
if self._current is None:
self._start_next()
return job_id

def _start_next(self) -> None:
if not self._queue:
return
job_id, runnable = self._queue.pop(0)
self._current = (job_id, runnable)
JobDB.update_status(job_id, "running")
self.job_started.emit(job_id)
wrapper = _RunnableWrapper(job_id, runnable, self)
self._pool.start(wrapper)

def _job_done(self, job_id: int, ok: bool) -> None:
JobDB.update_status(job_id, "completed" if ok else "failed")
self.job_finished.emit(job_id, ok)
self._current = None
self._start_next()

# ------------------------------------------------------------------
def pause_current(self) -> None:
if self._current is None:
return
job_id, runnable = self._current
if hasattr(runnable, "pause"):
runnable.pause() # type: ignore[attr-defined]
JobDB.update_status(job_id, "paused")

def resume_current(self) -> None:
if self._current is None:
return
job_id, runnable = self._current
if hasattr(runnable, "resume_task"):
runnable.resume_task() # type: ignore[attr-defined]
JobDB.update_status(job_id, "running")

def move_to_front(self, job_id: int) -> None:
"""Move a queued job to the front if present."""
for i, item in enumerate(self._queue):
if item[0] == job_id:
self._queue.insert(0, self._queue.pop(i))
break

def cancel_job(self, job_id: int) -> None:
"""Cancel a running or queued job."""
if self._current and self._current[0] == job_id:
_, runnable = self._current
if hasattr(runnable, "cancel"):
runnable.cancel() # type: ignore[attr-defined]
JobDB.update_status(job_id, "cancelled")
self._current = None
self._start_next()
return
for i, item in enumerate(self._queue):
if item[0] == job_id:
self._queue.pop(i)
JobDB.update_status(job_id, "cancelled")
break

28 changes: 13 additions & 15 deletions GUI/controllers/MainController.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ class MainController(QObject):
AUTOSAVE_IDLE_MS = 10_000
_UNASSESSED_LABELS = {-1}

def __init__(self, model: Optional[BaseImageDataModel], view: ClusteredCropsView, tasks_widget=None):
def __init__(self, model: Optional[BaseImageDataModel], view: ClusteredCropsView, tasks_widget=None, scheduler=None):
super().__init__()
self.image_data_model = model
self.view = view
self.tasks_widget = tasks_widget
self.scheduler = scheduler

# ---------- core sub‑controllers ---------------------------------
self.annotation_generator = LocalMaximaPointAnnotationGenerator()
Expand All @@ -52,6 +53,10 @@ def __init__(self, model: Optional[BaseImageDataModel], view: ClusteredCropsView
self.io = ProjectIOService(data_anchor=Path(model.data_path) if model else None)
self._export_usecase = ExportAnnotationsUseCase()
self.threadpool = QThreadPool.globalInstance()
if self.scheduler is None:
from GUI.controllers.JobScheduler import JobScheduler

self.scheduler = JobScheduler(pool=self.threadpool)

self.cluster_selector = make_selector("greedy", self.clustering_controller)

Expand All @@ -67,7 +72,6 @@ def __init__(self, model: Optional[BaseImageDataModel], view: ClusteredCropsView
self._idle_timer = QTimer(singleShot=True)
self._idle_timer.timeout.connect(self._autosave_if_dirty)

self._mc_widget = None
self._mc_worker = None

self._connect_signals()
Expand Down Expand Up @@ -691,23 +695,17 @@ def run_mc_banker(self, config: dict) -> None:

worker = MCBankerWorker(config, resume=resume)
self._mc_worker = worker
widget = getattr(self.tasks_widget, "mc_widget", None)
if widget is not None:
widget.start(str(output_file), total)
worker.signals.progress.connect(widget.update_progress)
widget.request_pause.connect(worker.pause)
widget.request_resume.connect(worker.resume_task)
widget.request_cancel.connect(worker.cancel)

worker.signals.finished.connect(self._on_mc_banker_finished)
self._mc_widget = widget
self.threadpool.start(worker)

if self.scheduler is not None:
job_id = self.scheduler.schedule_job("MC Banker", worker, config)
if hasattr(self.tasks_widget, "add_active_job"):
self.tasks_widget.add_active_job(job_id, "MC Banker", config, worker)
else:
self.threadpool.start(worker)

@pyqtSlot(bool)
def _on_mc_banker_finished(self, success: bool) -> None:
if getattr(self, "_mc_widget", None):
self._mc_widget.finish()
self._mc_widget = None
self._mc_worker = None
msg = "HDF5 file generated." if success else "HDF5 generation failed."
QMessageBox.information(self.view, "MC Inference", msg)
Expand Down
98 changes: 98 additions & 0 deletions GUI/models/JobDB.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from __future__ import annotations

"""Simple SQLite-backed job tracking for the GUI scheduler."""

import json
import sqlite3
from pathlib import Path
from typing import Any, Dict, List, Optional

_DB_PATH = Path.home() / ".attentionunet" / "jobs.db"


def _ensure_db() -> None:
"""Create the jobs table if needed."""
_DB_PATH.parent.mkdir(parents=True, exist_ok=True)
with sqlite3.connect(_DB_PATH) as conn:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS jobs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT,
config TEXT,
status TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP,
finished_at TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
conn.commit()


def add_job(name: str, config: Optional[Dict[str, Any]] = None) -> int:
"""Insert a new job and return its ID."""
_ensure_db()
cfg = json.dumps(config or {})
with sqlite3.connect(_DB_PATH) as conn:
cur = conn.execute(
"INSERT INTO jobs (name, config, status) VALUES (?, ?, ?)",
(name, cfg, "queued"),
)
conn.commit()
return cur.lastrowid


def update_status(job_id: int, status: str) -> None:
"""Update the status for *job_id*."""
_ensure_db()
with sqlite3.connect(_DB_PATH) as conn:
if status == "running":
conn.execute(
"UPDATE jobs SET status = ?, started_at = CURRENT_TIMESTAMP, updated_at = CURRENT_TIMESTAMP WHERE id = ?",
(status, job_id),
)
elif status in {"completed", "failed", "cancelled"}:
conn.execute(
"UPDATE jobs SET status = ?, finished_at = CURRENT_TIMESTAMP, updated_at = CURRENT_TIMESTAMP WHERE id = ?",
(status, job_id),
)
else:
conn.execute(
"UPDATE jobs SET status = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?",
(status, job_id),
)
conn.commit()


def list_jobs(limit: int = 50) -> List[Dict[str, Any]]:
"""Return recently recorded jobs."""
_ensure_db()
with sqlite3.connect(_DB_PATH) as conn:
cur = conn.execute(
"SELECT id, name, config, status, started_at, finished_at FROM jobs ORDER BY id DESC LIMIT ?",
(limit,),
)
return [
{
"id": row[0],
"name": row[1],
"config": json.loads(row[2]) if row[2] else {},
"status": row[3],
"started_at": row[4] or "",
"finished_at": row[5] or "",
}
for row in cur.fetchall()
]


def delete_job(job_id: int) -> None:
"""Remove a job record."""
_ensure_db()
with sqlite3.connect(_DB_PATH) as conn:
conn.execute("DELETE FROM jobs WHERE id = ?", (job_id,))
conn.commit()



30 changes: 30 additions & 0 deletions GUI/unittests/test_job_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
os.environ.setdefault("QT_QPA_PLATFORM", "offscreen")

from GUI.models import JobDB


def test_add_and_list_jobs(tmp_path, monkeypatch):
db_file = tmp_path / "jobs.db"
monkeypatch.setattr(JobDB, "_DB_PATH", db_file)

jid1 = JobDB.add_job("job1", {"a": 1})
jid2 = JobDB.add_job("job2", {"b": 2})

jobs = JobDB.list_jobs(2)
assert jobs[0]["id"] == jid2
assert jobs[1]["id"] == jid1
assert jobs[0]["status"] == "queued"

JobDB.update_status(jid1, "running")
jobs = JobDB.list_jobs(2)
running = next(j for j in jobs if j["id"] == jid1)
assert running["status"] == "running"
assert running["started_at"]

JobDB.update_status(jid1, "completed")
jobs = JobDB.list_jobs(2)
finished = next(j for j in jobs if j["id"] == jid1)
assert finished["finished_at"]


85 changes: 85 additions & 0 deletions GUI/unittests/test_job_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import os
os.environ.setdefault("QT_QPA_PLATFORM", "offscreen")

from typing import Optional

from PyQt5.QtCore import QRunnable

from GUI.controllers.JobScheduler import JobScheduler
from GUI.models import JobDB


class DummyPool:
def __init__(self):
self.runnable: Optional[QRunnable] = None

def start(self, runnable: QRunnable):
# Defer execution until explicitly triggered in the test
self.runnable = runnable


class DummyTask(QRunnable):
def __init__(self):
super().__init__()
self.count = 0

def run(self):
self.count += 1


class PausableTask(DummyTask):
def __init__(self):
super().__init__()
self.paused = False

def pause(self):
self.paused = True

def resume_task(self):
self.paused = False


def test_run_two_jobs_in_order(tmp_path, monkeypatch):
monkeypatch.setattr(JobDB, "_DB_PATH", tmp_path / "jobs.db")
pool = DummyPool()
scheduler = JobScheduler(pool=pool)

t1 = DummyTask()
t2 = DummyTask()
jid1 = scheduler.schedule_job("j1", t1)
# execute first queued job
assert pool.runnable is not None
pool.runnable.run()

jid2 = scheduler.schedule_job("j2", t2)
assert pool.runnable is not None
pool.runnable.run()

jobs = JobDB.list_jobs(2)
statuses = {j["id"]: j["status"] for j in jobs}
assert statuses[jid1] == "completed"
assert statuses[jid2] == "completed"
assert jobs[0]["started_at"] and jobs[0]["finished_at"]
assert t1.count == 1 and t2.count == 1


def test_pause_and_resume(tmp_path, monkeypatch):
monkeypatch.setattr(JobDB, "_DB_PATH", tmp_path / "jobs.db")
pool = DummyPool()
scheduler = JobScheduler(pool=pool)

t = PausableTask()
jid = scheduler.schedule_job("job", t)
scheduler.pause_current()
assert t.paused
scheduler.resume_current()
assert not t.paused

assert pool.runnable is not None
pool.runnable.run()

jobs = JobDB.list_jobs(1)
assert jobs[0]["id"] == jid
assert jobs[0]["status"] == "completed"


Loading