Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions GUI/controllers/JobScheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from __future__ import annotations

"""Simple queued job scheduler using QThreadPool."""

import logging
from typing import Any, Dict, List, Optional, Tuple

from PyQt5.QtCore import QObject, QRunnable, QThreadPool, pyqtSignal

from GUI.models import JobDB


class _RunnableWrapper(QRunnable):
"""Internal wrapper that notifies scheduler when done."""

def __init__(self, job_id: int, runnable: QRunnable, scheduler: "JobScheduler") -> None:
super().__init__()
self.job_id = job_id
self.runnable = runnable
self.scheduler = scheduler

def run(self) -> None: # pragma: no cover - executed in worker thread
try:
self.runnable.run()
self.scheduler._job_done(self.job_id, True)
except Exception as exc: # pragma: no cover - defensive
logging.error("Job %s failed: %s", self.job_id, exc)
self.scheduler._job_done(self.job_id, False)


class JobScheduler(QObject):
"""Queue jobs and run them one at a time."""

job_started = pyqtSignal(int)
job_finished = pyqtSignal(int, bool)

def __init__(self, pool: Optional[QThreadPool] = None) -> None:
super().__init__()
self._pool = pool or QThreadPool.globalInstance()
self._queue: List[Tuple[int, QRunnable]] = []
self._current: Optional[Tuple[int, QRunnable]] = None
self._paused_jobs: set[int] = set()

# ------------------------------------------------------------------
def schedule_job(self, name: str, runnable: QRunnable, config: Optional[Dict[str, Any]] = None) -> int:
"""Add *runnable* to the queue and start if idle."""
job_id = JobDB.add_job(name, config)
self._queue.append((job_id, runnable))
if self._current is None:
self._start_next()
return job_id

def _start_next(self) -> None:
if not self._queue:
return
job_id, runnable = self._queue.pop(0)
self._current = (job_id, runnable)
JobDB.update_status(job_id, "running")
self.job_started.emit(job_id)
wrapper = _RunnableWrapper(job_id, runnable, self)
self._pool.start(wrapper)

def _job_done(self, job_id: int, ok: bool) -> None:
if job_id in self._paused_jobs:
self._paused_jobs.remove(job_id)
else:
JobDB.update_status(job_id, "completed" if ok else "failed")
self.job_finished.emit(job_id, ok)
self._current = None
self._start_next()

# ------------------------------------------------------------------
def pause_current(self) -> None:
if self._current is None:
return
job_id, runnable = self._current
if hasattr(runnable, "cancel"):
runnable.cancel() # type: ignore[attr-defined]
JobDB.update_status(job_id, "paused")
self._paused_jobs.add(job_id)

def resume_job(self, job_id: int, runnable: QRunnable) -> None:
"""Resume a paused job by re-queuing a new *runnable*."""
JobDB.update_status(job_id, "queued")
self._queue.insert(0, (job_id, runnable))
if self._current is None:
self._start_next()

def cancel_current(self) -> None:
if self._current is None:
return
job_id, runnable = self._current
if hasattr(runnable, "cancel"):
runnable.cancel() # type: ignore[attr-defined]
JobDB.update_status(job_id, "canceled")

def move_to_front(self, job_id: int) -> None:
"""Move a queued job to the front if present."""
for i, item in enumerate(self._queue):
if item[0] == job_id:
self._queue.insert(0, self._queue.pop(i))
break

29 changes: 14 additions & 15 deletions GUI/controllers/MainController.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ class MainController(QObject):
AUTOSAVE_IDLE_MS = 10_000
_UNASSESSED_LABELS = {-1}

def __init__(self, model: Optional[BaseImageDataModel], view: ClusteredCropsView, tasks_widget=None):
def __init__(self, model: Optional[BaseImageDataModel], view: ClusteredCropsView, tasks_widget=None, scheduler=None):
super().__init__()
self.image_data_model = model
self.view = view
self.tasks_widget = tasks_widget
self.scheduler = scheduler

