Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ jobs:
python Scripts/generateshader.py ../resources/Materials/Examples/StandardSurface --target mdl
python Scripts/generateshader.py ../resources/Materials/Examples/StandardSurface --target msl
python Scripts/generateshader.py ../resources/Materials/Examples/StandardSurface --target slang
python Scripts/generateshader.py ../resources/Materials/Examples/StandardSurface --graph --graphValues
working-directory: python

- name: Shader Validation Tests (Windows)
Expand Down
13 changes: 13 additions & 0 deletions python/Scripts/generateshader.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def main():
parser.add_argument('--validatorArgs', dest='validatorArgs', nargs='?', const=' ', type=str, help='Optional arguments for code validator.')
parser.add_argument('--vulkanGlsl', dest='vulkanCompliantGlsl', default=False, type=bool, help='Set to True to generate Vulkan-compliant GLSL when using the genglsl target.')
parser.add_argument('--shaderInterfaceType', dest='shaderInterfaceType', default=0, type=int, help='Set the type of shader interface to be generated')
parser.add_argument('--graph', dest='graph', action='store_true', help='Set to True to generate a Mermaid graph of the shader graph for each shader.')
parser.add_argument('--graphValues', dest='graphValues', action='store_true', help='Set to True to output input values for Mermaid graphs.')
parser.add_argument(dest='inputFilename', help='Path to input document or folder containing input documents.')
opts = parser.parse_args()

Expand Down Expand Up @@ -156,6 +158,17 @@ def main():
elemName = mx.createValidName(elemName)
shader = shadergen.generate(elemName, elem, context)
if shader:
# Generate a Mermaid graph of the shader graph if requested
if opts.graph:
graphValues = opts.graphValues == True
mermaidGraph = shader.createMermaidGraph(graphValues)
mermaidGraph = "```mermaid\n" + mermaidGraph + "\n```"
filename = pathPrefix + "/" + shader.getName() + "." + gentarget + ".md"
print('--- Wrote Mermaid graph to: ' + filename)
file = open(filename, 'w+')
file.write(mermaidGraph)
file.close()

# Use extension of .vert and .frag as it's type is
# recognized by glslangValidator
if gentarget in ['glsl', 'essl', 'vulkan', 'msl', 'wgsl']:
Expand Down
1 change: 1 addition & 0 deletions source/JsMaterialX/JsMaterialXGenShader/JsShader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ EMSCRIPTEN_BINDINGS(Shader)
.smart_ptr<std::shared_ptr<mx::Shader>>("ShaderPtr")
.function("getSourceCode", &mx::Shader::getSourceCode)
.function("getStage", PTR_RETURN_OVERLOAD(mx::ShaderStage& (mx::Shader::*)(const std::string&), &mx::Shader::getStage), ems::allow_raw_pointers())
.function("createMermaidGraph", &mx::Shader::createMermaidGraph)
;
}
5 changes: 5 additions & 0 deletions source/MaterialXGenShader/Shader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,4 +67,9 @@ ShaderStagePtr Shader::createStage(const string& name, ConstSyntaxPtr syntax)
return s;
}

string Shader::createMermaidGraph(bool showInputValues) const
{
return _graph ? _graph->createMermaidGraph(showInputValues) : string();
}

MATERIALX_NAMESPACE_END
5 changes: 5 additions & 0 deletions source/MaterialXGenShader/Shader.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ class MX_GENSHADER_API Shader
/// Return the shader source code for a given shader stage.
const string& getSourceCode(const string& stage = Stage::PIXEL) const { return getStage(stage).getSourceCode(); }

/// Generate a Mermaid graph representing the shader graph and its nodes.
/// @param showInputValues If true, the graph will include the values of any unconnected input sockets
/// @return Mermaid graph as a string.
string createMermaidGraph(bool showInputValues) const;

protected:
/// Create a new stage in the shader.
ShaderStagePtr createStage(const string& name, ConstSyntaxPtr syntax);
Expand Down
52 changes: 52 additions & 0 deletions source/MaterialXGenShader/ShaderGraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,58 @@ void ShaderGraph::populateUnitTransformMap(UnitSystemPtr unitSystem, ShaderPort*
}
}

string ShaderGraph::createMermaidGraph(bool showInputValues) const
{
std::ostringstream oss;
oss << "graph LR\n";

// Emit nodes
for (const ShaderNode* node : getNodes())
{
const string& id = node->getUniqueId();
oss << " " << id << "[\"" << id << "\"]\n";
}

for (const ShaderNode* node : getNodes())
{
const string& nodeId = node->getUniqueId();

// Emit connections
for (const ShaderOutput* output : node->getOutputs())
{
const ShaderInputVec& connections = output->getConnections();
for (const ShaderInput* input : connections)
{
if (input && input->getNode())
{
const string connector = " --\"" + output->getName() + " --> " + input->getName() + "\"--> ";
oss << " " << nodeId << connector << input->getNode()->getUniqueId() << "\n";
}
}
}

// Optionally emit input values
if (showInputValues)
{
for (const ShaderInput* input : node->getInputs())
{
if (input)
{
const string& valueString = input->getValueString();
if (!valueString.empty())
{
oss << " " << node->getUniqueId() << "/" << input->getName();
oss << "[\"" << input->getName() << " = " << valueString << "\"]";
oss << " --> " << nodeId << "\n";
}
}
}
}
}

return oss.str();
}

namespace
{
static const ShaderGraphEdgeIterator NULL_EDGE_ITERATOR(nullptr);
Expand Down
5 changes: 5 additions & 0 deletions source/MaterialXGenShader/ShaderGraph.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ class MX_GENSHADER_API ShaderGraph : public ShaderNode
/// Get a vector of all nodes in order
const vector<ShaderNode*>& getNodes() const { return _nodeOrder; }

/// Generate a Mermaid graph representing this shader graph.
/// @param showInputValues If true, the graph will include the values of any unconnected input sockets
/// @return Mermaid graph as a string.
string createMermaidGraph(bool showInputValues) const;

/// Get number of input sockets
size_t numInputSockets() const { return numOutputs(); }

Expand Down
3 changes: 2 additions & 1 deletion source/PyMaterialX/PyMaterialXGenShader/PyShader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ void bindPyShader(py::module& mod)
.def("hasAttribute", &mx::Shader::hasAttribute)
.def("getAttribute", &mx::Shader::getAttribute)
.def("setAttribute", static_cast<void (mx::Shader::*)(const std::string&)>(&mx::Shader::setAttribute))
.def("setAttribute", static_cast<void (mx::Shader::*)(const std::string&, mx::ValuePtr)>(&mx::Shader::setAttribute));
.def("setAttribute", static_cast<void (mx::Shader::*)(const std::string&, mx::ValuePtr)>(&mx::Shader::setAttribute))
.def("createMermaidGraph", &mx::Shader::createMermaidGraph);
}