diff --git a/src/pyshark/packet/layers/xml_layer.py b/src/pyshark/packet/layers/xml_layer.py index 9c85f660..8dfa0cc4 100644 --- a/src/pyshark/packet/layers/xml_layer.py +++ b/src/pyshark/packet/layers/xml_layer.py @@ -8,13 +8,10 @@ class XmlLayer(base.BaseLayer): - __slots__ = [ - "raw_mode", - "_all_fields" - ] + base.BaseLayer.__slots__ + __slots__ = ["raw_mode", "_all_fields"] + base.BaseLayer.__slots__ def __init__(self, xml_obj=None, raw_mode=False): - super().__init__(xml_obj.attrib['name']) + super().__init__(xml_obj.attrib["name"]) self.raw_mode = raw_mode self._all_fields = {} @@ -22,14 +19,14 @@ def __init__(self, xml_obj=None, raw_mode=False): # We copy over all the fields from the XML object # Note: we don't read lazily from the XML because the lxml objects are very memory-inefficient # so we'd rather not save them. - for field in xml_obj.findall('.//field'): + for field in xml_obj.findall(".//field"): attributes = dict(field.attrib) field_obj = LayerField(**attributes) - if attributes['name'] in self._all_fields: + if attributes["name"] in self._all_fields: # Field name already exists, add this field to the container. - self._all_fields[attributes['name']].add_field(field_obj) + self._all_fields[attributes["name"]].add_field(field_obj) else: - self._all_fields[attributes['name']] = LayerFieldsContainer(field_obj) + self._all_fields[attributes["name"]] = LayerFieldsContainer(field_obj) def get_field(self, name) -> typing.Union[LayerFieldsContainer, None]: """Gets the XML field object of the given name.""" @@ -43,7 +40,9 @@ def get_field(self, name) -> typing.Union[LayerFieldsContainer, None]: return field return None - def get_field_value(self, name, raw=False) -> typing.Union[LayerFieldsContainer, None]: + def get_field_value( + self, name, raw=False + ) -> typing.Union[LayerFieldsContainer, None]: """Tries getting the value of the given field. Tries it in the following order: show (standard nice display), value (raw value), @@ -65,18 +64,22 @@ def get_field_value(self, name, raw=False) -> typing.Union[LayerFieldsContainer, @property def field_names(self) -> typing.List[str]: """Gets all XML field names of this layer.""" - return [self._sanitize_field_name(field_name) for field_name in self._all_fields] + return [ + self._sanitize_field_name(field_name) for field_name in self._all_fields + ] @property def layer_name(self): - if self._layer_name == 'fake-field-wrapper': + if self._layer_name == "fake-field-wrapper": return base.DATA_LAYER_NAME return super().layer_name def __getattr__(self, item): val = self.get_field(item) if val is None: - raise AttributeError() + raise AttributeError( + f"Field '{item}' not found in layer '{self.layer_name}'. Available Fields: {', '.join(self.field_names)}" + ) if self.raw_mode: return val.raw_value return val @@ -84,29 +87,35 @@ def __getattr__(self, item): @property def _field_prefix(self) -> str: """Prefix to field names in the XML.""" - if self.layer_name == 'geninfo': - return '' - return self.layer_name + '.' + if self.layer_name == "geninfo": + return "" + return self.layer_name + "." def _sanitize_field_name(self, field_name): """Sanitizes an XML field name An xml field might have characters which would make it inaccessible as a python attribute). """ - field_name = field_name.replace(self._field_prefix, '') - return field_name.replace('.', '_').replace('-', '_').lower() + field_name = field_name.replace(self._field_prefix, "") + return field_name.replace(".", "_").replace("-", "_").lower() def _pretty_print_layer_fields(self, file: io.IOBase): for field_line in self._get_all_field_lines(): - if ':' in field_line: - field_name, field_line = field_line.split(':', 1) - file.write(colored(field_name + ':', "green", attrs=["bold"])) + if ":" in field_line: + field_name, field_line = field_line.split(":", 1) + file.write(colored(field_name + ":", "green", attrs=["bold"])) file.write(colored(field_line, attrs=["bold"])) def _get_all_fields_with_alternates(self): all_fields = list(self._all_fields.values()) - all_fields += sum([field.alternate_fields for field in all_fields - if isinstance(field, LayerFieldsContainer)], []) + all_fields += sum( + [ + field.alternate_fields + for field in all_fields + if isinstance(field, LayerFieldsContainer) + ], + [], + ) return all_fields def _get_all_field_lines(self): @@ -129,7 +138,9 @@ def _get_field_repr(self, field): elif field.raw_value: return f"{self._sanitize_field_name(field.name)}: {field.raw_value}" - def get_field_by_showname(self, showname) -> typing.Union[LayerFieldsContainer, None]: + def get_field_by_showname( + self, showname + ) -> typing.Union[LayerFieldsContainer, None]: """Gets a field by its "showname" This is the name that appears in Wireshark's detailed display i.e. in 'User-Agent: Mozilla...', 'User-Agent' is the .showname