Skip to content

Commit 0b01ffb

Browse files
committed
feat(std): add function operator
1 parent 215d042 commit 0b01ffb

File tree

3 files changed

+300
-0
lines changed

3 files changed

+300
-0
lines changed
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
#ifndef FUNCTION_H
2+
#define FUNCTION_H
3+
4+
#include <algorithm>
5+
#include <cmath>
6+
#include <stdexcept>
7+
#include <string>
8+
#include <utility>
9+
#include <vector>
10+
11+
#include "rtbot/Operator.h"
12+
13+
namespace rtbot {
14+
15+
template <class T, class V>
16+
struct Function : public Operator<T, V> {
17+
enum class InterpolationType { LINEAR, HERMITE };
18+
19+
Function() = default;
20+
21+
Function(string const& id, vector<pair<V, V>> points, string type = "linear") : Operator<T, V>(id) {
22+
if (points.size() < 2) {
23+
throw runtime_error(typeName() + ": at least 2 points are required for interpolation");
24+
}
25+
26+
sort(points.begin(), points.end());
27+
this->points = points;
28+
29+
if (type == "linear") {
30+
this->type = InterpolationType::LINEAR;
31+
} else if (type == "hermite") {
32+
this->type = InterpolationType::HERMITE;
33+
} else {
34+
throw runtime_error(typeName() + ": invalid interpolation type. Use 'linear' or 'hermite'");
35+
}
36+
37+
this->addDataInput("i1", 1);
38+
this->addOutput("o1");
39+
}
40+
41+
virtual Bytes collect() override {
42+
Bytes bytes = Operator<T, V>::collect();
43+
44+
// Serialize points
45+
size_t pointsSize = points.size();
46+
bytes.insert(bytes.end(), reinterpret_cast<const unsigned char*>(&pointsSize),
47+
reinterpret_cast<const unsigned char*>(&pointsSize) + sizeof(pointsSize));
48+
49+
for (const auto& point : points) {
50+
bytes.insert(bytes.end(), reinterpret_cast<const unsigned char*>(&point.first),
51+
reinterpret_cast<const unsigned char*>(&point.first) + sizeof(point.first));
52+
bytes.insert(bytes.end(), reinterpret_cast<const unsigned char*>(&point.second),
53+
reinterpret_cast<const unsigned char*>(&point.second) + sizeof(point.second));
54+
}
55+
56+
// Serialize interpolation type
57+
auto typeVal = static_cast<int>(type);
58+
bytes.insert(bytes.end(), reinterpret_cast<const unsigned char*>(&typeVal),
59+
reinterpret_cast<const unsigned char*>(&typeVal) + sizeof(typeVal));
60+
61+
return bytes;
62+
}
63+
64+
virtual void restore(Bytes::const_iterator& it) override {
65+
Operator<T, V>::restore(it);
66+
67+
// Deserialize points
68+
size_t pointsSize = *reinterpret_cast<const size_t*>(&(*it));
69+
it += sizeof(pointsSize);
70+
71+
points.clear();
72+
for (size_t i = 0; i < pointsSize; i++) {
73+
V x = *reinterpret_cast<const V*>(&(*it));
74+
it += sizeof(V);
75+
V y = *reinterpret_cast<const V*>(&(*it));
76+
it += sizeof(V);
77+
points.push_back({x, y});
78+
}
79+
80+
// Deserialize interpolation type
81+
int typeVal = *reinterpret_cast<const int*>(&(*it));
82+
it += sizeof(typeVal);
83+
type = static_cast<InterpolationType>(typeVal);
84+
}
85+
86+
string typeName() const override { return "Function"; }
87+
88+
OperatorMessage<T, V> processData() override {
89+
string inputPort;
90+
auto in = this->getDataInputs();
91+
if (in.size() == 1)
92+
inputPort = in.at(0);
93+
else
94+
throw runtime_error(typeName() + " : more than 1 input port found");
95+
96+
OperatorMessage<T, V> outputMsgs;
97+
Message<T, V> msg = this->getDataInputLastMessage(inputPort);
98+
V x = msg.value;
99+
100+
size_t i = 0;
101+
while (i < points.size() - 1 && points[i + 1].first <= x) {
102+
i++;
103+
}
104+
105+
if (type == InterpolationType::LINEAR) {
106+
if (i == 0 && x < points[0].first) {
107+
msg.value = linearInterpolate(points[0].first, points[0].second, points[1].first, points[1].second, x);
108+
} else if (i == points.size() - 1) {
109+
msg.value = linearInterpolate(points[i - 1].first, points[i - 1].second, points[i].first, points[i].second, x);
110+
} else {
111+
msg.value = linearInterpolate(points[i].first, points[i].second, points[i + 1].first, points[i + 1].second, x);
112+
}
113+
} else {
114+
if (i == 0 && x < points[0].first) {
115+
msg.value = hermiteInterpolate(points[0].second, points[1].second, getTangent(0), getTangent(1),
116+
(x - points[0].first) / (points[1].first - points[0].first));
117+
} else if (i >= points.size() - 2) {
118+
size_t last = points.size() - 1;
119+
msg.value =
120+
hermiteInterpolate(points[last - 1].second, points[last].second, getTangent(last - 1), getTangent(last),
121+
(x - points[last - 1].first) / (points[last].first - points[last - 1].first));
122+
} else {
123+
msg.value = hermiteInterpolate(points[i].second, points[i + 1].second, getTangent(i), getTangent(i + 1),
124+
(x - points[i].first) / (points[i + 1].first - points[i].first));
125+
}
126+
}
127+
128+
PortMessage<T, V> v;
129+
v.push_back(msg);
130+
outputMsgs.emplace("o1", v);
131+
return outputMsgs;
132+
}
133+
134+
vector<pair<V, V>> getPoints() const { return points; }
135+
string getInterpolationType() const { return type == InterpolationType::LINEAR ? "linear" : "hermite"; }
136+
137+
private:
138+
vector<pair<V, V>> points;
139+
InterpolationType type;
140+
141+
static V linearInterpolate(V x1, V y1, V x2, V y2, V x) { return y1 + (y2 - y1) * (x - x1) / (x2 - x1); }
142+
143+
static V hermiteInterpolate(V y0, V y1, V m0, V m1, V mu) {
144+
V mu2 = mu * mu;
145+
V mu3 = mu2 * mu;
146+
V h00 = 2 * mu3 - 3 * mu2 + 1;
147+
V h10 = mu3 - 2 * mu2 + mu;
148+
V h01 = -2 * mu3 + 3 * mu2;
149+
V h11 = mu3 - mu2;
150+
151+
return h00 * y0 + h10 * m0 + h01 * y1 + h11 * m1;
152+
}
153+
154+
V getTangent(size_t i) const {
155+
if (i == 0) {
156+
return (points[1].second - points[0].second) / (points[1].first - points[0].first);
157+
} else if (i == points.size() - 1) {
158+
return (points[i].second - points[i - 1].second) / (points[i].first - points[i - 1].first);
159+
} else {
160+
return (points[i + 1].second - points[i - 1].second) / (points[i + 1].first - points[i - 1].first);
161+
}
162+
}
163+
};
164+
165+
} // namespace rtbot
166+
167+
#endif // FUNCTION_H
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
---
2+
behavior:
3+
buffered: false
4+
throughput: constant
5+
view:
6+
shape: circle
7+
latex:
8+
template: |
9+
f(x)
10+
jsonschema:
11+
type: object
12+
properties:
13+
id:
14+
type: string
15+
description: The id of the operator
16+
points:
17+
type: array
18+
description: Array of (x,y) coordinates defining the function
19+
minItems: 2
20+
items:
21+
type: array
22+
minItems: 2
23+
maxItems: 2
24+
items:
25+
type: number
26+
type:
27+
type: string
28+
enum: ["linear", "hermite"]
29+
default: "linear"
30+
description: Interpolation method to use
31+
required: ["id", "points"]
32+
---
33+
34+
# Function
35+
36+
Inputs: `i1`
37+
Outputs: `o1`
38+
39+
The `Function` operator transforms input values through a user-defined function specified by a set of points. It supports both linear and Hermite interpolation methods.
40+
41+
For input values between defined points, the operator performs interpolation according to the specified method:
42+
43+
- `linear`: Simple linear interpolation between adjacent points
44+
- `hermite`: Smooth cubic Hermite interpolation using estimated tangents at each point
45+
46+
For input values outside the defined range, the operator extrapolates using the same method.
47+
48+
The `Function` operator does not hold a message buffer on `i1`. It emits transformed values through `o1` immediately after receiving input.
49+
50+
Example usage for linear interpolation between points (0,0) and (1,1):
51+
52+
```cpp
53+
vector<pair<double, double>> points = {{0.0, 0.0}, {1.0, 1.0}};
54+
auto func = Function<uint64_t, double>("func", points, "linear");
55+
```

