Skip to content

Commit b9e8418

Browse files
committed
feat(api): add support for nested prototypes
1 parent 813aa03 commit b9e8418

File tree

2 files changed

+125
-31
lines changed

2 files changed

+125
-31
lines changed

libs/api/include/rtbot/Prototype.h

Lines changed: 52 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -38,47 +38,58 @@ class PrototypeHandler {
3838
return;
3939
}
4040

41-
auto prototypes = parse_prototypes(program_json["prototypes"]);
41+
// Parse and store all prototypes
42+
std::map<std::string, PrototypeDefinition> prototypes = parse_prototypes(program_json["prototypes"]);
43+
4244
json expanded_operators = json::array();
4345
json expanded_connections = json::array();
4446
std::map<std::string, std::pair<std::string, std::string>> instance_mappings;
4547

4648
// Process original operators
4749
for (const auto& op : program_json["operators"]) {
4850
if (op.contains("prototype")) {
49-
const auto& proto = prototypes[op["prototype"].get<std::string>()];
51+
// Handle prototype instance
52+
const auto& proto = prototypes.at(op["prototype"].get<std::string>());
53+
const std::string instance_id = op["id"].get<std::string>();
5054

51-
// Expand prototype instance
52-
expand_prototype_instance(op["id"].get<std::string>(), proto, op.value("parameters", json::object()),
53-
expanded_operators, expanded_connections, instance_mappings);
55+
expand_prototype_instance(instance_id, proto, op.value("parameters", json::object()), expanded_operators,
56+
expanded_connections, instance_mappings, prototypes);
5457
} else {
58+
// Regular operator - add as-is
5559
expanded_operators.push_back(op);
5660
}
5761
}
5862

5963
// Process original connections
6064
for (const auto& conn : program_json["connections"]) {
61-
std::string from_id = conn["from"];
62-
std::string to_id = conn["to"];
63-
64-
json remapped_conn = conn;
65-
66-
// Map prototype instance connections
67-
if (instance_mappings.count(from_id)) {
68-
remapped_conn["from"] = instance_mappings[from_id].second;
69-
remapped_conn["fromPort"] = conn.value("fromPort", "o1");
70-
}
71-
if (instance_mappings.count(to_id)) {
72-
remapped_conn["to"] = instance_mappings[to_id].first;
73-
remapped_conn["toPort"] = conn.value("toPort", "i1");
65+
bool processed = false;
66+
67+
// Check if connection involves prototype instances
68+
for (const auto& [instance_id, mapping] : instance_mappings) {
69+
if (conn["from"] == instance_id) {
70+
json new_conn = conn;
71+
new_conn["from"] = mapping.second; // Use output port
72+
expanded_connections.push_back(new_conn);
73+
processed = true;
74+
}
75+
if (conn["to"] == instance_id) {
76+
json new_conn = conn;
77+
new_conn["to"] = mapping.first; // Use entry port
78+
expanded_connections.push_back(new_conn);
79+
processed = true;
80+
}
7481
}
7582

76-
expanded_connections.push_back(remapped_conn);
83+
// If not involving prototypes, add as-is
84+
if (!processed) {
85+
expanded_connections.push_back(conn);
86+
}
7787
}
7888

79-
// Update program with expanded operators and connections
89+
// Update program JSON
8090
program_json["operators"] = expanded_operators;
8191
program_json["connections"] = expanded_connections;
92+
program_json.erase("prototypes");
8293
}
8394

8495
private:
@@ -129,32 +140,42 @@ class PrototypeHandler {
129140
}
130141
}
131142

