diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 61e4f36..4c8635a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -131,4 +131,4 @@ jobs: - name: Publish if: startsWith(github.event.release.tag_name, 'v') - uses: pypa/gh-action-pypi-publish@release/v1 + uses: pypa/gh-action-pypi-publish@release/v1 \ No newline at end of file diff --git a/src/smbprotocol/structure.py b/src/smbprotocol/structure.py index 9646ce6..bd8b320 100644 --- a/src/smbprotocol/structure.py +++ b/src/smbprotocol/structure.py @@ -3,6 +3,7 @@ import copy import datetime +import inspect import math import struct import textwrap @@ -148,7 +149,7 @@ def __init__(self, little_endian=True, default=None, size=None): field_type = self.__class__.__name__ self.little_endian = little_endian - if not (size is None or isinstance(size, int) or isinstance(size, types.LambdaType)): + if not (size is None or isinstance(size, int) or callable(size)): raise InvalidFieldDefinition(f"{field_type} size for field must be an int or None for a variable length") self.size = size self.default = default @@ -270,7 +271,7 @@ def _get_calculated_value(self, value): :param value: The value to calculate/expand :return: The final value """ - if isinstance(value, types.LambdaType): + if callable(value): expanded_value = value(self.structure) return self._get_calculated_value(expanded_value) else: @@ -294,7 +295,7 @@ def _get_calculated_size(self, size, data): # is None (last field value) if size is None: return len(data) - elif isinstance(size, types.LambdaType): + elif callable(size): expanded_size = size(self.structure) return self._get_calculated_size(expanded_size, data) else: @@ -309,7 +310,7 @@ def _get_struct_format(self, size, unsigned=True): :param size: The size as an int :return: The struct format specifier for the size specified """ - if isinstance(size, types.LambdaType): + if callable(size): size = size(self.structure) struct_format = {1: "B", 2: "H", 4: "L", 8: "Q"} @@ -345,7 +346,7 @@ def _pack_value(self, value): def _parse_value(self, value): if value is None: int_value = 0 - elif isinstance(value, types.LambdaType): + elif callable(value): int_value = value elif isinstance(value, bytes): struct_string = self._endian_prefix + self._get_struct_format(self.size, self.unsigned) @@ -376,7 +377,7 @@ def _pack_value(self, value): def _parse_value(self, value): if value is None: bytes_value = b"" - elif isinstance(value, types.LambdaType): + elif callable(value): bytes_value = value elif isinstance(value, int): struct_string = self._endian_prefix + self._get_struct_format(self.size) @@ -425,7 +426,7 @@ def __init__(self, list_count=None, list_type=BytesField(), unpack_func=None, ** used when the list contains variable length values. :param kwargs: Any other kwarg to be sent to Field() """ - if list_count is not None and not (isinstance(list_count, int) or isinstance(list_count, types.LambdaType)): + if list_count is not None and not (isinstance(list_count, int) or callable(list_count)): raise InvalidFieldDefinition( "ListField list_count must be an int, lambda, or None for a variable list length" ) @@ -435,7 +436,7 @@ def __init__(self, list_count=None, list_type=BytesField(), unpack_func=None, ** raise InvalidFieldDefinition("ListField list_type must be a Field definition") self.list_type = list_type - if unpack_func is not None and not isinstance(unpack_func, types.LambdaType): + if unpack_func is not None and not callable(unpack_func): raise InvalidFieldDefinition("ListField unpack_func must be a lambda function or None") elif unpack_func is None and (list_count is None or list_type.size is None): raise InvalidFieldDefinition( @@ -456,7 +457,7 @@ def get_value(self): # Override default get_value() so we return a list with the actual # value, not the Field definition list_value = [] - if isinstance(self.value, types.LambdaType): + if callable(self.value): value = self._get_calculated_value(self.value) else: value = self.value @@ -474,9 +475,9 @@ def _pack_value(self, value): def _parse_value(self, value): if value is None: list_value = [] - elif isinstance(value, types.LambdaType): + elif callable(value): return value - elif isinstance(value, bytes) and isinstance(self.unpack_func, types.LambdaType): + elif isinstance(value, bytes) and callable(self.unpack_func): # use the lambda function to parse the bytes to a list list_value = self.unpack_func(self.structure, value) elif isinstance(value, bytes): @@ -527,7 +528,7 @@ def _to_string(self): def _create_list_from_bytes(self, list_count, list_type, value): # calculate the list_count and rerun method if a lambda - if isinstance(list_count, types.LambdaType): + if callable(list_count): list_count = list_count(self.structure) return self._create_list_from_bytes(list_count, list_type, value) @@ -576,7 +577,7 @@ def _pack_value(self, value): def _parse_value(self, value): if value is None: structure_value = b"" - elif isinstance(value, types.LambdaType): + elif callable(value): structure_value = value elif isinstance(value, bytes): structure_value = value @@ -586,7 +587,7 @@ def _parse_value(self, value): raise TypeError(f"Cannot parse value for field {self.name} of type {type(value).__name__} to a structure") if isinstance(structure_value, bytes) and self.structure_type and structure_value != b"": - if isinstance(self.structure_type, types.LambdaType): + if callable(self.structure_type) and not inspect.isclass(self.structure_type): structure_type = self.structure_type(self.structure) else: structure_type = self.structure_type @@ -650,7 +651,7 @@ def _pack_value(self, value): def _parse_value(self, value): if value is None: datetime_value = datetime.datetime.today() - elif isinstance(value, types.LambdaType): + elif callable(value): datetime_value = value elif isinstance(value, bytes): struct_string = self._endian_prefix + self._get_struct_format(8) @@ -713,7 +714,7 @@ def _parse_value(self, value): uuid_value = uuid.UUID(int=value) elif isinstance(value, uuid.UUID): uuid_value = value - elif isinstance(value, types.LambdaType): + elif callable(value): uuid_value = value else: raise TypeError(f"Cannot parse value for field {self.name} of type {type(value).__name__} to a uuid") @@ -823,7 +824,7 @@ def _parse_value(self, value): bool_value = value elif isinstance(value, bytes): bool_value = value == b"\x01" - elif isinstance(value, types.LambdaType): + elif callable(value): bool_value = value else: raise TypeError(f"Cannot parse value for field {self.name} of type {type(value).__name__} to a bool") @@ -854,7 +855,7 @@ def _parse_value(self, value): text_value = to_text(value, encoding=self.encoding) elif isinstance(value, str): text_value = value - elif isinstance(value, types.LambdaType): + elif callable(value): text_value = value else: raise TypeError(f"Cannot parse value for field {self.name} of type {type(value).__name__} to a text string") diff --git a/tests/test_structure.py b/tests/test_structure.py index 7d25a83..9114f2b 100644 --- a/tests/test_structure.py +++ b/tests/test_structure.py @@ -378,6 +378,37 @@ def test_set_lambda(self): assert actual == expected assert len(field) == 4 + def test_set_function(self): + def field_resolver(s): + return 8765 + + structure = self.StructureTest() + field = structure["field"] + field.name = "field" + field.structure = self.StructureTest + field.set_value(field_resolver) + expected = 8765 + actual = field.get_value() + assert isinstance(field.value, types.FunctionType) + assert actual == expected + assert len(field) == 4 + + def test_set_class(self): + class FieldResolver: + def __new__(cls, s): + return 5678 + + structure = self.StructureTest() + field = structure["field"] + field.name = "field" + field.structure = self.StructureTest + field.set_value(FieldResolver) + expected = 5678 + actual = field.get_value() + assert isinstance(field.value, type) + assert actual == expected + assert len(field) == 4 + def test_set_bytes(self): field = self.StructureTest()["field"] field.set_value(b"\x12\x34\x00\x00") @@ -401,6 +432,24 @@ def test_set_invalid(self): field.set_value([]) assert str(exc.value) == "Cannot parse value for field field of type list to an int" + def test_builtin_default(self): + class LengthStructure(Structure): + def __init__(self): + self.fields = OrderedDict( + [ + ("field1", IntField(size=4, default=b"\x01\x03\x05\x07")), + ("field2", IntField(size=2, default=len)), + ] + ) + super().__init__() + + structure = LengthStructure() + field = structure["field2"] + expected = b"\x06\x00" + actual = field.pack() + assert actual == expected + assert structure.pack() == b"\x01\x03\x05\x07\x06\x00" + def test_byte_order(self): class ByteOrderStructure(Structure): def __init__(self): @@ -470,6 +519,37 @@ def test_set_lambda(self): assert actual == expected assert len(field) == 4 + def test_set_function(self): + def field_resolver(s): + return b"\x11\x12\x13\x14" + + structure = self.StructureTest() + field = structure["field"] + field.name = "field" + field.structure = self.StructureTest + field.set_value(field_resolver) + expected = b"\x11\x12\x13\x14" + actual = field.get_value() + assert isinstance(field.value, types.FunctionType) + assert actual == expected + assert len(field) == 4 + + def test_set_class(self): + class FieldResolver: + def __new__(cls, s): + return b"\x16\x17\x18\x19" + + structure = self.StructureTest() + field = structure["field"] + field.name = "field" + field.structure = self.StructureTest + field.set_value(FieldResolver) + expected = b"\x16\x17\x18\x19" + actual = field.get_value() + assert isinstance(field.value, type) + assert actual == expected + assert len(field) == 4 + def test_set_bytes(self): field = self.StructureTest()["field"] field.set_value(b"\x78\x00\x77\x00") @@ -486,6 +566,25 @@ def test_set_int(self): assert isinstance(field.value, bytes) assert actual == expected + def test_function_default(self): + def field_resolver(s): + return b"\x0d\x0e" + + class TestStructure(Structure): + def __init__(self): + self.fields = OrderedDict( + [ + ("field", BytesField(size=2, default=field_resolver)), + ] + ) + super().__init__() + + structure = TestStructure() + field = structure["field"] + expected = b"\x0d\x0e" + actual = field.pack() + assert actual == expected + def test_set_structure(self): field = self.StructureTest()["field"] field.size = 8 @@ -609,6 +708,23 @@ def __init__(self): assert len(field) == 7 assert actual == expected + def test_class_func(self): + class Unpacker: + def __new__(cls, s, d): + return [b"\x01\x02", b"\x03\x04\x05\x06", b"\07"] + + class UnpackListStructure(Structure): + def __init__(self): + self.fields = OrderedDict([("field", ListField(size=7, unpack_func=Unpacker))]) + super().__init__() + + field = UnpackListStructure()["field"] + field.unpack(b"\x00") + expected = [b"\x01\x02", b"\x03\x04\x05\x06", b"\07"] + actual = field.get_value() + assert len(field) == 7 + assert actual == expected + def test_set_none(self): field = self.StructureTest()["field"] field.set_value(None) @@ -631,6 +747,21 @@ def test_set_lambda_as_bytes(self): assert actual == expected assert len(field) == 4 + def test_set_function_as_bytes(self): + def field_resolver(s): + return b"\x10\x11\x12\x13" + + structure = self.StructureTest() + field = structure["field"] + field.name = "field" + field.structure = self.StructureTest + field.set_value(field_resolver) + expected = [b"\x10\x11", b"\x12\x13"] + actual = field.get_value() + assert isinstance(field.value, types.LambdaType) + assert actual == expected + assert len(field) == 4 + def test_set_lambda_as_list(self): structure = self.StructureTest() field = structure["field"] @@ -643,6 +774,21 @@ def test_set_lambda_as_list(self): assert actual == expected assert len(field) == 4 + def test_set_function_as_list(self): + def field_resolver(s): + return [b"\x11\x12", b"\x13\x14"] + + structure = self.StructureTest() + field = structure["field"] + field.name = "field" + field.structure = self.StructureTest + field.set_value(field_resolver) + expected = [b"\x11\x12", b"\x13\x14"] + actual = field.get_value() + assert isinstance(field.value, types.LambdaType) + assert actual == expected + assert len(field) == 4 + def test_set_bytes_fixed(self): field = self.StructureTest()["field"] field.set_value(b"\x78\x00\x77\x00") @@ -823,6 +969,22 @@ def test_set_lambda(self): assert isinstance(actual, Structure2) assert len(field) == 8 + def test_set_function(self): + def field_resolver(s): + return b"\x10\x07\x04\x01\x03\x05\x07\x09" + + structure = self.StructureTest() + field = structure["field"] + field.name = "field" + field.structure = self.StructureTest + field.set_value(field_resolver) + expected = b"\x10\x07\x04\x01\x03\x05\x07\x09" + actual = field.get_value() + assert isinstance(field.value, types.FunctionType) + assert actual.pack() == expected + assert isinstance(actual, Structure2) + assert len(field) == 8 + def test_set_lambda_without_type(self): structure = self.StructureTest() field = structure["field"] @@ -837,6 +999,23 @@ def test_set_lambda_without_type(self): assert isinstance(actual, bytes) assert len(field) == 8 + def test_set_function_without_type(self): + def field_resolver(s): + return b"\x10\x07\x04\x01\x03\x05\x07\x09" + + structure = self.StructureTest() + field = structure["field"] + field.name = "field" + field.structure = self.StructureTest + field.structure_type = None + field.set_value(field_resolver) + expected = b"\x10\x07\x04\x01\x03\x05\x07\x09" + actual = field.get_value() + assert isinstance(field.value, types.LambdaType) + assert actual == expected + assert isinstance(actual, bytes) + assert len(field) == 8 + def test_set_bytes(self): field = self.StructureTest()["field"] field.set_value(b"\x7d\x00\x00\x00\x14\x15\x16\x17") @@ -979,6 +1158,22 @@ def test_set_lambda(self): assert actual == expected assert len(field) == 16 + def test_set_class(self): + class FieldResolver: + def __new__(cls, s): + return uuid.UUID(bytes=b"\x10" * 16) + + structure = self.StructureTest() + field = structure["field"] + field.name = "field" + field.structure = self.StructureTest + field.set_value(FieldResolver) + expected = uuid.UUID(bytes=b"\x10" * 16) + actual = field.get_value() + assert isinstance(field.value, type) + assert actual == expected + assert len(field) == 16 + def test_set_bytes(self): field = self.StructureTest()["field"] field.set_value(b"\x22" * 16) @@ -1150,6 +1345,40 @@ def test_set_lambda(self): assert actual == expected assert len(field) == 8 + def test_set_class(self): + class FieldResolver: + def __new__(cls, s): + return datetime( + year=2022, + month=5, + day=27, + hour=12, + minute=36, + second=14, + microsecond=313481, + tzinfo=timezone.utc, + ) + + structure = self.StructureTest() + field = structure["field"] + field.name = "field" + field.structure = self.StructureTest + field.set_value(FieldResolver) + expected = datetime( + year=2022, + month=5, + day=27, + hour=12, + minute=36, + second=14, + microsecond=313481, + tzinfo=timezone.utc, + ) + actual = field.get_value() + assert isinstance(field.value, type) + assert actual == expected + assert len(field) == 8 + def test_set_bytes(self): field = self.StructureTest()["field"] field.set_value(b"\x00\x67\x7b\x21\x3d\x5d\xd3\x01") @@ -1525,6 +1754,22 @@ def test_set_lambda(self): assert actual == expected assert len(field) == 1 + def test_set_class(self): + class FieldResolver: + def __new__(cls, s): + return True + + structure = self.StructureTest() + field = structure["field"] + field.name = "field" + field.structure = self.StructureTest + field.set_value(FieldResolver) + expected = True + actual = field.get_value() + assert isinstance(field.value, type) + assert actual == expected + assert len(field) == 1 + def test_set_invalid(self): field = self.StructureTest()["field"] field.name = "field" @@ -1592,6 +1837,25 @@ def test_set_lambda(self): assert actual == expected assert len(field) == actual_length + def test_set_class(self): + class FieldResolver: + def __new__(cls, s): + return self.STRING_VALUE + + structure = self.StructureTest() + field = structure["field"] + field.name = "field" + field.structure = self.StructureTest + field.set_value(FieldResolver) + expected = self.STRING_VALUE + actual = field.get_value() + actual_length = 19 + if field.null_terminated: + actual_length += len("\x00".encode(field.encoding)) + assert isinstance(field.value, type) + assert actual == expected + assert len(field) == actual_length + def test_set_bytes(self): field = self.StructureTest()["field"] field.set_value(self.STRING_VALUE.encode("utf-8"))