-
Notifications
You must be signed in to change notification settings - Fork 146
Add function to mermaid diagram #1490
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Seeing that tests are failing because of no pydot. Might be able to mock that behavior... |
|
Rather install and test properly. Elemwise isn't a great name though, we have to check the function that extracts the name It should use |
| return "\n".join(mermaid_lines) | ||
|
|
||
|
|
||
| def _color_to_hex(color_name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this a function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mermaid needs hexcolors. the pydotformatter has strings names for colors
Do you like the route of using pydot? Don't think it would be too hard to implement a custom formatter. We just need a light graph representation.
Sure. Shall I change for the formatter then or go a different route? |
|
I think pydot is overkill but I wouldn't mock if you're using it. Your original function was fine except it will iterate some edges repeatedly. I'm sure you can repurpose FunctionGraph or something from |
Can you explain the case that is missing or example that will fail
I will explore. Is dprint logic general enough to support? |
Don't know if it will fail, but a graph like this, you may end up navigating x -> exp_x twice: x = pt.scalar("x")
y = pt.exp(x)
out = y + y * 2
Not sure what you mean. I meant that in those modules you have utilities to iterate over a graph. It's unlikely you have to invent something new for that goal. |
|
The d3viz based on x = pt.scalar("x")
y = pt.exp(x)
out = y + y * 2looks like (I constructed this from hand): graph TD
A["x"]
style A fill:#32CD32
B["Elemwise"]
B@{ shape: rounded }
style B fill:#FF00FF
C["dscalar"]
style C fill:#1E90FF
A --> B
B --> C
Is that as expected? If not, the PydotFormatter is not correct |
|
Recent change has it looking like graph TD
%% Nodes:
n1["Composite"]
n1@{ shape: rounded }
n2["x"]
n2@{ shape: rect }
style n2 fill:#32CD32
n3["out"]
n3@{ shape: rect }
style n3 fill:#1E90FF
%% Edges:
n2 --> n1
n1 --> n3
|
|
Seems like you are compiling/rewriting the graph, otherwise the Composite wouldn't be introduced |
|
Well the d3 stuff only works for |
|
We can make d3 work easily for variables as well no? Just need a FunctionGraph? Also don't narrow anything on Tensorvariables but try to work with any Variables (for instance XTrensorVariables but even stuff like Slice and RNGs) |
Description
I explored using variables directly with is definitely doable. Some helpful functions where:
Using the pydot formatter which already exists in pytensor, can provide some version of this already.
Some examples of this:
graph TD %% Nodes: n1["DimShuffle"] n1@{ shape: rounded } n2["noise"] n2@{ shape: rect } style n2 fill:#32CD32 n4["DimShuffle"] n4@{ shape: rounded } n5["alpha"] n5@{ shape: rect } style n5 fill:#32CD32 n7["Shape_i"] n7@{ shape: rounded } style n7 fill:#00FFFF n8["X"] n8@{ shape: rect } style n8 fill:#32CD32 n8["X"] n8@{ shape: rect } style n8 fill:#32CD32 n10["AllocEmpty"] n10@{ shape: rounded } n12["CGemv"] n12@{ shape: rounded } n13["1.0"] n13@{ shape: rect } style n13 fill:#00FF7F n14["beta"] n14@{ shape: rect } style n14 fill:#32CD32 n15["0.0"] n15@{ shape: rect } style n15 fill:#00FF7F n17["Elemwise"] n17@{ shape: rounded } n18["y"] n18@{ shape: rect } style n18 fill:#1E90FF %% Edges: n2 --> n1 n5 --> n4 n8 --> n7 n7 --> n10 n10 --> n12 n13 --> n12 n8 --> n12 n14 --> n12 n15 --> n12 n12 --> n17 n4 --> n17 n1 --> n17 n17 --> n18graph TD %% Nodes: n1["OpFromGraph"] n1@{ shape: rounded } n2["x"] n2@{ shape: rect } style n2 fill:#32CD32 n3["y"] n3@{ shape: rect } style n3 fill:#32CD32 n4["z"] n4@{ shape: rect } style n4 fill:#32CD32 n4["z"] n4@{ shape: rect } style n4 fill:#32CD32 n6["Elemwise"] n6@{ shape: rounded } n7["dscalar"] n7@{ shape: rect } style n7 fill:#1E90FF %% Edges: n2 --> n1 n3 --> n1 n4 --> n1 n1 --> n6 n4 --> n6 n6 --> n7Related Issue
Checklist
Type of change