diff --git a/rosidl_runtime_py/rosidl_runtime_py/__init__.py b/rosidl_runtime_py/rosidl_runtime_py/__init__.py index 5f40b2ed..a047d61c 100644 --- a/rosidl_runtime_py/rosidl_runtime_py/__init__.py +++ b/rosidl_runtime_py/rosidl_runtime_py/__init__.py @@ -15,10 +15,12 @@ from .convert import message_to_csv from .convert import message_to_ordereddict from .convert import message_to_yaml +from .import_message import import_message_type from .set_message import set_message_fields __all__ = [ + 'import_message_type', 'message_to_csv', 'message_to_ordereddict', 'message_to_yaml', diff --git a/rosidl_runtime_py/rosidl_runtime_py/import_message.py b/rosidl_runtime_py/rosidl_runtime_py/import_message.py new file mode 100644 index 00000000..24e97b7b --- /dev/null +++ b/rosidl_runtime_py/rosidl_runtime_py/import_message.py @@ -0,0 +1,33 @@ +# Copyright 2017-2019 Open Source Robotics Foundation, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib + + +def import_message_type(topic_name: str, message_type: str): + # TODO(dirk-thomas) this logic should come from a rosidl related package + try: + package_name, message_name = message_type.split('/', 2) + if not package_name or not message_name: + raise ValueError() + except ValueError: + raise RuntimeError('The passed message type is invalid') + + # TODO(sloretz) node API to get topic types should indicate if action or msg + middle_module = 'msg' + if topic_name.endswith('/_action/feedback'): + middle_module = 'action' + + module = importlib.import_module(package_name + '.' + middle_module) + return getattr(module, message_name) diff --git a/rosidl_runtime_py/rosidl_runtime_py/set_message.py b/rosidl_runtime_py/rosidl_runtime_py/set_message.py index cfb36563..91b372d1 100644 --- a/rosidl_runtime_py/rosidl_runtime_py/set_message.py +++ b/rosidl_runtime_py/rosidl_runtime_py/set_message.py @@ -15,6 +15,8 @@ from typing import Any from typing import Dict +from rosidl_runtime_py.import_message import import_message_type + def set_message_fields(msg: Any, values: Dict[str, str]) -> None: """ @@ -33,4 +35,17 @@ def set_message_fields(msg: Any, values: Dict[str, str]) -> None: except TypeError: value = field_type() set_message_fields(value, field_value) + f_type = msg.get_fields_and_field_types()[field_name] + # Check if field is an array of ROS message types + if f_type.find('/') != -1: + if isinstance(field_type(), list): + # strip the 'sequence<' prefix if it is a sequence + if f_type.startswith('sequence'): + f_type = f_type[len('sequence') + 1:] + f_type = f_type[:f_type.rfind('[')] + field_elem_type = import_message_type('', f_type) + for n in range(len(value)): + submsg = field_elem_type() + set_message_fields(submsg, value[n]) + value[n] = submsg setattr(msg, field_name, value) diff --git a/rosidl_runtime_py/test/rosidl_runtime_py/test_set_message.py b/rosidl_runtime_py/test/rosidl_runtime_py/test_set_message.py index a74a6d92..4f178d65 100644 --- a/rosidl_runtime_py/test/rosidl_runtime_py/test_set_message.py +++ b/rosidl_runtime_py/test/rosidl_runtime_py/test_set_message.py @@ -91,3 +91,39 @@ def test_set_message_fields_invalid(): invalid_type['int32_value'] = 'this is not an integer' with pytest.raises(ValueError): set_message_fields(msg, invalid_type) + + +def test_set_nested_namespaced_fields(): + unbounded_sequence_msg = message_fixtures.get_msg_unbounded_sequences()[1] + test_values = { + 'basic_types_values': [ + {'float64_value': 42.42, 'int8_value': 42}, + {'float64_value': 11.11, 'int8_value': 11} + ] + } + set_message_fields(unbounded_sequence_msg, test_values) + assert unbounded_sequence_msg.basic_types_values[0].float64_value == 42.42 + assert unbounded_sequence_msg.basic_types_values[0].int8_value == 42 + assert unbounded_sequence_msg.basic_types_values[0].uint8_value == 0 + assert unbounded_sequence_msg.basic_types_values[1].float64_value == 11.11 + assert unbounded_sequence_msg.basic_types_values[1].int8_value == 11 + assert unbounded_sequence_msg.basic_types_values[1].uint8_value == 0 + + arrays_msg = message_fixtures.get_msg_arrays()[0] + test_values = { + 'basic_types_values': [ + {'float64_value': 42.42, 'int8_value': 42}, + {'float64_value': 11.11, 'int8_value': 11}, + {'float64_value': 22.22, 'int8_value': 22}, + ] + } + set_message_fields(arrays_msg, test_values) + assert arrays_msg.basic_types_values[0].float64_value == 42.42 + assert arrays_msg.basic_types_values[0].int8_value == 42 + assert arrays_msg.basic_types_values[0].uint8_value == 0 + assert arrays_msg.basic_types_values[1].float64_value == 11.11 + assert arrays_msg.basic_types_values[1].int8_value == 11 + assert arrays_msg.basic_types_values[1].uint8_value == 0 + assert arrays_msg.basic_types_values[2].float64_value == 22.22 + assert arrays_msg.basic_types_values[2].int8_value == 22 + assert arrays_msg.basic_types_values[2].uint8_value == 0