Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 2 additions & 14 deletions event_bus.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,12 @@
class EventBus:
def __init__(self):
self._listeners: defaultdict[Type[BaseEvent], list[Callable[[Any], None]]] = defaultdict(list)
self._dispatch_wrapper: Callable[[Callable], None] | None = None

def register_dispatch_wrapper(self, fn: Callable[[Callable], None]):
"""Set a wrapper for dispatching listeners (e.g., Textual's app.call_from_thread)."""
self._dispatch_wrapper = fn

def subscribe(self, event_type: Type[BaseEvent], listener: Callable[[Any], None]):
"""Registers a listener for a specific event type."""
self._listeners[event_type].append(listener)

def publish(self, event: BaseEvent):
"""Publishes an event to all registered listeners."""

def _dispatch():
for listener in self._listeners[type(event)]:
listener(event)

if self._dispatch_wrapper:
self._dispatch_wrapper(_dispatch)
else:
_dispatch()
for listener in self._listeners[type(event)]:
listener(event)
3 changes: 3 additions & 0 deletions module_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
run_state: RunState,
event_bus: EventBus,
stop_event: threading.Event | None = None,
enter_pause_event: threading.Event | None = None,
):
self.codeplainAPI = codeplainAPI
self.filename = filename
Expand All @@ -39,6 +40,7 @@ def __init__(
self.run_state = run_state
self.event_bus = event_bus
self.stop_event = stop_event
self.enter_pause_event = enter_pause_event

def _ensure_module_folders_exist(self, module_name: str, first_render_frid: str) -> tuple[str, str]:
"""
Expand Down Expand Up @@ -181,6 +183,7 @@ def _build_render_context_for_module(
event_bus=self.event_bus,
test_script_timeout=self.args.test_script_timeout,
stop_event=self.stop_event,
enter_pause_event=self.enter_pause_event,
)

def _render_module(
Expand Down
3 changes: 3 additions & 0 deletions plain2code.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def render(args, run_state: RunState, event_bus: EventBus): # noqa: C901
_check_connection(codeplainAPI)

stop_event = threading.Event()
enter_pause_event = threading.Event()
Comment thread
pedjaradenkovic marked this conversation as resolved.
signal.signal(signal.SIGTERM, lambda _signum, _frame: stop_event.set())

module_renderer = ModuleRenderer(
Expand All @@ -214,6 +215,7 @@ def render(args, run_state: RunState, event_bus: EventBus): # noqa: C901
run_state,
event_bus,
stop_event=stop_event,
enter_pause_event=enter_pause_event,
)

render_error: list[Exception] = []
Expand Down Expand Up @@ -245,6 +247,7 @@ def run_render():
conformance_tests_script=args.conformance_tests_script,
prepare_environment_script=args.prepare_environment_script,
state_machine_version=system_config.client_version,
enter_pause_event=enter_pause_event,
css_path="styles.css",
)
app.run()
Expand Down
5 changes: 5 additions & 0 deletions plain2code_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,8 @@ class RenderModuleCompleted(BaseEvent):
@dataclass
class RenderModuleStarted(BaseEvent):
module_name: str


@dataclass
class RenderPaused(BaseEvent):
pass
12 changes: 11 additions & 1 deletion render_machine/code_renderer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import time
from copy import deepcopy

from transitions.extensions.diagrams import HierarchicalGraphMachine

from plain2code_events import RenderModuleCompleted, RenderModuleStarted, RenderStateUpdated
from plain2code_events import RenderModuleCompleted, RenderModuleStarted, RenderPaused, RenderStateUpdated
from render_machine.render_context import RenderContext
from render_machine.state_machine_config import StateMachineConfig, States

PAUSE_POLL_INTERVAL_SECONDS = 1


class CodeRenderer:
"""Main code renderer class that orchestrates the code generation workflow using a hierarchical state machine."""
Expand Down Expand Up @@ -35,7 +38,14 @@ def run(self):
self.render_context.event_bus.publish(RenderModuleStarted(module_name=self.render_context.module_name))
previous_action_payload = None
previous_state = None

while True:
if self.render_context.enter_pause_event.is_set():
self.render_context.event_bus.publish(RenderPaused())

while self.render_context.enter_pause_event.is_set():
time.sleep(PAUSE_POLL_INTERVAL_SECONDS)

self.render_context.event_bus.publish(
RenderStateUpdated(
state=self.render_context.state,
Expand Down
2 changes: 2 additions & 0 deletions render_machine/render_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__(
event_bus: EventBus,
test_script_timeout: Optional[int] = None,
stop_event: Optional[threading.Event] = None,
enter_pause_event: Optional[threading.Event] = None,
):
self.codeplain_api: CodeplainAPI = codeplain_api
self.memory_manager = memory_manager
Expand All @@ -77,6 +78,7 @@ def __init__(
self.run_state = run_state
self.event_bus = event_bus
self.stop_event = stop_event
self.enter_pause_event = enter_pause_event
self.script_execution_history = ScriptExecutionHistory()
self.starting_frid = None
self.test_script_timeout = test_script_timeout
Expand Down
85 changes: 59 additions & 26 deletions tui/components.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import time
from enum import Enum
from typing import Optional
from typing import Literal, Optional

from textual.containers import Horizontal, Vertical, VerticalScroll
from textual.message import Message
from textual.timer import Timer
from textual.widgets import Button, Static

from .models import Substate
Expand All @@ -13,23 +13,39 @@
class CustomFooter(Horizontal):
"""A custom footer with keyboard shortcuts and render ID."""

FOOTER_TEXT = "ctrl+c: copy * ctrl+d: quit * ctrl+l: toggle logs"
FOOTER_BASE_TEXT = "ctrl+c: copy * ctrl+d: quit * ctrl+l: toggle logs"
FOOTER_RENDERING_TEXT = FOOTER_BASE_TEXT + " * ctrl+p: pause"
RENDER_PAUSING_TEXT = FOOTER_BASE_TEXT + " * pausing ..."
RENDER_PAUSED_TEXT = FOOTER_BASE_TEXT + " * ctrl+p: resume"
RENDER_FINISHED_TEXT = "enter: exit * ctrl+c: copy * ctrl+l: toggle logs"

def __init__(self, render_id: str = "", **kwargs):
super().__init__(**kwargs)
self.render_id = render_id

def compose(self):
self._footer_text_widget = Static(self.FOOTER_TEXT, classes="custom-footer-text")
self._footer_text_widget = Static(self.FOOTER_RENDERING_TEXT, classes="custom-footer-text")
yield self._footer_text_widget
if self.render_id:
yield Static(f"render id: {self.render_id} ", classes="custom-footer-render-id")

def show_render_finished(self) -> None:
"""Update footer text to show render-finished keybindings."""
def update_footer_state(self, state: Literal["rendering", "pausing", "paused", "finished"]) -> None:
self.remove_class("footer-state-default")
self.remove_class("footer-state-paused")

if self._footer_text_widget is not None:
self._footer_text_widget.update(self.RENDER_FINISHED_TEXT)
if state == "rendering":
self._footer_text_widget.update(self.FOOTER_RENDERING_TEXT)
self.add_class("footer-state-default")
elif state == "pausing":
self._footer_text_widget.update(self.RENDER_PAUSING_TEXT)
self.add_class("footer-state-paused")
elif state == "paused":
self._footer_text_widget.update(self.RENDER_PAUSED_TEXT)
self.add_class("footer-state-paused")
elif state == "finished":
self._footer_text_widget.update(self.RENDER_FINISHED_TEXT)
self.add_class("footer-state-default")


class ScriptOutputType(str, Enum):
Expand Down Expand Up @@ -90,23 +106,40 @@ class TUIComponents(str, Enum):
class SubstateLine(Horizontal):
"""A single substate row with an attached timer."""

def __init__(self, text: str, indent: str, **kwargs):
def __init__(self, text: str, indent: str, progress_status: str, **kwargs):
super().__init__(**kwargs)
self.text = text
self.indent = indent
self.start_time = time.monotonic()
self._progress_status = progress_status
self._line_widget: Static | None = None
self._timer: Timer | None = None
self._seconds_elapsed = 0

def compose(self):
self._line_widget = Static(self._format_line(), classes="substate-line-text")
yield self._line_widget

def on_mount(self) -> None:
self._refresh_timer()
self.set_interval(1, self._refresh_timer)
self._timer = self.set_interval(1, self._add_second)
if self._progress_status == ProgressItem.PAUSED:
self._timer.pause()

def set_progress_status(self, progress_status: str) -> None:
self._progress_status = progress_status
if self._timer is None:
return
if progress_status == ProgressItem.PAUSED:
self._timer.pause()
else:
self._timer.resume()

def _add_second(self) -> None:
self._seconds_elapsed += 1
self._refresh_timer()

def _format_timer(self) -> str:
elapsed = int(time.monotonic() - self.start_time)
elapsed = int(self._seconds_elapsed)
if elapsed < 60:
return f"{elapsed}s"
minutes = elapsed // 60
Expand Down Expand Up @@ -135,10 +168,13 @@ class ProgressItem(Vertical):
PROCESSING = "PROCESSING"
COMPLETED = "COMPLETED"
STOPPED = "STOPPED"
PAUSED = "PAUSED"
PAUSING = "PAUSING"

def __init__(self, initial_text: str, **kwargs):
super().__init__(**kwargs)
self.initial_text = initial_text
self.current_status = self.PENDING

def compose(self):
# Main row with status and description
Expand All @@ -156,11 +192,16 @@ def _get_status_text(self, status: str) -> str:
return "◉ processing"
elif status == self.STOPPED:
return "◼ stopped"
elif status == self.PAUSING:
return "◉ pausing"
elif status == self.PAUSED:
return "⏸ paused"
else:
return "○ pending"

async def update_status(self, status: str):
# TODO: Move to plain2code_tui.py
self.current_status = status
try:
# Get the main row container
main_row = self.query_one(f"#{self.id}-main-row", Horizontal)
Expand All @@ -173,21 +214,20 @@ async def update_status(self, status: str):
pass

# Add appropriate widget based on status
if status == self.PROCESSING:
if status == self.PROCESSING or status == self.PAUSING:
# Use spinner for processing state
spinner = Spinner(text="processing", classes=f"status {status}")
spinner = Spinner(
text="processing" if status == self.PROCESSING else "pausing", classes=f"status {status}"
)
await main_row.mount(spinner, before=0)
else:
# Use static text for pending/completed
status_widget = Static(self._get_status_text(status), classes=f"status {status}")
await main_row.mount(status_widget, before=0)

except Exception:
pass
for line in self.query(SubstateLine):
line.set_progress_status(status)

def update_text(self, text: str):
try:
self.query_one(".description", Static).update(text)
except Exception:
pass

Expand Down Expand Up @@ -224,7 +264,7 @@ async def _render_substates_recursive(self, container: Vertical, substates: list

for substate in substates:
# Render the current substate
substate_widget = SubstateLine(substate.text, indent, classes="substate-row")
substate_widget = SubstateLine(substate.text, indent, self.current_status, classes="substate-row")
await container.mount(substate_widget)

# Recursively render children if they exist
Expand Down Expand Up @@ -376,13 +416,6 @@ def update_functionality_text(self, text: str) -> None:
except Exception:
pass

def update_fr_status(self, status: str) -> None:
try:
widget = self.query_one(f"#{TUIComponents.FRID_PROGRESS_RENDER_FR.value}", ProgressItem)
self.call_later(widget.update_status, status)
except Exception:
pass

def on_mount(self) -> None:
self.border_title = "FRID Progress"

Expand Down
Loading
Loading