132-
static void expand_prototype_instance(const std::string& instance_id, const PrototypeDefinition& proto,
133-
const json& params, json& expanded_operators, json& expanded_connections,
134-
std::map<std::string, std::pair<std::string, std::string>>& instance_mappings) {
143+
static void expand_prototype_instance(
144+
const std::string& instance_id, const PrototypeDefinition& proto, const json& params, json& expanded_operators,
145+
json& expanded_connections, std::map<std::string, std::pair<std::string, std::string>>& instance_mappings,
146+
const std::map<std::string, PrototypeDefinition>& all_prototypes // Add prototypes map
147+
) {
135148
json resolved_params = resolve_parameters(instance_id, params, proto.parameters);
136149

137-
// Create a copy of operators for modification
138150
json instance_operators = proto.operators;
139151

140-
// First resolve parameters in all operators, including nested Pipelines
152+
// First pass: Resolve parameters
141153
resolve_pipeline_operators(instance_operators, resolved_params);
142154

143-
// Expand operators with scoped IDs
155+
// Second pass: Handle nested prototypes
144156
for (auto& op : instance_operators) {
145-
std::string local_id = op["id"];
146-
op["id"] = instance_id + "::" + local_id;
147-
expanded_operators.push_back(op);
157+
if (op.contains("prototype")) {
158+
// Recursively expand nested prototype
159+
const auto& nested_proto = all_prototypes.at(op["prototype"].get<std::string>());
160+
const std::string nested_id = instance_id + "::" + op["id"].get<std::string>();
161+
162+
expand_prototype_instance(nested_id, nested_proto, op.value("parameters", json::object()), expanded_operators,
163+
expanded_connections, instance_mappings, all_prototypes);
164+
} else {
165+
// Regular operator - add with scoped ID
166+
std::string local_id = op["id"];
167+
op["id"] = instance_id + "::" + local_id;
168+
expanded_operators.push_back(op);
169+
}
148170
}
149171

150-
// Expand connections
172+
// Expand connections with scoped IDs
151173
for (auto conn : proto.connections) {
152174
conn["from"] = instance_id + "::" + conn["from"].get<std::string>();
153175
conn["to"] = instance_id + "::" + conn["to"].get<std::string>();
154176
expanded_connections.push_back(conn);
155177
}
156178

157-
// Store mappings
158179
instance_mappings[instance_id] =
159180
std::make_pair(instance_id + "::" + proto.entry.operator_id, instance_id + "::" + proto.output.operator_id);
160181
}

libs/api/test/test_program.cpp

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,4 +1030,77 @@ SCENARIO("Program handles nested Pipeline prototypes correctly", "[program][prot
10301030
REQUIRE_THROWS_WITH(Program(invalid_json), Catch::Contains("Unknown parameter reference '${nonexistent_param}'"));
10311031
}
10321032
}
1033+
}
1034+
1035+
SCENARIO("Program with nested prototypes", "[program]") {
1036+
std::string program_json = R"({
1037+
"prototypes": {
1038+
"simpleMA": {
1039+
"parameters": [
1040+
{"name": "window", "type": "number"}
1041+
],
1042+
"operators": [
1043+
{"type": "MovingAverage", "id": "ma", "window_size": "${window}"}
1044+
],
1045+
"connections": [],
1046+
"entry": {"operator": "ma"},
1047+
"output": {"operator": "ma"}
1048+
},
1049+
"dualMA": {
1050+
"parameters": [
1051+
{"name": "window1", "type": "number"},
1052+
{"name": "window2", "type": "number"}
1053+
],
1054+
"operators": [
1055+
{"id": "ma1", "prototype": "simpleMA", "parameters": {"window": "${window1}"}},
1056+
{"id": "ma2", "prototype": "simpleMA", "parameters": {"window": "${window2}"}}
1057+
],
1058+
"connections": [
1059+
{"from": "ma1::ma", "to": "ma2::ma"}
1060+
],
1061+
"entry": {"operator": "ma1::ma"},
1062+
"output": {"operator": "ma2::ma"}
1063+
}
1064+
},
1065+
"operators": [
1066+
{"type": "Input", "id": "input1", "portTypes": ["number"]},
1067+
{"id": "nested_ma", "prototype": "dualMA", "parameters": {"window1": 2, "window2": 3}},
1068+
{"type": "Output", "id": "output1", "portTypes": ["number"]}
1069+
],
1070+
"connections": [
1071+
{"from": "input1", "to": "nested_ma::ma1::ma"},
1072+
{"from": "nested_ma::ma2::ma", "to": "output1"}
1073+
],
1074+
"entryOperator": "input1",
1075+
"output": {
1076+
"output1": ["o1"]
1077+
}
1078+
})";
1079+
1080+
Program program(program_json);
1081+
WHEN("Processing messages through nested moving averages") {
1082+
ProgramMsgBatch batch;
1083+
std::vector<double> results;
1084+
1085+
// Send messages one by one
1086+
batch = program.receive(Message<NumberData>(1, NumberData{2.0}));
1087+
REQUIRE(batch.empty()); // First MA collecting
1088+
1089+
batch = program.receive(Message<NumberData>(2, NumberData{4.0}));
1090+
REQUIRE(batch.empty()); // First MA starts outputting, second MA collecting
1091+
1092+
batch = program.receive(Message<NumberData>(3, NumberData{6.0}));
1093+
REQUIRE(batch.empty()); // Second MA still collecting
1094+
1095+
batch = program.receive(Message<NumberData>(4, NumberData{8.0}));
1096+
REQUIRE(batch.size() == 1);
1097+
REQUIRE(batch["output1"]["o1"].size() == 1);
1098+
auto* msg = dynamic_cast<const Message<NumberData>*>(batch["output1"]["o1"].back().get());
1099+
REQUIRE(msg->data.value == Approx(5.0)); // First output from nested MAs
1100+
1101+
batch = program.receive(Message<NumberData>(5, NumberData{10.0}));
1102+
REQUIRE(batch.size() == 1);
1103+
msg = dynamic_cast<const Message<NumberData>*>(batch["output1"]["o1"].back().get());
1104+
REQUIRE(msg->data.value == Approx(7.0));
1105+
}
10331106
}

0 commit comments

Comments
 (0)