|
16 | 16 | #include "fields/division.h"
|
17 | 17 | #include "fields/exponential_function.h"
|
18 | 18 | #include "fields/summation.h"
|
| 19 | +#include "fields/field_activation_function.h" |
19 | 20 |
|
20 | 21 |
|
21 | 22 | // ----------------
|
@@ -68,6 +69,10 @@ PYBIND11_MODULE(aika, m)
|
68 | 69 | // Bind Summation (inherits from AbstractFunctionDefinition)
|
69 | 70 | py::class_<Summation, AbstractFunctionDefinition>(m, "Summation");
|
70 | 71 |
|
| 72 | + // Bind FieldActivationFunction (inherits from AbstractFunctionDefinition) |
| 73 | + py::class_<FieldActivationFunction, AbstractFunctionDefinition>(m, "FieldActivationFunction") |
| 74 | + .def(py::init<Type*, const std::string&, ActivationFunction*, double>()); |
| 75 | + |
71 | 76 | py::class_<InputField, FieldDefinition>(m, "InputField")
|
72 | 77 | .def(py::init<Type*, const std::string &>())
|
73 | 78 | .def("__str__", [](const InputField &f) {
|
@@ -120,6 +125,14 @@ PYBIND11_MODULE(aika, m)
|
120 | 125 | const_cast<Type*>(&ref),
|
121 | 126 | name
|
122 | 127 | );
|
| 128 | + }, py::return_value_policy::reference_internal) |
| 129 | + .def("fieldActivationFunc", [](const Type &ref, const std::string &name, ActivationFunction* actFunction, double tolerance) { |
| 130 | + return new FieldActivationFunction( |
| 131 | + const_cast<Type*>(&ref), |
| 132 | + name, |
| 133 | + actFunction, |
| 134 | + tolerance |
| 135 | + ); |
123 | 136 | }, py::return_value_policy::reference_internal);
|
124 | 137 |
|
125 | 138 | py::class_<Obj>(m, "Obj")
|
|
0 commit comments