Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 35 additions & 24 deletions src/pyshark/packet/layers/xml_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,25 @@


class XmlLayer(base.BaseLayer):
__slots__ = [
"raw_mode",
"_all_fields"
] + base.BaseLayer.__slots__
__slots__ = ["raw_mode", "_all_fields"] + base.BaseLayer.__slots__

def __init__(self, xml_obj=None, raw_mode=False):
super().__init__(xml_obj.attrib['name'])
super().__init__(xml_obj.attrib["name"])
self.raw_mode = raw_mode

self._all_fields = {}

# We copy over all the fields from the XML object
# Note: we don't read lazily from the XML because the lxml objects are very memory-inefficient
# so we'd rather not save them.
for field in xml_obj.findall('.//field'):
for field in xml_obj.findall(".//field"):
attributes = dict(field.attrib)
field_obj = LayerField(**attributes)
if attributes['name'] in self._all_fields:
if attributes["name"] in self._all_fields:
# Field name already exists, add this field to the container.
self._all_fields[attributes['name']].add_field(field_obj)
self._all_fields[attributes["name"]].add_field(field_obj)
else:
self._all_fields[attributes['name']] = LayerFieldsContainer(field_obj)
self._all_fields[attributes["name"]] = LayerFieldsContainer(field_obj)

def get_field(self, name) -> typing.Union[LayerFieldsContainer, None]:
"""Gets the XML field object of the given name."""
Expand All @@ -43,7 +40,9 @@ def get_field(self, name) -> typing.Union[LayerFieldsContainer, None]:
return field
return None

def get_field_value(self, name, raw=False) -> typing.Union[LayerFieldsContainer, None]:
def get_field_value(
self, name, raw=False
) -> typing.Union[LayerFieldsContainer, None]:
"""Tries getting the value of the given field.

Tries it in the following order: show (standard nice display), value (raw value),
Expand All @@ -65,48 +64,58 @@ def get_field_value(self, name, raw=False) -> typing.Union[LayerFieldsContainer,
@property
def field_names(self) -> typing.List[str]:
"""Gets all XML field names of this layer."""
return [self._sanitize_field_name(field_name) for field_name in self._all_fields]
return [
self._sanitize_field_name(field_name) for field_name in self._all_fields
]

@property
def layer_name(self):
if self._layer_name == 'fake-field-wrapper':
if self._layer_name == "fake-field-wrapper":
return base.DATA_LAYER_NAME
return super().layer_name

def __getattr__(self, item):
val = self.get_field(item)
if val is None:
raise AttributeError()
raise AttributeError(
f"Field '{item}' not found in layer '{self.layer_name}'. Available Fields: {', '.join(self.field_names)}"
)
if self.raw_mode:
return val.raw_value
return val

@property
def _field_prefix(self) -> str:
"""Prefix to field names in the XML."""
if self.layer_name == 'geninfo':
return ''
return self.layer_name + '.'
if self.layer_name == "geninfo":
return ""
return self.layer_name + "."

def _sanitize_field_name(self, field_name):
"""Sanitizes an XML field name

An xml field might have characters which would make it inaccessible as a python attribute).
"""
field_name = field_name.replace(self._field_prefix, '')
return field_name.replace('.', '_').replace('-', '_').lower()
field_name = field_name.replace(self._field_prefix, "")
return field_name.replace(".", "_").replace("-", "_").lower()

def _pretty_print_layer_fields(self, file: io.IOBase):
for field_line in self._get_all_field_lines():
if ':' in field_line:
field_name, field_line = field_line.split(':', 1)
file.write(colored(field_name + ':', "green", attrs=["bold"]))
if ":" in field_line:
field_name, field_line = field_line.split(":", 1)
file.write(colored(field_name + ":", "green", attrs=["bold"]))
file.write(colored(field_line, attrs=["bold"]))

def _get_all_fields_with_alternates(self):
all_fields = list(self._all_fields.values())
all_fields += sum([field.alternate_fields for field in all_fields
if isinstance(field, LayerFieldsContainer)], [])
all_fields += sum(
[
field.alternate_fields
for field in all_fields
if isinstance(field, LayerFieldsContainer)
],
[],
)
return all_fields

def _get_all_field_lines(self):
Expand All @@ -129,7 +138,9 @@ def _get_field_repr(self, field):
elif field.raw_value:
return f"{self._sanitize_field_name(field.name)}: {field.raw_value}"

def get_field_by_showname(self, showname) -> typing.Union[LayerFieldsContainer, None]:
def get_field_by_showname(
self, showname
) -> typing.Union[LayerFieldsContainer, None]:
"""Gets a field by its "showname"
This is the name that appears in Wireshark's detailed display i.e. in 'User-Agent: Mozilla...',
'User-Agent' is the .showname
Expand Down