# ---------- core sub‑controllers ---------------------------------
self.annotation_generator = LocalMaximaPointAnnotationGenerator()
Expand All @@ -52,6 +53,10 @@ def __init__(self, model: Optional[BaseImageDataModel], view: ClusteredCropsView
self.io = ProjectIOService(data_anchor=Path(model.data_path) if model else None)
self._export_usecase = ExportAnnotationsUseCase()
self.threadpool = QThreadPool.globalInstance()
if self.scheduler is None:
from GUI.controllers.JobScheduler import JobScheduler

self.scheduler = JobScheduler(pool=self.threadpool)

self.cluster_selector = make_selector("greedy", self.clustering_controller)

Expand All @@ -67,7 +72,6 @@ def __init__(self, model: Optional[BaseImageDataModel], view: ClusteredCropsView
self._idle_timer = QTimer(singleShot=True)
self._idle_timer.timeout.connect(self._autosave_if_dirty)

self._mc_widget = None
self._mc_worker = None

self._connect_signals()
Expand Down Expand Up @@ -691,23 +695,18 @@ def run_mc_banker(self, config: dict) -> None:

worker = MCBankerWorker(config, resume=resume)
self._mc_worker = worker
widget = getattr(self.tasks_widget, "mc_widget", None)
if widget is not None:
widget.start(str(output_file), total)
worker.signals.progress.connect(widget.update_progress)
widget.request_pause.connect(worker.pause)
widget.request_resume.connect(worker.resume_task)
widget.request_cancel.connect(worker.cancel)

worker.signals.finished.connect(self._on_mc_banker_finished)
self._mc_widget = widget
self.threadpool.start(worker)
if self.scheduler is not None:
job_id = self.scheduler.schedule_job("MC Banker", worker, config)
if getattr(self.tasks_widget, "add_job", None):
self.tasks_widget.add_job(job_id, "MC Banker", config, worker)
else:
self.threadpool.start(worker)
if getattr(self.tasks_widget, "add_job", None):
self.tasks_widget.add_job(-1, "MC Banker", config, worker)

@pyqtSlot(bool)
def _on_mc_banker_finished(self, success: bool) -> None:
if getattr(self, "_mc_widget", None):
self._mc_widget.finish()
self._mc_widget = None
self._mc_worker = None
msg = "HDF5 file generated." if success else "HDF5 generation failed."
QMessageBox.information(self.view, "MC Inference", msg)
Expand Down
122 changes: 122 additions & 0 deletions GUI/models/JobDB.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from __future__ import annotations

"""Simple SQLite-backed job tracking for the GUI scheduler."""

import json
import sqlite3
from pathlib import Path
from typing import Any, Dict, List, Optional

_DB_PATH = Path.home() / ".attentionunet" / "jobs.db"


def _ensure_db() -> None:
"""Create the jobs table if needed."""
_DB_PATH.parent.mkdir(parents=True, exist_ok=True)
with sqlite3.connect(_DB_PATH) as conn:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS jobs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT,
config TEXT,
status TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
started_at TIMESTAMP,
finished_at TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
)
"""
)
# columns added in later versions
cols = {row[1] for row in conn.execute("PRAGMA table_info(jobs)")}
if "started_at" not in cols:
conn.execute("ALTER TABLE jobs ADD COLUMN started_at TIMESTAMP")
if "finished_at" not in cols:
conn.execute("ALTER TABLE jobs ADD COLUMN finished_at TIMESTAMP")
if "updated_at" not in cols:
conn.execute(
"ALTER TABLE jobs ADD COLUMN updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP"
)
conn.commit()


def add_job(name: str, config: Optional[Dict[str, Any]] = None) -> int:
"""Insert a new job and return its ID."""
_ensure_db()
cfg = json.dumps(config or {})
with sqlite3.connect(_DB_PATH) as conn:
cur = conn.execute(
"INSERT INTO jobs (name, config, status) VALUES (?, ?, ?)",
(name, cfg, "queued"),
)
conn.commit()
return cur.lastrowid


def update_status(job_id: int, status: str) -> None:
"""Update the status for *job_id* and timestamps."""
_ensure_db()
with sqlite3.connect(_DB_PATH) as conn:
if status == "running":
conn.execute(
"UPDATE jobs SET status = ?, started_at = CURRENT_TIMESTAMP, updated_at = CURRENT_TIMESTAMP WHERE id = ?",
(status, job_id),
)
elif status in {"completed", "failed", "canceled"}:
conn.execute(
"UPDATE jobs SET status = ?, finished_at = CURRENT_TIMESTAMP, updated_at = CURRENT_TIMESTAMP WHERE id = ?",
(status, job_id),
)
else:
conn.execute(
"UPDATE jobs SET status = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?",
(status, job_id),
)
conn.commit()


def list_jobs(limit: int = 50) -> List[Dict[str, Any]]:
"""Return recently recorded jobs."""
_ensure_db()
with sqlite3.connect(_DB_PATH) as conn:
cur = conn.execute(
"SELECT id, name, config, status, created_at, started_at, finished_at FROM jobs ORDER BY id DESC LIMIT ?",
(limit,),
)
return [
{
"id": row[0],
"name": row[1],
"config": json.loads(row[2]) if row[2] else {},
"status": row[3],
"created_at": row[4],
"started_at": row[5],
"finished_at": row[6],
}
for row in cur.fetchall()
]


def get_job(job_id: int) -> Optional[Dict[str, Any]]:
"""Return a single job record or ``None``."""
_ensure_db()
with sqlite3.connect(_DB_PATH) as conn:
cur = conn.execute(
"SELECT id, name, config, status, created_at, started_at, finished_at FROM jobs WHERE id = ?",
(job_id,),
)
row = cur.fetchone()
if row:
return {
"id": row[0],
"name": row[1],
"config": json.loads(row[2]) if row[2] else {},
"status": row[3],
"created_at": row[4],
"started_at": row[5],
"finished_at": row[6],
}
return None


29 changes: 29 additions & 0 deletions GUI/unittests/test_job_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
os.environ.setdefault("QT_QPA_PLATFORM", "offscreen")

from GUI.models import JobDB


def test_add_and_list_jobs(tmp_path, monkeypatch):
db_file = tmp_path / "jobs.db"
monkeypatch.setattr(JobDB, "_DB_PATH", db_file)

jid1 = JobDB.add_job("job1", {"a": 1})
jid2 = JobDB.add_job("job2", {"b": 2})

jobs = JobDB.list_jobs(2)
assert jobs[0]["id"] == jid2
assert jobs[1]["id"] == jid1
assert jobs[0]["status"] == "queued"

JobDB.update_status(jid1, "running")
JobDB.update_status(jid1, "completed")
jobs = JobDB.list_jobs(2)
j1 = next(j for j in jobs if j["id"] == jid1)
assert j1["status"] == "completed"
assert j1["started_at"] is not None
assert j1["finished_at"] is not None

rec = JobDB.get_job(jid1)
assert rec is not None and rec["id"] == jid1

Loading