@@ -429,6 +429,14 @@ def edges(
429
429
for parent in parents
430
430
]
431
431
432
+ def nodes (self , plates : list [Plate ] | None = None ) -> list [NodeInfo ]:
433
+ """Get all nodes in the model graph."""
434
+ plates = plates or self .get_plates ()
435
+ nodes = []
436
+ for plate in plates :
437
+ nodes .extend (plate .variables )
438
+ return nodes
439
+
432
440
433
441
def make_graph (
434
442
name : str ,
@@ -785,3 +793,131 @@ def model_to_graphviz(
785
793
if include_dim_lengths
786
794
else create_plate_label_without_dim_length ,
787
795
)
796
+
797
+
798
+ def _build_mermaid_node (node : NodeInfo ) -> list [str ]:
799
+ var = node .var
800
+ node_type = node .node_type
801
+ if node_type == NodeType .DATA :
802
+ return [
803
+ f"{ var .name } [{ var .name } ~ Data]" ,
804
+ f"{ var .name } @{{ shape: db }}" ,
805
+ ]
806
+ elif node_type == NodeType .OBSERVED_RV :
807
+ return [
808
+ f"{ var .name } ([{ var .name } ~ { random_variable_symbol (var )} ])" ,
809
+ f"{ var .name } @{{ shape: rounded }}" ,
810
+ f"style { var .name } fill:#757575" ,
811
+ ]
812
+
813
+ elif node_type == NodeType .FREE_RV :
814
+ return [
815
+ f"{ var .name } ([{ var .name } ~ { random_variable_symbol (var )} ])" ,
816
+ f"{ var .name } @{{ shape: rounded }}" ,
817
+ ]
818
+ elif node_type == NodeType .DETERMINISTIC :
819
+ return [
820
+ f"{ var .name } ([{ var .name } ~ Deterministic])" ,
821
+ f"{ var .name } @{{ shape: rect }}" ,
822
+ ]
823
+ elif node_type == NodeType .POTENTIAL :
824
+ return [
825
+ f"{ var .name } ([{ var .name } ~ Potential])" ,
826
+ f"{ var .name } @{{ shape: diam }}" ,
827
+ f"style { var .name } fill:#f0f0f0" ,
828
+ ]
829
+
830
+ return []
831
+
832
+
833
+ def _build_mermaid_nodes (nodes ) -> list [str ]:
834
+ node_lines = []
835
+ for node in nodes :
836
+ node_lines .extend (_build_mermaid_node (node ))
837
+
838
+ return node_lines
839
+
840
+
841
+ def _build_mermaid_edges (edges ) -> list [str ]:
842
+ """Return a list of Mermaid edge definitions."""
843
+ edge_lines = []
844
+ for child , parent in edges :
845
+ child_id = str (child ).replace (":" , "_" )
846
+ parent_id = str (parent ).replace (":" , "_" )
847
+ edge_lines .append (f"{ parent_id } --> { child_id } " )
848
+ return edge_lines
849
+
850
+
851
+ def _build_mermaid_plates (plates , include_dim_lengths ) -> list [str ]:
852
+ plate_lines = []
853
+ for plate in plates :
854
+ if not plate .dim_info :
855
+ continue
856
+
857
+ plate_label_func = (
858
+ create_plate_label_with_dim_length
859
+ if include_dim_lengths
860
+ else create_plate_label_without_dim_length
861
+ )
862
+ plate_label = plate_label_func (plate .dim_info )
863
+ plate_name = f'subgraph "{ plate_label } "'
864
+ plate_lines .append (plate_name )
865
+ for var in plate .variables :
866
+ plate_lines .append (f" { var .var .name } " )
867
+ plate_lines .append ("end" )
868
+
869
+ return plate_lines
870
+
871
+
872
+ def model_to_mermaid (model = None , * , var_names = None , include_dim_lengths : bool = True ) -> str :
873
+ """Produce a Mermaid diagram string from a PyMC model.
874
+
875
+ Parameters
876
+ ----------
877
+ model : pm.Model
878
+ The model to plot. Not required when called from inside a modelcontext.
879
+ var_names : iterable of variable names, optional
880
+ Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
881
+ include_dim_lengths : bool
882
+ Include the dim lengths in the plate label. Default is True.
883
+
884
+ Returns
885
+ -------
886
+ str
887
+ Mermaid diagram string representing the model graph.
888
+
889
+ Examples
890
+ --------
891
+ Visualize a simple PyMC model
892
+
893
+ .. code-block:: python
894
+
895
+ import pymc as pm
896
+
897
+ with pm.Model() as model:
898
+ mu = pm.Normal("mu", mu=0, sigma=1)
899
+ sigma = pm.HalfNormal("sigma", sigma=1)
900
+
901
+ pm.Normal("obs", mu=mu, sigma=sigma, observed=[1, 2, 3])
902
+
903
+ print(pm.model_to_mermaid(model))
904
+
905
+
906
+ """
907
+ model = pm .modelcontext (model )
908
+ graph = ModelGraph (model )
909
+ plates = sorted (graph .get_plates (var_names = var_names ), key = lambda plate : hash (plate .dim_info ))
910
+ edges = sorted (graph .edges (var_names = var_names ))
911
+ nodes = sorted (graph .nodes (plates = plates ), key = lambda node : cast (str , node .var .name ))
912
+
913
+ return "\n " .join (
914
+ [
915
+ "graph TD" ,
916
+ "%% Nodes:" ,
917
+ * _build_mermaid_nodes (nodes ),
918
+ "\n %% Edges:" ,
919
+ * _build_mermaid_edges (edges ),
920
+ "\n %% Plates:" ,
921
+ * _build_mermaid_plates (plates , include_dim_lengths = include_dim_lengths ),
922
+ ]
923
+ )
0 commit comments