Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,4 +131,4 @@ jobs:

- name: Publish
if: startsWith(github.event.release.tag_name, 'v')
uses: pypa/gh-action-pypi-publish@release/v1
uses: pypa/gh-action-pypi-publish@release/v1
37 changes: 19 additions & 18 deletions src/smbprotocol/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import copy
import datetime
import inspect
import math
import struct
import textwrap
Expand Down Expand Up @@ -148,7 +149,7 @@ def __init__(self, little_endian=True, default=None, size=None):
field_type = self.__class__.__name__
self.little_endian = little_endian

if not (size is None or isinstance(size, int) or isinstance(size, types.LambdaType)):
if not (size is None or isinstance(size, int) or callable(size)):
raise InvalidFieldDefinition(f"{field_type} size for field must be an int or None for a variable length")
self.size = size
self.default = default
Expand Down Expand Up @@ -270,7 +271,7 @@ def _get_calculated_value(self, value):
:param value: The value to calculate/expand
:return: The final value
"""
if isinstance(value, types.LambdaType):
if callable(value):
expanded_value = value(self.structure)
return self._get_calculated_value(expanded_value)
else:
Expand All @@ -294,7 +295,7 @@ def _get_calculated_size(self, size, data):
# is None (last field value)
if size is None:
return len(data)
elif isinstance(size, types.LambdaType):
elif callable(size):
expanded_size = size(self.structure)
return self._get_calculated_size(expanded_size, data)
else:
Expand All @@ -309,7 +310,7 @@ def _get_struct_format(self, size, unsigned=True):
:param size: The size as an int
:return: The struct format specifier for the size specified
"""
if isinstance(size, types.LambdaType):
if callable(size):
size = size(self.structure)

struct_format = {1: "B", 2: "H", 4: "L", 8: "Q"}
Expand Down Expand Up @@ -345,7 +346,7 @@ def _pack_value(self, value):
def _parse_value(self, value):
if value is None:
int_value = 0
elif isinstance(value, types.LambdaType):
elif callable(value):
int_value = value
elif isinstance(value, bytes):
struct_string = self._endian_prefix + self._get_struct_format(self.size, self.unsigned)
Expand Down Expand Up @@ -376,7 +377,7 @@ def _pack_value(self, value):
def _parse_value(self, value):
if value is None:
bytes_value = b""
elif isinstance(value, types.LambdaType):
elif callable(value):
bytes_value = value
elif isinstance(value, int):
struct_string = self._endian_prefix + self._get_struct_format(self.size)
Expand Down Expand Up @@ -425,7 +426,7 @@ def __init__(self, list_count=None, list_type=BytesField(), unpack_func=None, **
used when the list contains variable length values.
:param kwargs: Any other kwarg to be sent to Field()
"""
if list_count is not None and not (isinstance(list_count, int) or isinstance(list_count, types.LambdaType)):
if list_count is not None and not (isinstance(list_count, int) or callable(list_count)):
raise InvalidFieldDefinition(
"ListField list_count must be an int, lambda, or None for a variable list length"
)
Expand All @@ -435,7 +436,7 @@ def __init__(self, list_count=None, list_type=BytesField(), unpack_func=None, **
raise InvalidFieldDefinition("ListField list_type must be a Field definition")
self.list_type = list_type

if unpack_func is not None and not isinstance(unpack_func, types.LambdaType):
if unpack_func is not None and not callable(unpack_func):
raise InvalidFieldDefinition("ListField unpack_func must be a lambda function or None")
elif unpack_func is None and (list_count is None or list_type.size is None):
raise InvalidFieldDefinition(
Expand All @@ -456,7 +457,7 @@ def get_value(self):
# Override default get_value() so we return a list with the actual
# value, not the Field definition
list_value = []
if isinstance(self.value, types.LambdaType):
if callable(self.value):
value = self._get_calculated_value(self.value)
else:
value = self.value
Expand All @@ -474,9 +475,9 @@ def _pack_value(self, value):
def _parse_value(self, value):
if value is None:
list_value = []
elif isinstance(value, types.LambdaType):
elif callable(value):
return value
elif isinstance(value, bytes) and isinstance(self.unpack_func, types.LambdaType):
elif isinstance(value, bytes) and callable(self.unpack_func):
# use the lambda function to parse the bytes to a list
list_value = self.unpack_func(self.structure, value)
elif isinstance(value, bytes):
Expand Down Expand Up @@ -527,7 +528,7 @@ def _to_string(self):

def _create_list_from_bytes(self, list_count, list_type, value):
# calculate the list_count and rerun method if a lambda
if isinstance(list_count, types.LambdaType):
if callable(list_count):
list_count = list_count(self.structure)
return self._create_list_from_bytes(list_count, list_type, value)

Expand Down Expand Up @@ -576,7 +577,7 @@ def _pack_value(self, value):
def _parse_value(self, value):
if value is None:
structure_value = b""
elif isinstance(value, types.LambdaType):
elif callable(value):
structure_value = value
elif isinstance(value, bytes):
structure_value = value
Expand All @@ -586,7 +587,7 @@ def _parse_value(self, value):
raise TypeError(f"Cannot parse value for field {self.name} of type {type(value).__name__} to a structure")

if isinstance(structure_value, bytes) and self.structure_type and structure_value != b"":
if isinstance(self.structure_type, types.LambdaType):
if callable(self.structure_type) and not inspect.isclass(self.structure_type):
structure_type = self.structure_type(self.structure)
else:
structure_type = self.structure_type
Expand Down Expand Up @@ -650,7 +651,7 @@ def _pack_value(self, value):
def _parse_value(self, value):
if value is None:
datetime_value = datetime.datetime.today()
elif isinstance(value, types.LambdaType):
elif callable(value):
datetime_value = value
elif isinstance(value, bytes):
struct_string = self._endian_prefix + self._get_struct_format(8)
Expand Down Expand Up @@ -713,7 +714,7 @@ def _parse_value(self, value):
uuid_value = uuid.UUID(int=value)
elif isinstance(value, uuid.UUID):
uuid_value = value
elif isinstance(value, types.LambdaType):
elif callable(value):
uuid_value = value
else:
raise TypeError(f"Cannot parse value for field {self.name} of type {type(value).__name__} to a uuid")
Expand Down Expand Up @@ -823,7 +824,7 @@ def _parse_value(self, value):
bool_value = value
elif isinstance(value, bytes):
bool_value = value == b"\x01"
elif isinstance(value, types.LambdaType):
elif callable(value):
bool_value = value
else:
raise TypeError(f"Cannot parse value for field {self.name} of type {type(value).__name__} to a bool")
Expand Down Expand Up @@ -854,7 +855,7 @@ def _parse_value(self, value):
text_value = to_text(value, encoding=self.encoding)
elif isinstance(value, str):
text_value = value
elif isinstance(value, types.LambdaType):
elif callable(value):
text_value = value
else:
raise TypeError(f"Cannot parse value for field {self.name} of type {type(value).__name__} to a text string")
Expand Down
Loading