Skip to content

Commit 131eee4

Browse files
committed
feat(py): improve mermaid code generation
1 parent d6772d3 commit 131eee4

File tree

1 file changed

+164
-24
lines changed

1 file changed

+164
-24
lines changed

libs/wrappers/python/rtbot.py

Lines changed: 164 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -250,38 +250,178 @@ def instantiate_prototype(self, instance_id: str, prototype_id: str, parameters:
250250
return self
251251

252252
def to_mermaid(self) -> str:
253-
"""Convert program structure to Mermaid.js flowchart representation."""
254-
# Start with flowchart definition, left-to-right
253+
"""Convert program structure to Mermaid.js flowchart representation with improved layout."""
255254
lines = ["flowchart LR"]
255+
nodes = {}
256+
rank_groups = {}
256257

257-
# Create nodes for each operator
258-
for op in self.operators:
259-
op_id = op["id"]
260-
op_type = op["type"]
258+
def get_short_name(op_type: str, params: dict = None, op_id: str = "") -> str:
259+
"""Get shortened operator name with parameters."""
260+
if isinstance(op_type, dict):
261+
if "prototype" in op_type:
262+
proto_name = op_type["prototype"]
263+
params = op_type.get("parameters", {})
264+
param_str = ", ".join(f"{k}={v}" for k, v in params.items())
265+
return f"{proto_name}\\n({param_str})"
266+
elif "type" in op_type:
267+
return get_short_name(op_type["type"], op_type.get("parameters", {}), op_id)
268+
else:
269+
# If we can't determine the type, use the operator's ID
270+
return op_id.split("::")[-1]
271+
272+
type_map = {
273+
"LogicalAnd": "AND",
274+
"LogicalOr": "OR",
275+
"LogicalNot": "NOT",
276+
"MovingAverage": "MA",
277+
"StandardDeviation": "StdDev",
278+
"GreaterThan": ">",
279+
"LessThan": "<",
280+
"EqualTo": "=",
281+
"ResamplerHermite": "Hermite",
282+
"ResamplerConstant": "Resampler",
283+
"ConstantNumber": "", # Will show just the value
284+
"ConstantBoolean": "", # Will show just the value
285+
"ConstantNumberToBoolean": "→bool",
286+
"ConstantBooleanToNumber": "→num",
287+
}
288+
289+
# Get base name, fallback to operator ID if type unknown
290+
base_name = type_map.get(op_type, op_type)
291+
if base_name not in type_map.values() and op_id:
292+
base_name = op_id.split("::")[-1]
293+
294+
# Add parameters if available
295+
if params:
296+
if op_type == "GreaterThan":
297+
return f"> {params.get('value', '')}"
298+
elif op_type == "LessThan":
299+
return f"< {params.get('value', '')}"
300+
elif op_type == "EqualTo":
301+
return f"= {params.get('value', '')}"
302+
elif op_type == "ConstantNumber":
303+
return str(params.get('value', ''))
304+
elif op_type == "ConstantBoolean":
305+
return str(params.get('value', '')).lower()
306+
elif op_type == "MovingAverage":
307+
return f"MA({params.get('window_size', '')})"
308+
elif op_type == "ResamplerHermite":
309+
return f"Hermite({params.get('interval', '')})"
310+
elif op_type == "StandardDeviation":
311+
return f"StdDev({params.get('window_size', '')})"
312+
else:
313+
# For any other operator type, show all parameters
314+
param_str = ", ".join(f"{k}={v}" for k, v in params.items())
315+
return f"{base_name}({param_str})"
261316

262-
# Special styling for entry operator
317+
return base_name
318+
319+
def add_node(op_id: str, op_type: str, level: int = 0):
320+
styles = []
263321
if op_id == self.entryOperator:
264-
lines.append(f' {op_id}["{op_type}\\n{op_id}"]:::entry')
265-
# Special styling for output operators
266-
elif op_id in self.output:
267-
lines.append(f' {op_id}["{op_type}\\n{op_id}"]:::output')
268-
else:
269-
lines.append(f' {op_id}["{op_type}\\n{op_id}"]')
322+
styles.append("entry")
323+
if op_id in self.output:
324+
styles.append("output")
325+
if isinstance(op_type, dict) and "prototype" in op_type:
326+
styles.append("prototype")
327+
328+
style = ":::" + ",".join(styles) if styles else ""
329+
330+
node_id = op_id.replace("::", "_")
331+
nodes[op_id] = node_id
332+
333+
if level not in rank_groups:
334+
rank_groups[level] = []
335+
rank_groups[level].append(node_id)
336+
337+
# Get parameters and create node label
338+
params = op_type.get("parameters", {}) if isinstance(op_type, dict) else None
339+
op_name = get_short_name(op_type if isinstance(op_type, str) else op_type.get("type", "Unknown"),
340+
params, op_id)
341+
342+
# Use hexagon shape for prototype instances
343+
shape = "{{" if isinstance(op_type, dict) and "prototype" in op_type else "["
344+
end_shape = "}}" if isinstance(op_type, dict) and "prototype" in op_type else "]"
345+
346+
display_name = op_name if op_name else op_id.split("::")[-1]
347+
lines.append(f' {node_id}{shape}"{display_name}"{end_shape}{style}')
348+
349+
# Handle prototypes and pipelines
350+
if isinstance(op_type, dict) and "prototype" in op_type:
351+
proto_name = op_type["prototype"]
352+
proto_def = self.prototypes[proto_name]
353+
# Add prototype internals
354+
for internal_op in proto_def["operators"]:
355+
internal_id = f"{op_id}::{internal_op['id']}"
356+
internal_type = internal_op
357+
if isinstance(internal_op, dict):
358+
# Resolve template parameters
359+
params = {}
360+
for k, v in internal_op.get("parameters", {}).items():
361+
if isinstance(v, str) and v.startswith("${") and v.endswith("}"):
362+
param_name = v[2:-1]
363+
if param_name in op_type.get("parameters", {}):
364+
params[k] = op_type["parameters"][param_name]
365+
else:
366+
params[k] = v
367+
internal_type = {"type": internal_op["type"], "parameters": params}
368+
add_node(internal_id, internal_type, level + 1)
369+
370+
# Add prototype connections
371+
for conn in proto_def["connections"]:
372+
from_id = f"{op_id}::{conn['from']}"
373+
to_id = f"{op_id}::{conn['to']}"
374+
add_connection(from_id, to_id, conn.get("fromPort", "o1"), conn.get("toPort", "i1"))
375+
376+
elif isinstance(op_type, dict) and op_type.get("type") == "Pipeline":
377+
# Add pipeline internals
378+
for internal_op in op_type["operators"]:
379+
internal_id = f"{op_id}::{internal_op['id']}"
380+
add_node(internal_id, internal_op["type"], level + 1)
381+
# Add pipeline connections
382+
for conn in op_type["connections"]:
383+
from_id = f"{op_id}::{conn['from']}"
384+
to_id = f"{op_id}::{conn['to']}"
385+
add_connection(from_id, to_id, conn.get("fromPort", "o1"), conn.get("toPort", "i1"))
386+
387+
def add_connection(from_op: str, to_op: str, from_port: str = "o1", to_port: str = "i1"):
388+
from_node = nodes[from_op]
389+
to_node = nodes[to_op]
390+
lines.append(f' {from_node} -- "{from_port}{to_port}" --> {to_node}')
391+
392+
# First pass: create all nodes with proper levels
393+
for op in self.operators:
394+
add_node(op["id"], op.get("type", op), 0)
270395

271-
# Add connections
396+
# Second pass: add connections
272397
for conn in self.connections:
273-
from_op = conn["from"]
274-
to_op = conn["to"]
275-
from_port = conn.get("fromPort", "o1")
276-
to_port = conn.get("toPort", "i1")
277-
278-
# Add port labels to connection
279-
lines.append(f' {from_op} -- "{from_port}{to_port}" --> {to_op}')
398+
add_connection(conn["from"], conn["to"],
399+
conn.get("fromPort", "o1"),
400+
conn.get("toPort", "i1"))
401+
402+
# Add subgraph rankings to enforce left-to-right layout
403+
max_level = max(rank_groups.keys()) if rank_groups else 0
404+
for level in range(max_level + 1):
405+
if level in rank_groups and rank_groups[level]:
406+
lines.append(f" subgraph level_{level} [\" \"]")
407+
lines.append(" direction LR") # Force left-to-right direction within subgraph
408+
lines.append(" " + " & ".join(rank_groups[level]))
409+
lines.append(" end")
410+
411+
# Add invisible edges to force ordering between levels
412+
for level in range(max_level):
413+
if level in rank_groups and (level + 1) in rank_groups:
414+
if rank_groups[level] and rank_groups[level + 1]:
415+
first_node = rank_groups[level][0]
416+
next_node = rank_groups[level + 1][0]
417+
lines.append(f" {first_node} ~~~ {next_node}") # Invisible edge
280418

281-
# Add class definitions
419+
# Add styling
282420
lines.extend([
283-
" classDef entry fill:#f96",
284-
" classDef output fill:#9cf"
421+
" classDef entry fill:#f96,stroke:#333,stroke-width:2px",
422+
" classDef output fill:#9cf,stroke:#333,stroke-width:2px",
423+
" classDef prototype fill:#f0f0f0,stroke:#666,stroke-width:2px,stroke-dasharray: 5 5",
424+
" classDef default fill:white,stroke:#333,stroke-width:1px"
285425
])
286426

287427
return "\n".join(lines)

0 commit comments

Comments
 (0)