libs/std/test/test_function.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#include <catch2/catch.hpp>
2+
#include <cmath>
3+
4+
#include "rtbot/std/Function.h"
5+
6+
using namespace rtbot;
7+
using namespace std;
8+
9+
TEST_CASE("Function operator - Linear interpolation") {
10+
vector<pair<double, double>> points = {{0.0, 0.0}, {1.0, 2.0}, {2.0, 4.0}, {3.0, 6.0}};
11+
auto func = Function<uint64_t, double>("func", points, "linear");
12+
13+
SECTION("Interpolation between points") {
14+
func.receiveData(Message<uint64_t, double>(1, 0.5));
15+
auto output = func.executeData();
16+
REQUIRE(output.find("func")->second.find("o1")->second.at(0).value == Approx(1.0));
17+
18+
func.receiveData(Message<uint64_t, double>(2, 1.5));
19+
output = func.executeData();
20+
REQUIRE(output.find("func")->second.find("o1")->second.at(0).value == Approx(3.0));
21+
}
22+
23+
SECTION("Extrapolation before first point") {
24+
func.receiveData(Message<uint64_t, double>(1, -1.0));
25+
auto output = func.executeData();
26+
REQUIRE(output.find("func")->second.find("o1")->second.at(0).value == Approx(-2.0));
27+
}
28+
29+
SECTION("Extrapolation after last point") {
30+
func.receiveData(Message<uint64_t, double>(1, 4.0));
31+
auto output = func.executeData();
32+
REQUIRE(output.find("func")->second.find("o1")->second.at(0).value == Approx(8.0));
33+
}
34+
}
35+
36+
TEST_CASE("Function operator - Hermite interpolation") {
37+
vector<pair<double, double>> points = {{0.0, 0.0}, {1.0, 1.0}, {2.0, 0.0}, {3.0, 1.0}};
38+
auto func = Function<uint64_t, double>("func", points, "hermite");
39+
40+
SECTION("Interpolation between points") {
41+
func.receiveData(Message<uint64_t, double>(1, 0.5));
42+
auto output = func.executeData();
43+
REQUIRE(output.find("func")->second.find("o1")->second.at(0).value > 0.5);
44+
45+
func.receiveData(Message<uint64_t, double>(2, 1.5));
46+
output = func.executeData();
47+
double y = output.find("func")->second.find("o1")->second.at(0).value;
48+
REQUIRE(y >= 0.0);
49+
REQUIRE(y <= 1.0);
50+
}
51+
}
52+
53+
TEST_CASE("Function operator - Serialization") {
54+
vector<pair<double, double>> points = {{0.0, 0.0}, {1.0, 2.0}, {2.0, 4.0}, {3.0, 6.0}};
55+
auto func1 = Function<uint64_t, double>("func", points, "linear");
56+
57+
// Add some data and process it
58+
func1.receiveData(Message<uint64_t, double>(1, 1.5));
59+
auto output1 = func1.executeData();
60+
61+
// Serialize
62+
Bytes bytes = func1.collect();
63+
64+
// Create new operator and restore state
65+
auto func2 = Function<uint64_t, double>("func", {{0.0, 0.0}, {1.0, 1.0}}, "linear"); // Different initial state
66+
Bytes::const_iterator it = bytes.begin();
67+
func2.restore(it);
68+
69+
// Verify state was properly restored
70+
REQUIRE(func2.getPoints() == points);
71+
REQUIRE(func2.getInterpolationType() == "linear");
72+
73+
// Verify behavior is identical
74+
func2.receiveData(Message<uint64_t, double>(1, 1.5));
75+
auto output2 = func2.executeData();
76+
REQUIRE(output2.find("func")->second.find("o1")->second.at(0).value ==
77+
output1.find("func")->second.find("o1")->second.at(0).value);
78+
}

0 commit comments

Comments
 (0)