Skip to content

Commit bf3e40c

Browse files
Work on subtraction test
1 parent 6d58225 commit bf3e40c

File tree

6 files changed

+160
-7
lines changed

6 files changed

+160
-7
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ set(SOURCES
2121
src/fields/obj.cpp
2222
src/fields/type.cpp
2323
src/fields/type_registry.cpp
24-
src/fields/type_registry_python.cpp
2524
src/fields/flattened_type.cpp
2625
src/fields/flattened_type_relation.cpp
2726
src/fields/relation.cpp
@@ -34,6 +33,7 @@ set(SOURCES
3433
src/fields/step.cpp
3534
src/fields/test_type.cpp
3635
src/fields/test_object.cpp
36+
src/fields/python_bindings.cpp
3737
src/network/model.cpp
3838
)
3939

README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ cmake --build . --target install
2828
python tests/subtraction-test.py
2929
```
3030

31+
Additionally, ensure that the `parameterized` package is installed for running parameterized tests. You can install it using:
32+
33+
```bash
34+
pip install parameterized
35+
```
36+
3137
**Steps explained:**
3238

3339
1. Activate the Python virtual environment (here, `.venv`) so that Python dependencies and the installation target are set up correctly.

tests/subtraction-test.py renamed to python-tests/subtraction-test.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
import sys
33
import os
4+
from parameterized import parameterized
45

56
# Add the project root to Python's module search path
67
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@@ -9,7 +10,12 @@
910

1011
class MyTestCase(unittest.TestCase):
1112

12-
def testSubtraction(self):
13+
@parameterized.expand([
14+
(0, "linking_pos_0"),
15+
(1, "linking_pos_1"),
16+
(2, "linking_pos_2")
17+
])
18+
def testSubtraction(self, linking_pos, test_name):
1319
print("Module 'aika' was loaded from:", aika.__file__)
1420

1521
TEST_RELATION_FROM = aika.RelationOne(1, "TEST_FROM")
@@ -45,6 +51,23 @@ def testSubtraction(self):
4551
oa = typeA.instantiate()
4652
ob = typeB.instantiate()
4753

54+
if linking_pos == 0:
55+
aika.TestObj.linkObjects(oa, ob)
56+
ob.initFields()
57+
58+
oa.setFieldValue(a, 50.0)
59+
60+
if linking_pos == 1:
61+
aika.TestObj.linkObjects(oa, ob)
62+
ob.initFields()
63+
64+
oa.setFieldValue(b, 20.0)
65+
66+
if linking_pos == 2:
67+
aika.TestObj.linkObjects(oa, ob)
68+
ob.initFields()
69+
70+
self.assertEqual(30.0, ob.getFieldOutput(c).getValue())
4871

4972
if __name__ == '__main__':
5073
unittest.main()
File renamed without changes.

src/fields/flattened_type.cpp

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,17 @@ FlattenedType::FlattenedType(Direction* dir, Type* type, const std::map<FieldDef
4343
}
4444
}
4545

46+
/**
47+
* @brief Creates a flattened type for input direction
48+
*
49+
* This function creates a flattened type representation for the input direction.
50+
* It takes a type and a set of field definitions, and maps the field definitions
51+
* to their corresponding indices in the flattened type.
52+
*
53+
* @param type The type to flatten
54+
* @param fieldDefs The set of field definitions to include in the flattened type
55+
* @return A new FlattenedType object representing the input flattened type
56+
*/
4657
FlattenedType* FlattenedType::createInputFlattenedType(Type* type, const std::set<FieldDefinition*>& fieldDefs) {
4758
std::map<FieldDefinition*, int> fieldMappings;
4859

@@ -60,6 +71,18 @@ FlattenedType* FlattenedType::createInputFlattenedType(Type* type, const std::se
6071
return new FlattenedType(Direction::INPUT, type, fieldMappings, requiredFields.size());
6172
}
6273

74+
/**
75+
* @brief Creates a flattened type for output direction
76+
*
77+
* This function creates a flattened type representation for the output direction.
78+
* It takes a type and a set of field definitions, and maps the field definitions
79+
* to their corresponding indices in the flattened type.
80+
*
81+
* @param type The type to flatten
82+
* @param fieldDefs The set of field definitions to include in the flattened type
83+
* @param inputSide The flattened type of the input side
84+
* @return A new FlattenedType object representing the output flattened type
85+
*/
6386
FlattenedType* FlattenedType::createOutputFlattenedType(Type* type, const std::set<FieldDefinition*>& fieldDefs, FlattenedType* inputSide) {
6487
std::map<FieldDefinition*, int> fieldMappings;
6588
for (const auto& fd : fieldDefs) {
@@ -71,6 +94,16 @@ FlattenedType* FlattenedType::createOutputFlattenedType(Type* type, const std::s
7194
return new FlattenedType(Direction::OUTPUT, type, fieldMappings, inputSide->numberOfFields);
7295
}
7396

97+
/**
98+
* @brief Checks if all elements in a vector are null
99+
*
100+
* This function checks if all elements in a vector are null.
101+
* It iterates through the vector and returns true if all elements are null,
102+
* otherwise it returns false.
103+
*
104+
* @param vec The vector to check
105+
* @return true if all elements are null, otherwise false
106+
*/
74107
template <typename T>
75108
bool isAllNull(const std::vector<T>& vec) {
76109
for (const auto& element : vec) {
@@ -81,6 +114,13 @@ bool isAllNull(const std::vector<T>& vec) {
81114
return true; // If no non-null element is found, return true
82115
}
83116

117+
/**
118+
* @brief Flattens the type hierarchy
119+
*
120+
* This function flattens the type hierarchy by creating a mapping of relations
121+
* to their corresponding flattened type relations. It iterates through all
122+
* relations and types to build the mapping.
123+
*/
84124
void FlattenedType::flatten() {
85125
mapping = new FlattenedTypeRelation**[type->getRelations().size()];
86126

@@ -101,6 +141,17 @@ void FlattenedType::flatten() {
101141
}
102142
}
103143

144+
/**
145+
* @brief Flattens the type hierarchy per type
146+
*
147+
* This function flattens the type hierarchy per type by creating a mapping of
148+
* field links for a given relation and related type. It iterates through all
149+
* field definitions and their corresponding field link definitions to build the mapping.
150+
*
151+
* @param relation The relation to flatten
152+
* @param relatedType The related type to flatten
153+
* @return A new FlattenedTypeRelation object representing the flattened type relation
154+
*/
104155
FlattenedTypeRelation* FlattenedType::flattenPerType(Relation* relation, Type* relatedType) {
105156
std::vector<FieldLinkDefinition*> fieldLinks;
106157

@@ -128,6 +179,15 @@ FlattenedTypeRelation* FlattenedType::flattenPerType(Relation* relation, Type* r
128179
new FlattenedTypeRelation(this, fieldLinks);
129180
}
130181

182+
/**
183+
* @brief Follows links in the flattened type
184+
*
185+
* This function follows links in the flattened type by iterating through all
186+
* relations and their corresponding flattened type relations. It iterates through
187+
* all relations and their corresponding flattened type relations to follow links.
188+
*
189+
* @param field The field to follow links from
190+
*/
131191
void FlattenedType::followLinks(Field* field) {
132192
for (int relationId = 0; relationId < type->getRelations().size(); relationId++) {
133193
auto& ftr = mapping[relationId];
@@ -146,28 +206,76 @@ void FlattenedType::followLinks(Field* field) {
146206
}
147207
}
148208

209+
/**
210+
* @brief Follows links in the flattened type relation
211+
*
212+
* This function follows links in the flattened type relation by iterating through
213+
* all field links and their corresponding field link definitions.
214+
*
215+
* @param ftr The flattened type relation to follow links from
216+
* @param relatedObj The related object to follow links from
217+
* @param field The field to follow links from
218+
*/
149219
void FlattenedType::followLinks(FlattenedTypeRelation* ftr, Obj* relatedObj, Field* field) {
150220
if (ftr != nullptr) {
151221
ftr->followLinks(direction, relatedObj, field);
152222
}
153223
}
154224

225+
/**
226+
* @brief Gets the field index
227+
*
228+
* This function gets the field index by looking up the field definition in the
229+
* flattened type.
230+
*
231+
* @param fd The field definition to get the index of
232+
* @return The index of the field definition
233+
*/
155234
int FlattenedType::getFieldIndex(FieldDefinition* fd) {
156235
return fields[fd->getId()];
157236
}
158237

238+
/**
239+
* @brief Gets the number of fields
240+
*
241+
* This function gets the number of fields in the flattened type.
242+
*
243+
* @return The number of fields in the flattened type
244+
*/
159245
int FlattenedType::getNumberOfFields() const {
160246
return numberOfFields;
161247
}
162248

249+
/**
250+
* @brief Gets the type
251+
*
252+
* This function gets the type in the flattened type.
253+
*
254+
* @return The type in the flattened type
255+
*/
163256
Type* FlattenedType::getType() const {
164257
return type;
165258
}
166259

260+
/**
261+
* @brief Gets the fields reverse
262+
*
263+
* This function gets the fields reverse in the flattened type.
264+
*
265+
* @return The fields reverse in the flattened type
266+
*/
167267
FieldDefinition*** FlattenedType::getFieldsReverse() {
168268
return fieldsReverse;
169269
}
170270

271+
/**
272+
* @brief Gets the field definition by index
273+
*
274+
* This function gets the field definition by index in the flattened type.
275+
*
276+
* @param idx The index of the field definition to get
277+
* @return The field definition by index in the flattened type
278+
*/
171279
FieldDefinition* FlattenedType::getFieldDefinitionIdByIndex(short idx) {
172280
return fieldsReverse[idx][0];
173281
}

src/fields/type_registry_python.cpp renamed to src/fields/python_bindings.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "fields/subtraction.h"
1111
#include "fields/test_type.h"
1212
#include "fields/test_object.h"
13+
#include "fields/field.h"
1314

1415

1516
// ----------------
@@ -64,29 +65,44 @@ PYBIND11_MODULE(aika, m)
6465
const_cast<Type*>(&ref),
6566
name
6667
);
67-
}, py::return_value_policy::take_ownership)
68+
}, py::return_value_policy::reference_internal)
6869
.def("sub", [](const Type &ref, const std::string &name) {
6970
return new Subtraction(
7071
const_cast<Type*>(&ref),
7172
name
7273
);
73-
}, py::return_value_policy::take_ownership);
74+
}, py::return_value_policy::reference_internal);
7475

7576
py::class_<Obj>(m, "Obj")
7677
.def("__str__", [](const Obj &t) {
7778
return t.toString();
78-
});
79+
})
80+
.def("setFieldValue", &Obj::setFieldValue)
81+
.def("getFieldValue", &Obj::getFieldValue)
82+
.def("initFields", &Obj::initFields)
83+
.def("getType", &Obj::getType, py::return_value_policy::reference_internal)
84+
.def("isInstanceOf", &Obj::isInstanceOf)
85+
.def("getFieldOutput", &Obj::getFieldOutput, py::return_value_policy::reference_internal)
86+
.def("getOrCreateFieldInput", &Obj::getOrCreateFieldInput, py::return_value_policy::reference_internal);
7987

8088
py::class_<TestType, Type>(m, "TestType")
8189
.def(py::init<TypeRegistry*, const std::string&>())
82-
.def("instantiate", &TestType::instantiate, py::return_value_policy::take_ownership);
90+
.def("instantiate", &TestType::instantiate, py::return_value_policy::reference_internal);
8391

8492
py::class_<TestObject, Obj>(m, "TestObj")
85-
.def(py::init<TestType*>());
93+
.def(py::init<TestType*>())
94+
.def_static("linkObjects", &TestObject::linkObjects);
8695

8796
py::class_<TypeRegistry>(m, "TypeRegistry")
8897
.def(py::init<>())
8998
.def("getType", &TypeRegistry::getType)
9099
.def("registerType", &TypeRegistry::registerType)
91100
.def("flattenTypeHierarchy", &TypeRegistry::flattenTypeHierarchy);
101+
102+
py::class_<Field>(m, "Field")
103+
.def("getValue", &Field::getValue)
104+
.def("getUpdatedValue", &Field::getUpdatedValue)
105+
.def("__str__", [](const Field &f) {
106+
return f.toString();
107+
});
92108
}

0 commit comments

Comments
 (0)