From 87b0bea53156a1d522ae668b38e6af898604093f Mon Sep 17 00:00:00 2001 From: Trung Dong Huynh Date: Sat, 7 Jun 2025 21:22:18 +0100 Subject: [PATCH 01/20] Configure mypy settings in pyproject.toml for stricter type checking --- pyproject.toml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 7b1e5ad..8c7c672 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -95,3 +95,12 @@ disable = "C0330, C0326" [tool.pylint.format] max-line-length = "88" + +[tool.mypy] +python_version = 3.12 +exclude = ["prov/tests/*"] + +[[tool.mypy.overrides]] +module = "prov.*" +disallow_untyped_defs = true +check_untyped_defs = true From 1c822590bab62e1796902b6a58d0bdf1b1e98376 Mon Sep 17 00:00:00 2001 From: Trung Dong Huynh Date: Sun, 8 Jun 2025 19:09:14 +0100 Subject: [PATCH 02/20] Set up GitHub Actions workflow for mypy type checks Also testing with the branch 143-type-hints --- .github/workflows/mypy.yml | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 .github/workflows/mypy.yml diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml new file mode 100644 index 0000000..4118b44 --- /dev/null +++ b/.github/workflows/mypy.yml @@ -0,0 +1,27 @@ +name: mypy check +on: + push: + branches: [main, master, 143-type-hints] + pull_request: + branches: [main, master, dev] + +jobs: + static-type-check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.x' + - run: pip install mypy + - name: Get Python changed files + id: changed-py-files + uses: tj-actions/changed-files@v46 + with: + files: | + *.py + **/*.py + - name: Run if any of the listed files above is changed + if: steps.changed-py-files.outputs.any_changed == 'true' + run: mypy ${{ steps.changed-py-files.outputs.all_changed_files }} --ignore-missing-imports + From ca814ea3208c92025d17ae4f5d51196815ba1974 Mon Sep 17 00:00:00 2001 From: Trung Dong Huynh Date: Sun, 8 Jun 2025 19:10:42 +0100 Subject: [PATCH 03/20] Add type hints and improve docstrings for `identifier.py` #143 --- src/prov/identifier.py | 89 +++++++++++++++++++++++++++--------------- 1 file changed, 57 insertions(+), 32 deletions(-) diff --git a/src/prov/identifier.py b/src/prov/identifier.py index b1333c4..b40fd0d 100644 --- a/src/prov/identifier.py +++ b/src/prov/identifier.py @@ -1,3 +1,6 @@ +from __future__ import annotations # needed for | type annotations in Python < 3.10 +from typing import Any + __author__ = "Trung Dong Huynh" __email__ = "trungdong@donggiang.com" @@ -8,45 +11,67 @@ class Identifier(object): # TODO: make Identifier an "abstract" base class and move xsd:anyURI # into a subclass - def __init__(self, uri): + def __init__(self, uri: str): """ Constructor. :param uri: URI string for the long namespace identifier. """ - self._uri = str(uri) # Ensure this is a unicode string + self._uri: str = str(uri) # Ensure this is a unicode string @property - def uri(self): - """Identifier's URI.""" + def uri(self) -> str: + """ + Returns the URI associated with the current identifier. + + Returns: + str: The URI representing the resource identifier. + """ return self._uri - def __str__(self): + def __str__(self) -> str: return self._uri - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return self.uri == other.uri if isinstance(other, Identifier) else False - def __hash__(self): + def __hash__(self) -> int: return hash((self.uri, self.__class__)) - def __repr__(self): + def __repr__(self) -> str: return "<%s: %s>" % (self.__class__.__name__, self._uri) - def provn_representation(self): - """PROV-N representation of qualified name in a string.""" + def provn_representation(self) -> str: + """ + Returns the PROV-N representation of the URI. + + Returns: + str: The PROV-N representation of the URI. + """ return '"%s" %%%% xsd:anyURI' % self._uri class QualifiedName(Identifier): - """Qualified name of an identifier in a particular namespace.""" + """ + Represents a `qualified name `_, + which combines a namespace and a local part for use in identifying entities in a + namespace-aware context. - def __init__(self, namespace, localpart): - """ - Constructor. + This class facilitates handling and manipulation of qualified names, which + combine a namespace and a local identifier. It supports string representation, + hashing, and retrieval of individual components (namespace or local part). + """ - :param namespace: Namespace to use for qualified name resolution. - :param localpart: Portion of identifier not part of the namespace prefix. + def __init__(self, namespace: "Namespace", localpart: str): + """ + Initializes a new qualified name with the provided namespace and localpart + values. It combines the namespace URI and localpart to form an identifier and + constructs a string representation including optional namespace prefix. + + Args: + namespace (Namespace): The namespace object containing a URI and optional + prefix associated with this qualified name. + localpart (str): The local part of the qualified name. """ Identifier.__init__(self, "".join([namespace.uri, localpart])) self._namespace = namespace @@ -56,25 +81,25 @@ def __init__(self, namespace, localpart): ) @property - def namespace(self): + def namespace(self) -> "Namespace": """Namespace of qualified name.""" return self._namespace @property - def localpart(self): + def localpart(self) -> str: """Local part of qualified name.""" return self._localpart - def __str__(self): + def __str__(self) -> str: return self._str - def __repr__(self): + def __repr__(self) -> str: return "<%s: %s>" % (self.__class__.__name__, self._str) - def __hash__(self): + def __hash__(self) -> int: return hash(self.uri) - def provn_representation(self): + def provn_representation(self) -> str: """PROV-N representation of qualified name in a string.""" return "'%s'" % self._str @@ -93,19 +118,19 @@ def __init__(self, prefix: str, uri: str): raise ValueError("Not a valid URI to create a namespace.") self._prefix = prefix self._uri = uri - self._cache = dict() + self._cache: dict[str, QualifiedName] = dict() @property - def uri(self): + def uri(self) -> str: """Namespace URI.""" return self._uri @property - def prefix(self): + def prefix(self) -> str: """Namespace prefix.""" return self._prefix - def contains(self, identifier): + def contains(self, identifier: Identifier) -> bool: """ Indicates whether the identifier provided is contained in this namespace. @@ -119,7 +144,7 @@ def contains(self, identifier): ) return uri.startswith(self._uri) if uri else False - def qname(self, identifier): + def qname(self, identifier: str | Identifier) -> QualifiedName | None: """ Returns the qualified name of the identifier given using the namespace prefix. @@ -137,27 +162,27 @@ def qname(self, identifier): else: return None - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return ( (self._uri == other.uri and self._prefix == other.prefix) if isinstance(other, Namespace) else False ) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return ( not isinstance(other, Namespace) or self._uri != other.uri or self._prefix != other.prefix ) - def __hash__(self): + def __hash__(self) -> int: return hash((self._uri, self._prefix)) - def __repr__(self): + def __repr__(self) -> str: return "<%s: %s {%s}>" % (self.__class__.__name__, self._prefix, self._uri) - def __getitem__(self, localpart): + def __getitem__(self, localpart: str) -> QualifiedName: if localpart in self._cache: return self._cache[localpart] else: From 4e0e91e0a3b44984ad4c39395db113b21697c32c Mon Sep 17 00:00:00 2001 From: Trung Dong Huynh Date: Sun, 8 Jun 2025 22:52:08 +0100 Subject: [PATCH 04/20] Add type hints and refactor `serializers` module with abstract base class implementation #143 --- src/prov/serializers/__init__.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/prov/serializers/__init__.py b/src/prov/serializers/__init__.py index e9d9577..586ca55 100644 --- a/src/prov/serializers/__init__.py +++ b/src/prov/serializers/__init__.py @@ -1,4 +1,8 @@ +from __future__ import annotations # needed for | type annotations in Python < 3.10 +from abc import ABC, abstractmethod +import io from prov import Error +from prov.model import ProvDocument __author__ = "Trung Dong Huynh" __email__ = "trungdong@donggiang.com" @@ -6,13 +10,13 @@ __all__ = ["get", "Serializer"] -class Serializer(object): +class Serializer(ABC): """Serializer for PROV documents.""" document = None """PROV document to serialise.""" - def __init__(self, document=None): + def __init__(self, document: ProvDocument | None = None): """ Constructor. @@ -20,19 +24,23 @@ def __init__(self, document=None): """ self.document = document - def serialize(self, stream, **kwargs): + @abstractmethod + def serialize(self, stream: io.IOBase) -> None: """ Abstract method for serializing. :param stream: Stream object to serialize the document into. """ + pass - def deserialize(self, stream, **kwargs): + @abstractmethod + def deserialize(self, stream: io.IOBase) -> ProvDocument | None: """ Abstract method for deserializing. :param stream: Stream object to deserialize the document from. """ + pass class DoNotExist(Error): @@ -44,11 +52,11 @@ class DoNotExist(Error): class Registry: """Registry of serializers.""" - serializers = None + serializers = None # type: dict[str, type[Serializer]] """Property caching all available serializers in a dict.""" @staticmethod - def load_serializers(): + def load_serializers() -> None: """Loads all available serializers into the registry.""" from prov.serializers.provjson import ProvJSONSerializer from prov.serializers.provn import ProvNSerializer @@ -63,7 +71,7 @@ def load_serializers(): } -def get(format_name): +def get(format_name: str) -> type[Serializer]: """ Returns the serializer class for the specified format. Raises a DoNotExist """ From 4bff595ab95798892b12adf2e379d2125de1e00a Mon Sep 17 00:00:00 2001 From: Trung Dong Huynh Date: Sun, 8 Jun 2025 22:59:04 +0100 Subject: [PATCH 05/20] Add type hints to `read` function #143 --- src/prov/__init__.py | 7 ++++++- src/prov/serializers/__init__.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/prov/__init__.py b/src/prov/__init__.py index 3db3248..3d5c028 100644 --- a/src/prov/__init__.py +++ b/src/prov/__init__.py @@ -1,3 +1,7 @@ +from __future__ import annotations # needed for | type annotations in Python < 3.10 +import os +from prov.model import ProvDocument + __author__ = "Trung Dong Huynh" __email__ = "trungdong@donggiang.com" __version__ = "2.0.2" @@ -11,7 +15,7 @@ class Error(Exception): pass -def read(source, format=None): +def read(source: os.PathLike, format: str | None = None) -> ProvDocument | None: """ Convenience function returning a ProvDocument instance. @@ -39,6 +43,7 @@ def read(source, format=None): try: return ProvDocument.deserialize(source=source, format=format) except: + # TODO: Specify an exception type for failing to deserialize. pass else: raise TypeError( diff --git a/src/prov/serializers/__init__.py b/src/prov/serializers/__init__.py index 586ca55..85822c0 100644 --- a/src/prov/serializers/__init__.py +++ b/src/prov/serializers/__init__.py @@ -7,7 +7,7 @@ __author__ = "Trung Dong Huynh" __email__ = "trungdong@donggiang.com" -__all__ = ["get", "Serializer"] +__all__ = ["get", "Registry", "Serializer"] class Serializer(ABC): From b5df843db530d29bcfd9418d4c8b37d7be1b4ad1 Mon Sep 17 00:00:00 2001 From: Trung Dong Huynh Date: Mon, 9 Jun 2025 16:56:21 +0100 Subject: [PATCH 06/20] Add `TYPE_CHECKING` conditional import for `ProvDocument` to avoid circular imports --- src/prov/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/prov/__init__.py b/src/prov/__init__.py index 3d5c028..fd1f1cf 100644 --- a/src/prov/__init__.py +++ b/src/prov/__init__.py @@ -1,6 +1,9 @@ from __future__ import annotations # needed for | type annotations in Python < 3.10 import os -from prov.model import ProvDocument +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from prov.model import ProvDocument __author__ = "Trung Dong Huynh" __email__ = "trungdong@donggiang.com" From 64550fa09fd78cd071c0264001a979758cae4755 Mon Sep 17 00:00:00 2001 From: Trung Dong Huynh Date: Mon, 9 Jun 2025 17:04:38 +0100 Subject: [PATCH 07/20] Update abstract methods in `serializers` to accept additional keyword arguments and add `Any` type hint --- src/prov/serializers/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/prov/serializers/__init__.py b/src/prov/serializers/__init__.py index 85822c0..44b1b8b 100644 --- a/src/prov/serializers/__init__.py +++ b/src/prov/serializers/__init__.py @@ -1,6 +1,7 @@ from __future__ import annotations # needed for | type annotations in Python < 3.10 from abc import ABC, abstractmethod import io +from typing import Any from prov import Error from prov.model import ProvDocument @@ -25,7 +26,7 @@ def __init__(self, document: ProvDocument | None = None): self.document = document @abstractmethod - def serialize(self, stream: io.IOBase) -> None: + def serialize(self, stream: io.IOBase, **args: Any) -> None: """ Abstract method for serializing. @@ -34,7 +35,7 @@ def serialize(self, stream: io.IOBase) -> None: pass @abstractmethod - def deserialize(self, stream: io.IOBase) -> ProvDocument | None: + def deserialize(self, stream: io.IOBase, **args: Any) -> ProvDocument | None: """ Abstract method for deserializing. From 22abe7c424e74c92401fbea536000180cf7ba659 Mon Sep 17 00:00:00 2001 From: Trung Dong Huynh Date: Mon, 9 Jun 2025 17:23:11 +0100 Subject: [PATCH 08/20] Add `TYPE_CHECKING` import in `serializers` to prevent circular dependency --- src/prov/serializers/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/prov/serializers/__init__.py b/src/prov/serializers/__init__.py index 44b1b8b..043fa69 100644 --- a/src/prov/serializers/__init__.py +++ b/src/prov/serializers/__init__.py @@ -1,9 +1,11 @@ from __future__ import annotations # needed for | type annotations in Python < 3.10 from abc import ABC, abstractmethod import io -from typing import Any +from typing import Any, TYPE_CHECKING from prov import Error -from prov.model import ProvDocument + +if TYPE_CHECKING: + from prov.model import ProvDocument __author__ = "Trung Dong Huynh" __email__ = "trungdong@donggiang.com" From 15f6bd474ba47c0dd753e77100fbcab198b162c1 Mon Sep 17 00:00:00 2001 From: Trung Dong Huynh Date: Mon, 9 Jun 2025 18:56:20 +0100 Subject: [PATCH 09/20] Expand supported input types for the `read` function in `prov` module --- src/prov/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/prov/__init__.py b/src/prov/__init__.py index fd1f1cf..911fb44 100644 --- a/src/prov/__init__.py +++ b/src/prov/__init__.py @@ -12,13 +12,15 @@ __all__ = ["Error", "model", "read"] + + class Error(Exception): """Base class for all errors in this package.""" pass -def read(source: os.PathLike, format: str | None = None) -> ProvDocument | None: +def read(source: str | bytes | os.PathLike, format: str | None = None) -> ProvDocument | None: """ Convenience function returning a ProvDocument instance. From 4c804049686e9a98ad2988eff8aa899794c6e99d Mon Sep 17 00:00:00 2001 From: Trung Dong Huynh Date: Mon, 9 Jun 2025 21:27:35 +0100 Subject: [PATCH 10/20] Add type annotations and improve typing consistency in `model.py` #143 This commit introduces detailed type annotations for improved clarity and stricter type checks. Includes adjustments in `test_extras.py`, `serializers`, and additional type alias definitions. --- src/prov/model.py | 996 +++++++++++++++++++------------ src/prov/serializers/__init__.py | 2 +- src/prov/tests/test_extras.py | 11 +- 3 files changed, 634 insertions(+), 375 deletions(-) diff --git a/src/prov/model.py b/src/prov/model.py index 9c47e9d..ce37e7b 100644 --- a/src/prov/model.py +++ b/src/prov/model.py @@ -6,6 +6,9 @@ PROV-DM: http://www.w3.org/TR/prov-dm/ PROV-JSON: https://openprovenance.org/prov-json/ """ + +from __future__ import annotations # needed for | type annotations in Python < 3.10 + from collections import defaultdict import datetime import io @@ -14,6 +17,15 @@ import os import shutil import tempfile +from io import IOBase +from typing import ( + Any, + Callable, + Iterable, + Optional, + Union, +) +import typing # to use typing.TypeAlias in comments for compatibility with Python 3.9 from urllib.parse import urlparse import dateutil.parser @@ -29,15 +41,35 @@ logger = logging.getLogger(__name__) +# Type aliases for convenience +QualifiedNameCandidate = Union[QualifiedName, str, Identifier] # type: typing.TypeAlias +OptionalID = Optional[QualifiedNameCandidate] # type: typing.TypeAlias +EntityRef = Union["ProvEntity", QualifiedNameCandidate] # type: typing.TypeAlias +ActivityRef = Union["ProvActivity", QualifiedNameCandidate] # type: typing.TypeAlias +AgentRef = Union[ + "ProvAgent", "ProvEntity", "ProvActivity", QualifiedNameCandidate +] # type: typing.TypeAlias +GenrationRef = Union["ProvGeneration", QualifiedNameCandidate] # type: typing.TypeAlias +UsageRef = Union["ProvUsage", QualifiedNameCandidate] # type: typing.TypeAlias +RecordAttributesArg = Union[ + dict[QualifiedNameCandidate, Any], + Iterable[tuple[QualifiedNameCandidate, Any]], +] # type: typing.TypeAlias +NameValuePair = tuple[QualifiedName, Any] # type: typing.TypeAlias +DatetimeOrStr = Union[datetime.datetime, str] # type: typing.TypeAlias +NSCollection = Union[dict[str, str], Iterable[Namespace]] # type: typing.TypeAlias +PathLike = Union[str, bytes, os.PathLike] # type: typing.TypeAlias + + # Data Types -def _ensure_datetime(value): +def _ensure_datetime(value: Optional[DatetimeOrStr]) -> Optional[datetime.datetime]: if isinstance(value, str): return dateutil.parser.parse(value) else: return value -def parse_xsd_datetime(value): +def parse_xsd_datetime(value: str) -> Optional[datetime.datetime]: try: return dateutil.parser.parse(value) except ValueError: @@ -45,7 +77,7 @@ def parse_xsd_datetime(value): return None -def parse_boolean(value): +def parse_boolean(value: str) -> Optional[bool]: if value.lower() in ("false", "0"): return False elif value.lower() in ("true", "1"): @@ -60,7 +92,10 @@ def parse_boolean(value): # Mappings for XSD datatypes to Python standard types -XSD_DATATYPE_PARSERS = { +SupportedXSDParsedTypes = Union[ + str, datetime.datetime, float, int, bool, Identifier, None +] # type: typing.TypeAlias +XSD_DATATYPE_PARSERS: dict[QualifiedName, Callable[[str], SupportedXSDParsedTypes]] = { XSD_STRING: str, XSD_DOUBLE: float, XSD_LONG: int, @@ -71,7 +106,7 @@ def parse_boolean(value): } -def parse_xsd_types(value, datatype): +def parse_xsd_types(value: str, datatype: QualifiedName) -> SupportedXSDParsedTypes: return ( XSD_DATATYPE_PARSERS[datatype](value) if datatype in XSD_DATATYPE_PARSERS @@ -79,11 +114,11 @@ def parse_xsd_types(value, datatype): ) -def first(a_set): +def first(a_set: set[Any]) -> Any | None: return next(iter(a_set), None) -def _ensure_multiline_string_triple_quoted(value): +def _ensure_multiline_string_triple_quoted(value: str) -> str: # converting the value to a string s = str(value) # Escaping any double quote @@ -94,7 +129,9 @@ def _ensure_multiline_string_triple_quoted(value): return '"%s"' % s -def encoding_provn_value(value): +def encoding_provn_value( + value: str | datetime.datetime | float | bool | QualifiedName, +) -> str: if isinstance(value, str): return _ensure_multiline_string_triple_quoted(value) elif isinstance(value, datetime.datetime): @@ -109,8 +146,13 @@ def encoding_provn_value(value): class Literal(object): - def __init__(self, value, datatype=None, langtag=None): - self._value = str(value) # value is always a string + def __init__( + self, + value: Any, + datatype: Optional[QualifiedName] = None, + langtag: Optional[str] = None, + ): + self._value: str = str(value) # value is always a string if langtag: if datatype is None: logger.debug( @@ -127,17 +169,17 @@ def __init__(self, value, datatype=None, langtag=None): "prov:InternationalizedString." % (datatype, value, langtag) ) datatype = PROV["InternationalizedString"] - self._datatype = datatype + self._datatype: Optional[QualifiedName] = datatype # langtag is always a string - self._langtag = str(langtag) if langtag is not None else None + self._langtag: Optional[str] = str(langtag) if langtag is not None else None - def __str__(self): + def __str__(self) -> str: return self.provn_representation() - def __repr__(self): + def __repr__(self) -> str: return "" % self.provn_representation() - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return ( ( self._value == other.value @@ -148,28 +190,28 @@ def __eq__(self, other): else False ) - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not (self == other) - def __hash__(self): + def __hash__(self) -> int: return hash((self._value, self._datatype, self._langtag)) @property - def value(self): + def value(self) -> str: return self._value @property - def datatype(self): + def datatype(self) -> QualifiedName | None: return self._datatype @property - def langtag(self): + def langtag(self) -> str | None: return self._langtag - def has_no_langtag(self): + def has_no_langtag(self) -> bool: return self._langtag is None - def provn_representation(self): + def provn_representation(self) -> str: if self._langtag: # a language tag can only go with prov:InternationalizedString return "%s@%s" % ( @@ -202,7 +244,7 @@ class ProvExceptionInvalidQualifiedName(ProvException): qname = None """Intended qualified name.""" - def __init__(self, qname): + def __init__(self, qname: Any): """ Constructor. @@ -210,14 +252,14 @@ def __init__(self, qname): """ self.qname = qname - def __str__(self): + def __str__(self) -> str: return "Invalid Qualified Name: %s" % self.qname class ProvElementIdentifierRequired(ProvException): """Exception for a missing element identifier.""" - def __str__(self): + def __str__(self) -> str: return ( "An identifier is missing. All PROV elements require a valid " "identifier." ) @@ -227,12 +269,18 @@ def __str__(self): class ProvRecord(object): """Base class for PROV records.""" - FORMAL_ATTRIBUTES = () + FORMAL_ATTRIBUTES = () # type: tuple[QualifiedName, ...] + """Formal attributes names of this record type, in the expected order.""" - _prov_type = None + _prov_type: Optional[QualifiedName] = None """PROV type of record.""" - def __init__(self, bundle, identifier, attributes=None): + def __init__( + self, + bundle: ProvBundle, + identifier: Optional[QualifiedName], + attributes: Optional[RecordAttributesArg] = None, + ): """ Constructor. @@ -246,10 +294,10 @@ def __init__(self, bundle, identifier, attributes=None): if attributes: self.add_attributes(attributes) - def __hash__(self): + def __hash__(self) -> int: return hash((self.get_type(), self._identifier, frozenset(self.attributes))) - def copy(self): + def copy(self) -> "ProvRecord": """ Return an exact copy of this record. """ @@ -257,15 +305,18 @@ def copy(self): self._bundle, self.identifier, self.attributes ) - def get_type(self): + def get_type(self) -> QualifiedName: """Returns the PROV type of the record.""" - return self._prov_type + if self._prov_type is not None: + return self._prov_type + else: + raise NotImplementedError("Type not defined for this record.") - def get_asserted_types(self): + def get_asserted_types(self) -> set[QualifiedName]: """Returns the set of all asserted PROV types of this record.""" return self._attributes[PROV_TYPE] - def add_asserted_type(self, type_identifier): + def add_asserted_type(self, type_identifier: QualifiedName) -> None: """ Adds a PROV type assertion to the record. @@ -273,24 +324,24 @@ def add_asserted_type(self, type_identifier): """ self._attributes[PROV_TYPE].add(type_identifier) - def get_attribute(self, attr_name) -> set: + def get_attribute(self, attr_name: QualifiedNameCandidate) -> set: """ - Returns the attribute of the given name. + Returns the attribute values (if any) for the specified attribute name). :param attr_name: Name of the attribute. :return: Set of value(s) of the specified attribute. :rtype: set """ - attr_name = self._bundle.valid_qualified_name(attr_name) - return self._attributes[attr_name] + attr_name_qn = self._bundle.mandatory_valid_qname(attr_name) + return self._attributes[attr_name_qn] @property - def identifier(self): + def identifier(self) -> QualifiedName | None: """Record's identifier.""" return self._identifier @property - def attributes(self): + def attributes(self) -> list[tuple[QualifiedName, Any]]: """ All record attributes. @@ -303,7 +354,7 @@ def attributes(self): ] @property - def args(self): + def args(self) -> tuple: """ All values of the record's formal attributes. @@ -314,7 +365,7 @@ def args(self): ) @property - def formal_attributes(self): + def formal_attributes(self) -> tuple[tuple[QualifiedName, Any], ...]: """ All names and values of the record's formal attributes. @@ -326,21 +377,21 @@ def formal_attributes(self): ) @property - def extra_attributes(self): + def extra_attributes(self) -> tuple[tuple[QualifiedName, Any], ...]: """ All names and values of the record's attributes that are not formal attributes. :return: Tuple of tuples (name, value) """ - return [ + return tuple( (attr_name, attr_value) for attr_name, attr_value in self.attributes if attr_name not in self.FORMAL_ATTRIBUTES - ] + ) @property - def bundle(self): + def bundle(self) -> ProvBundle: """ Bundle of the record. @@ -349,21 +400,21 @@ def bundle(self): return self._bundle @property - def label(self): + def label(self) -> str: """Identifying label of the record.""" - return ( + return str( first(self._attributes[PROV_LABEL]) if self._attributes[PROV_LABEL] else self._identifier ) @property - def value(self): + def value(self) -> Any: """Value of the record.""" return self._attributes[PROV_VALUE] # Handling attributes - def _auto_literal_conversion(self, literal): + def _auto_literal_conversion(self, literal: Any) -> Any: # This method normalise datatype for literals if isinstance(literal, ProvRecord): @@ -376,8 +427,8 @@ def _auto_literal_conversion(self, literal): return self._bundle.valid_qualified_name(literal) elif isinstance(literal, Literal) and literal.has_no_langtag(): if literal.datatype: - # try convert generic Literal object to Python standard type - # this is to match JSON decoding's literal conversion + # try to convert a generic Literal object to Python standard type + # to match the JSON decoding's literal conversion value = parse_xsd_types(literal.value, literal.datatype) else: # A literal with no datatype nor langtag defined @@ -389,12 +440,12 @@ def _auto_literal_conversion(self, literal): # No conversion possible, return the original value return literal - def add_attributes(self, attributes): + def add_attributes(self, attributes: RecordAttributesArg) -> None: """ Add attributes to the record. :param attributes: Dictionary of attributes, with keys being qualified - identifiers. Alternatively an iterable of tuples (key, value) with the + identifiers. Alternatively, an iterable of tuples (key, value) with the keys satisfying the same condition. """ if attributes: @@ -416,24 +467,33 @@ def add_attributes(self, attributes): continue # make sure the attribute name is valid - attr = self._bundle.valid_qualified_name(attr_name) - if attr is None: - raise ProvExceptionInvalidQualifiedName(attr_name) + attr = self._bundle.mandatory_valid_qname(attr_name) if attr in PROV_ATTRIBUTE_QNAMES: # Expecting a qualified name - qname = ( - original_value.identifier - if isinstance(original_value, ProvRecord) - else original_value - ) - value = self._bundle.valid_qualified_name(qname) + if isinstance(original_value, ProvRecord): + # Use the identifier of the record, which must exist, as the value for this attribute + qname = original_value.identifier + if qname is None: + raise ProvException( + f"Invalid value for attribute {attr}: {original_value}." + f" The record has no identifier." + ) + else: + qname = original_value + value = self._bundle.mandatory_valid_qname(qname) # type: Any elif attr in PROV_ATTRIBUTE_LITERALS: - value = ( - original_value - if isinstance(original_value, datetime.datetime) - else parse_xsd_datetime(original_value) - ) + # Expecting a datetime object or a string that can be parsed as a datetime + if isinstance(original_value, str): + value = parse_xsd_datetime(original_value) + else: + value = original_value + if not isinstance(value, datetime.datetime): + raise ProvException( + f"Invalid value for attribute {attr}: {original_value}. " + f"Expected a datetime object or a string that can be parsed" + f" as a datetime." + ) else: value = self._auto_literal_conversion(original_value) @@ -465,7 +525,7 @@ def add_attributes(self, attributes): self._attributes[attr].add(value) - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, ProvRecord): return False if self.get_type() != other.get_type(): @@ -475,10 +535,10 @@ def __eq__(self, other): return set(self.attributes) == set(other.attributes) - def __str__(self): + def __str__(self) -> str: return self.get_provn() - def get_provn(self): + def get_provn(self) -> str: """ Returns the PROV-N representation of the record. @@ -493,8 +553,7 @@ def get_provn(self): if self.is_element(): items.append(identifier) else: - # this is a relation - # relations use ; to separate identifiers + # this is a relation, which relation uses a semicolon to separate identifiers relation_id = identifier + "; " # Writing out the formal attributes @@ -533,7 +592,7 @@ def get_provn(self): ) return prov_n - def is_element(self): + def is_element(self) -> bool: """ True, if the record is an element, False otherwise. @@ -541,7 +600,7 @@ def is_element(self): """ return False - def is_relation(self): + def is_relation(self) -> bool: """ True, if the record is a relation, False otherwise. @@ -554,14 +613,19 @@ def is_relation(self): class ProvElement(ProvRecord): """Provenance Element (nodes in the provenance graph).""" - def __init__(self, bundle, identifier, attributes=None): + def __init__( + self, + bundle: ProvBundle, + identifier: Optional[QualifiedName], + attributes: Optional[RecordAttributesArg] = None, + ): if identifier is None: # All types of PROV elements require a valid identifier raise ProvElementIdentifierRequired() super(ProvElement, self).__init__(bundle, identifier, attributes) - def is_element(self): + def is_element(self) -> bool: """ True, if the record is an element, False otherwise. @@ -569,14 +633,14 @@ def is_element(self): """ return True - def __repr__(self): + def __repr__(self) -> str: return "<%s: %s>" % (self.__class__.__name__, self._identifier) class ProvRelation(ProvRecord): """Provenance Relationship (edge between nodes).""" - def is_relation(self): + def is_relation(self) -> bool: """ True, if the record is a relation, False otherwise. @@ -584,7 +648,7 @@ def is_relation(self): """ return True - def __repr__(self): + def __repr__(self) -> str: identifier = " %s" % self._identifier if self._identifier else "" element_1, element_2 = [qname for _, qname in self.formal_attributes[:2]] return "<%s:%s (%s, %s)>" % ( @@ -603,7 +667,12 @@ class ProvEntity(ProvElement): # Convenient assertions that take the current ProvEntity as the first # (formal) argument - def wasGeneratedBy(self, activity, time=None, attributes=None): + def wasGeneratedBy( + self, + activity: Optional[ActivityRef] = None, + time: Optional[DatetimeOrStr] = None, + attributes: Optional[RecordAttributesArg] = None, + ) -> ProvEntity: """ Creates a new generation record to this entity. @@ -618,7 +687,12 @@ def wasGeneratedBy(self, activity, time=None, attributes=None): self._bundle.generation(self, activity, time, other_attributes=attributes) return self - def wasInvalidatedBy(self, activity, time=None, attributes=None): + def wasInvalidatedBy( + self, + activity: Optional[ActivityRef], + time: Optional[DatetimeOrStr] = None, + attributes: Optional[RecordAttributesArg] = None, + ) -> ProvEntity: """ Creates a new invalidation record for this entity. @@ -634,17 +708,22 @@ def wasInvalidatedBy(self, activity, time=None, attributes=None): return self def wasDerivedFrom( - self, usedEntity, activity=None, generation=None, usage=None, attributes=None - ): + self, + usedEntity: EntityRef, + activity: Optional[ActivityRef] = None, + generation: Optional[GenrationRef] = None, + usage: Optional[UsageRef] = None, + attributes: Optional[RecordAttributesArg] = None, + ) -> ProvEntity: """ Creates a new derivation record for this entity from a used entity. :param usedEntity: Entity or a string identifier for the used entity. :param activity: Activity or string identifier of the activity involved in the derivation (default: None). - :param generation: Optionally extra activity to state qualified derivation + :param generation: Optional generation record to state qualified derivation through an internal generation (default: None). - :param usage: Optionally extra entity to state qualified derivation through + :param usage: Optional usage record to state qualified derivation through an internal usage (default: None). :param attributes: Optional other attributes as a dictionary or list of tuples to be added to the record optionally (default: None). @@ -654,7 +733,9 @@ def wasDerivedFrom( ) return self - def wasAttributedTo(self, agent, attributes=None): + def wasAttributedTo( + self, agent: AgentRef, attributes: Optional[RecordAttributesArg] = None + ) -> ProvEntity: """ Creates a new attribution record between this entity and an agent. @@ -666,7 +747,7 @@ def wasAttributedTo(self, agent, attributes=None): self._bundle.attribution(self, agent, other_attributes=attributes) return self - def alternateOf(self, alternate2): + def alternateOf(self, alternate2: EntityRef) -> ProvEntity: """ Creates a new alternate record between this and another entity. @@ -675,7 +756,7 @@ def alternateOf(self, alternate2): self._bundle.alternate(self, alternate2) return self - def specializationOf(self, generalEntity): + def specializationOf(self, generalEntity: EntityRef) -> ProvEntity: """ Creates a new specialisation record for this from a general entity. @@ -684,7 +765,7 @@ def specializationOf(self, generalEntity): self._bundle.specialization(self, generalEntity) return self - def hadMember(self, entity): + def hadMember(self, entity: EntityRef) -> ProvEntity: """ Creates a new membership record to an entity for a collection. @@ -702,7 +783,11 @@ class ProvActivity(ProvElement): _prov_type = PROV_ACTIVITY # Convenient methods - def set_time(self, startTime=None, endTime=None): + def set_time( + self, + startTime: Optional[datetime.datetime] = None, + endTime: Optional[datetime.datetime] = None, + ) -> None: """ Sets the time this activity took place. @@ -718,7 +803,7 @@ def set_time(self, startTime=None, endTime=None): if endTime is not None: self._attributes[PROV_ATTR_ENDTIME] = {endTime} - def get_startTime(self): + def get_startTime(self) -> datetime.datetime | None: """ Returns the time the activity started. @@ -727,7 +812,7 @@ def get_startTime(self): values = self._attributes[PROV_ATTR_STARTTIME] return first(values) if values else None - def get_endTime(self): + def get_endTime(self) -> datetime.datetime | None: """ Returns the time the activity ended. @@ -738,7 +823,12 @@ def get_endTime(self): # Convenient assertions that take the current ProvActivity as the first # (formal) argument - def used(self, entity, time=None, attributes=None): + def used( + self, + entity: EntityRef, + time: Optional[DatetimeOrStr] = None, + attributes: Optional[RecordAttributesArg] = None, + ) -> ProvActivity: """ Creates a new usage record for this activity. @@ -753,7 +843,9 @@ def used(self, entity, time=None, attributes=None): self._bundle.usage(self, entity, time, other_attributes=attributes) return self - def wasInformedBy(self, informant, attributes=None): + def wasInformedBy( + self, informant: ActivityRef, attributes: Optional[RecordAttributesArg] = None + ) -> ProvActivity: """ Creates a new communication record for this activity. @@ -764,13 +856,19 @@ def wasInformedBy(self, informant, attributes=None): self._bundle.communication(self, informant, other_attributes=attributes) return self - def wasStartedBy(self, trigger, starter=None, time=None, attributes=None): + def wasStartedBy( + self, + trigger: Optional[EntityRef], + starter: Optional[ActivityRef] = None, + time: Optional[DatetimeOrStr] = None, + attributes: Optional[RecordAttributesArg] = None, + ) -> ProvActivity: """ Creates a new start record for this activity. The activity did not exist before the start by the trigger. :param trigger: Entity triggering the start of this activity. - :param starter: Optionally extra activity to state a qualified start + :param starter: Optional extra activity to state a qualified start through which the trigger entity for the start is generated (default: None). :param time: Optional time for the start (default: None). @@ -782,7 +880,13 @@ def wasStartedBy(self, trigger, starter=None, time=None, attributes=None): self._bundle.start(self, trigger, starter, time, other_attributes=attributes) return self - def wasEndedBy(self, trigger, ender=None, time=None, attributes=None): + def wasEndedBy( + self, + trigger: Optional[EntityRef], + ender: Optional[ActivityRef] = None, + time: Optional[DatetimeOrStr] = None, + attributes: Optional[RecordAttributesArg] = None, + ) -> ProvActivity: """ Creates a new end record for this activity. @@ -798,7 +902,12 @@ def wasEndedBy(self, trigger, ender=None, time=None, attributes=None): self._bundle.end(self, trigger, ender, time, other_attributes=attributes) return self - def wasAssociatedWith(self, agent, plan=None, attributes=None): + def wasAssociatedWith( + self, + agent: AgentRef, + plan: Optional[EntityRef] = None, + attributes: Optional[RecordAttributesArg] = None, + ) -> ProvActivity: """ Creates a new association record for this activity. @@ -894,7 +1003,12 @@ class ProvAgent(ProvElement): # Convenient assertions that take the current ProvAgent as the first # (formal) argument - def actedOnBehalfOf(self, responsible, activity=None, attributes=None): + def actedOnBehalfOf( + self, + responsible: AgentRef, + activity: Optional[ActivityRef] = None, + attributes: Optional[RecordAttributesArg] = None, + ) -> ProvAgent: """ Creates a new delegation record on behalf of this agent. @@ -946,7 +1060,10 @@ class ProvInfluence(ProvRelation): class ProvSpecialization(ProvRelation): """Provenance Specialization relationship.""" - FORMAL_ATTRIBUTES = (PROV_ATTR_SPECIFIC_ENTITY, PROV_ATTR_GENERAL_ENTITY) + FORMAL_ATTRIBUTES = ( + PROV_ATTR_SPECIFIC_ENTITY, + PROV_ATTR_GENERAL_ENTITY, + ) # type: tuple[QualifiedName, ...] _prov_type = PROV_SPECIALIZATION @@ -1010,10 +1127,15 @@ class ProvMembership(ProvRelation): class NamespaceManager(dict): """Manages namespaces for PROV documents and bundles.""" - parent = None + parent = None # type: Optional[NamespaceManager] """Parent :py:class:`NamespaceManager` this manager one is a child of.""" - def __init__(self, namespaces=None, default=None, parent=None): + def __init__( + self, + namespaces: Optional[NSCollection] = None, + default: Optional[str] = None, + parent: Optional[NamespaceManager] = None, + ): """ Constructor. @@ -1026,21 +1148,22 @@ def __init__(self, namespaces=None, default=None, parent=None): dict.__init__(self) self._default_namespaces = DEFAULT_NAMESPACES self.update(self._default_namespaces) - self._namespaces = {} + self._namespaces = {} # type: dict[str, Namespace] if default is not None: self.set_default_namespace(default) else: - self._default = None + self._default = None # type: Optional[Namespace] self.parent = parent # TODO check if default is in the default namespaces self._anon_id_count = 0 - self._uri_map = dict() - self._rename_map = dict() - self._prefix_renamed_map = dict() - self.add_namespaces(namespaces) + self._uri_map = dict() # type: dict[str, Namespace] + self._rename_map = dict() # type: dict[Namespace, Namespace] + self._prefix_renamed_map = dict() # type: dict[str, Namespace] + if namespaces is not None: + self.add_namespaces(namespaces) - def get_namespace(self, uri): + def get_namespace(self, uri: str) -> Namespace | None: """ Returns the namespace prefix for the given URI. @@ -1052,7 +1175,7 @@ def get_namespace(self, uri): return namespace return None - def get_registered_namespaces(self): + def get_registered_namespaces(self) -> Iterable[Namespace]: """ Returns all registered namespaces. @@ -1060,7 +1183,7 @@ def get_registered_namespaces(self): """ return self._namespaces.values() - def set_default_namespace(self, uri): + def set_default_namespace(self, uri: str) -> None: """ Sets the default namespace to the one of a given URI. @@ -1069,7 +1192,7 @@ def set_default_namespace(self, uri): self._default = Namespace("", uri) self[""] = self._default - def get_default_namespace(self): + def get_default_namespace(self) -> Namespace | None: """ Returns the default namespace. @@ -1077,7 +1200,7 @@ def get_default_namespace(self): """ return self._default - def add_namespace(self, namespace): + def add_namespace(self, namespace: Namespace) -> Namespace: """ Adds a namespace (if not available, yet). @@ -1119,7 +1242,7 @@ def add_namespace(self, namespace): return namespace - def add_namespaces(self, namespaces): + def add_namespaces(self, namespaces: NSCollection) -> None: """ Add multiple namespaces into this manager. @@ -1136,7 +1259,9 @@ def add_namespaces(self, namespaces): for ns in namespaces: self.add_namespace(ns) - def valid_qualified_name(self, qname): + def valid_qualified_name( + self, qname: QualifiedNameCandidate + ) -> QualifiedName | None: """ Resolves an identifier to a valid qualified name. @@ -1184,12 +1309,12 @@ def valid_qualified_name(self, qname): # returning the new qname return new_qname - # Trying to guess from here + # Trying to generate a valid qualified name from here if not isinstance(qname, (str, Identifier)): - # Only proceed for string or URI values + # Only proceed with a string or URI value return None - # Try to generate a Qualified Name - str_value = qname.uri if isinstance(qname, Identifier) else str(qname) + # Extract the URI string value if it is an identifier + str_value = qname.uri if isinstance(qname, Identifier) else qname if str_value.startswith("_:"): # this is a blank node ID return None @@ -1203,14 +1328,15 @@ def valid_qualified_name(self, qname): # return a new QualifiedName return self._prefix_renamed_map[prefix][local_part] else: - # treat as a URI (with the first part as its scheme) - # check if the URI can be compacted + # assuming it is a URI (with the first part as its scheme) + # check if the URI can be compacted by any of the registered namespaces for namespace in self.values(): if str_value.startswith(namespace.uri): # create a QName with the namespace return namespace[str_value.replace(namespace.uri, "")] - elif self._default: - # create and return an identifier in the default namespace + elif self._default and isinstance(qname, str): + # no colon in the identifier and a default namespace is defined, + # create and return a qualified name in the default namespace return self._default[qname] if self.parent: @@ -1221,7 +1347,7 @@ def valid_qualified_name(self, qname): # Default to FAIL return None - def get_anonymous_identifier(self, local_prefix="id"): + def get_anonymous_identifier(self, local_prefix: str = "id") -> Identifier: """ Returns an anonymous identifier (without a namespace prefix). @@ -1232,7 +1358,7 @@ def get_anonymous_identifier(self, local_prefix="id"): self._anon_id_count += 1 return Identifier("_:%s%d" % (local_prefix, self._anon_id_count)) - def _get_unused_prefix(self, original_prefix): + def _get_unused_prefix(self, original_prefix: str) -> str: if original_prefix not in self: return original_prefix count = 1 @@ -1247,7 +1373,13 @@ def _get_unused_prefix(self, original_prefix): class ProvBundle(object): """PROV Bundle""" - def __init__(self, records=None, identifier=None, namespaces=None, document=None): + def __init__( + self, + records: Optional[Iterable[ProvRecord]] = None, + identifier: Optional[QualifiedName] = None, + namespaces: Optional[NSCollection] = None, + document: Optional["ProvDocument"] = None, + ): """ Constructor. @@ -1260,21 +1392,21 @@ def __init__(self, records=None, identifier=None, namespaces=None, document=None """ # Initializing bundle-specific attributes self._identifier = identifier - self._records = list() - self._id_map = defaultdict(list) + self._records = list() # type: list[ProvRecord] + self._id_map = defaultdict(list) # type: dict[QualifiedName, list[ProvRecord]] self._document = document self._namespaces = NamespaceManager( namespaces, parent=(document._namespaces if document is not None else None) - ) + ) # type: NamespaceManager if records: for record in records: self.add_record(record) - def __repr__(self): + def __repr__(self) -> str: return "<%s: %s>" % (self.__class__.__name__, self._identifier) @property - def namespaces(self): + def namespaces(self) -> set[Namespace]: """ Returns the set of registered namespaces. @@ -1283,7 +1415,7 @@ def namespaces(self): return set(self._namespaces.get_registered_namespaces()) @property - def default_ns_uri(self): + def default_ns_uri(self) -> str | None: """ Returns the default namespace's URI, if any. @@ -1293,7 +1425,7 @@ def default_ns_uri(self): return default_ns.uri if default_ns else None @property - def document(self): + def document(self) -> ProvDocument | None: """ Returns the parent document, if any. @@ -1302,21 +1434,21 @@ def document(self): return self._document @property - def identifier(self): + def identifier(self) -> QualifiedName | None: """ Returns the bundle's identifier """ return self._identifier @property - def records(self): + def records(self) -> list[ProvRecord]: """ Returns the list of all records in the current bundle """ return list(self._records) # Bundle configurations - def set_default_namespace(self, uri): + def set_default_namespace(self, uri: str) -> None: """ Sets the default namespace through a given URI. @@ -1324,7 +1456,7 @@ def set_default_namespace(self, uri): """ self._namespaces.set_default_namespace(uri) - def get_default_namespace(self): + def get_default_namespace(self) -> Namespace | None: """ Returns the default namespace. @@ -1332,7 +1464,9 @@ def get_default_namespace(self): """ return self._namespaces.get_default_namespace() - def add_namespace(self, namespace_or_prefix, uri=None): + def add_namespace( + self, namespace_or_prefix: Namespace | str, uri: Optional[str] = None + ) -> Namespace: """ Adds a namespace (if not available, yet). @@ -1341,12 +1475,17 @@ def add_namespace(self, namespace_or_prefix, uri=None): :param uri: Namespace URI (default: None). Must be present if only a prefix is given in the previous parameter. """ - if uri is None: + if isinstance(namespace_or_prefix, Namespace): return self._namespaces.add_namespace(namespace_or_prefix) else: - return self._namespaces.add_namespace(Namespace(namespace_or_prefix, uri)) + if uri is not None: + return self._namespaces.add_namespace( + Namespace(namespace_or_prefix, uri) + ) + else: + raise ProvException("Cannot add a namespace without a URI") - def get_registered_namespaces(self): + def get_registered_namespaces(self) -> Iterable[Namespace]: """ Returns all registered namespaces. @@ -1354,10 +1493,27 @@ def get_registered_namespaces(self): """ return self._namespaces.get_registered_namespaces() - def valid_qualified_name(self, identifier): + def valid_qualified_name( + self, identifier: QualifiedNameCandidate + ) -> Optional[QualifiedName]: return self._namespaces.valid_qualified_name(identifier) - def get_records(self, class_or_type_or_tuple=None): + def mandatory_valid_qname( + self, identifier: QualifiedNameCandidate + ) -> QualifiedName: + """ + Determines if the given identifier is a valid qualified name and returns it. + If the provided identifier is not valid, an exception is raised. + """ + valid_qname = self.valid_qualified_name(identifier) + if valid_qname is not None: + return valid_qname + else: + raise ProvExceptionInvalidQualifiedName(identifier) + + def get_records( + self, class_or_type_or_tuple: Optional[type | tuple[type]] = None + ) -> Iterable[ProvRecord]: """ Returns all records. Returned records may be filtered by the optional argument. @@ -1367,35 +1523,24 @@ def get_records(self, class_or_type_or_tuple=None): record using the `isinstance` check on the record. :return: List of :py:class:`ProvRecord` objects. """ - results = list(self._records) + results = list(self._records) # make a (shallow) copy of the record list if class_or_type_or_tuple: return filter(lambda rec: isinstance(rec, class_or_type_or_tuple), results) else: return results - def get_record(self, identifier): + def get_record(self, identifier: QualifiedNameCandidate) -> list[ProvRecord]: """ - Returns a specific record matching a given identifier. + Returns one or more records matching a given identifier. :param identifier: Record identifier. - :return: :py:class:`ProvRecord` + :return: List of :py:class:`ProvRecord` """ - # TODO: This will not work with the new _id_map, which is now a map of - # (QName, list(ProvRecord)) - if identifier is None: - return None valid_id = self.valid_qualified_name(identifier) - try: - return self._id_map[valid_id] - except KeyError: - # looking up the parent bundle - if self.is_bundle(): - return self.document.get_record(valid_id) - else: - return None + return list(self._id_map[valid_id]) if valid_id is not None else [] # Miscellaneous functions - def is_document(self): + def is_document(self) -> bool: """ `True` if the object is a document, `False` otherwise. @@ -1403,7 +1548,7 @@ def is_document(self): """ return False - def is_bundle(self): + def is_bundle(self) -> bool: """ `True` if the object is a bundle, `False` otherwise. @@ -1411,7 +1556,7 @@ def is_bundle(self): """ return True - def has_bundles(self): + def has_bundles(self) -> bool: """ `True` if the object has at least one bundle, `False` otherwise. @@ -1420,15 +1565,15 @@ def has_bundles(self): return False @property - def bundles(self): + def bundles(self) -> Iterable[ProvBundle]: """ Returns bundles contained in the document :return: Iterable of :py:class:`ProvBundle`. """ - return frozenset() + raise ProvException("A PROV bundle does not contain sub-bundles") - def get_provn(self, _indent_level=0): + def get_provn(self, _indent_level: int = 0) -> str: """ Returns the PROV-N representation of the bundle. @@ -1471,7 +1616,7 @@ def get_provn(self, _indent_level=0): ) return provn_str - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, ProvBundle): return False other_records = set(other.get_records()) @@ -1495,13 +1640,13 @@ def __eq__(self, other): return False return True - def __ne__(self, other): + def __ne__(self, other: Any) -> bool: return not (self == other) - __hash__ = None + __hash__ = None # type: ignore - # Transformations - def _unified_records(self): + # type: ignore # type: ignore # Transformations + def _unified_records(self) -> list[ProvRecord]: """Returns a list of unified records.""" # TODO: Check unification rules in the PROV-CONSTRAINTS document # This method simply merges the records having the same name @@ -1533,7 +1678,7 @@ def _unified_records(self): unified_records.append(record) return unified_records - def unified(self): + def unified(self) -> ProvBundle: """ Unifies all records in the bundle that haves same identifiers @@ -1543,7 +1688,7 @@ def unified(self): bundle = ProvBundle(records=unified_records, identifier=self.identifier) return bundle - def update(self, other): + def update(self, other: ProvBundle) -> None: """ Append all the records of the *other* ProvBundle into this bundle. @@ -1567,7 +1712,7 @@ def update(self, other): ) # Provenance statements - def _add_record(self, record): + def _add_record(self, record: ProvRecord) -> None: # IMPORTANT: All records need to be added to a bundle/document via this # method. Otherwise, the _id_map dict will not be correctly updated identifier = record.identifier @@ -1576,8 +1721,12 @@ def _add_record(self, record): self._records.append(record) def new_record( - self, record_type, identifier, attributes=None, other_attributes=None - ): + self, + record_type: QualifiedName, + identifier: OptionalID, + attributes: Optional[RecordAttributesArg] = None, + other_attributes: Optional[RecordAttributesArg] = None, + ) -> ProvRecord: """ Creates a new record. @@ -1588,7 +1737,7 @@ def new_record( :param other_attributes: Optional other attributes as a dictionary or list of tuples to be added to the record optionally (default: None). """ - attr_list = [] + attr_list = [] # type: list[tuple[QualifiedNameCandidate, Any]] if attributes: if isinstance(attributes, dict): attr_list.extend((attr, value) for attr, value in attributes.items()) @@ -1601,13 +1750,14 @@ def new_record( if isinstance(other_attributes, dict) else other_attributes ) - new_record = PROV_REC_CLS[record_type]( - self, self.valid_qualified_name(identifier), attr_list + record_identifier = ( + self.valid_qualified_name(identifier) if identifier else None ) + new_record = PROV_REC_CLS[record_type](self, record_identifier, attr_list) self._add_record(new_record) return new_record - def add_record(self, record): + def add_record(self, record: ProvRecord) -> ProvRecord: """ Adds a new record that to the bundle. @@ -1620,7 +1770,11 @@ def add_record(self, record): record.extra_attributes, ) - def entity(self, identifier, other_attributes=None): + def entity( + self, + identifier: QualifiedNameCandidate, + other_attributes: Optional[RecordAttributesArg] = None, + ) -> ProvEntity: """ Creates a new entity. @@ -1628,9 +1782,15 @@ def entity(self, identifier, other_attributes=None): :param other_attributes: Optional other attributes as a dictionary or list of tuples to be added to the record optionally (default: None). """ - return self.new_record(PROV_ENTITY, identifier, None, other_attributes) + return self.new_record(PROV_ENTITY, identifier, None, other_attributes) # type: ignore - def activity(self, identifier, startTime=None, endTime=None, other_attributes=None): + def activity( + self, + identifier: QualifiedNameCandidate, + startTime: Optional[DatetimeOrStr] = None, + endTime: Optional[DatetimeOrStr] = None, + other_attributes: Optional[RecordAttributesArg] = None, + ) -> ProvActivity: """ Creates a new activity. @@ -1644,19 +1804,25 @@ def activity(self, identifier, startTime=None, endTime=None, other_attributes=No :param other_attributes: Optional other attributes as a dictionary or list of tuples to be added to the record optionally (default: None). """ + attributes = { + PROV_ATTR_STARTTIME: _ensure_datetime(startTime), + PROV_ATTR_ENDTIME: _ensure_datetime(endTime), + } # type: dict[QualifiedNameCandidate, Any] return self.new_record( PROV_ACTIVITY, identifier, - { - PROV_ATTR_STARTTIME: _ensure_datetime(startTime), - PROV_ATTR_ENDTIME: _ensure_datetime(endTime), - }, + attributes, other_attributes, - ) + ) # type: ignore def generation( - self, entity, activity=None, time=None, identifier=None, other_attributes=None - ): + self, + entity: EntityRef, + activity: Optional[ActivityRef] = None, + time: Optional[DatetimeOrStr] = None, + identifier: OptionalID = None, + other_attributes: Optional[RecordAttributesArg] = None, + ) -> ProvRecord: """ Creates a new generation record for an entity. @@ -1670,20 +1836,26 @@ def generation( :param other_attributes: Optional other attributes as a dictionary or list of tuples to be added to the record optionally (default: None). """ + attributes = { + PROV_ATTR_ENTITY: entity, + PROV_ATTR_ACTIVITY: activity, + PROV_ATTR_TIME: _ensure_datetime(time), + } # type: dict[QualifiedNameCandidate, Any] return self.new_record( PROV_GENERATION, identifier, - { - PROV_ATTR_ENTITY: entity, - PROV_ATTR_ACTIVITY: activity, - PROV_ATTR_TIME: _ensure_datetime(time), - }, + attributes, other_attributes, ) def usage( - self, activity, entity=None, time=None, identifier=None, other_attributes=None - ): + self, + activity: ActivityRef, + entity: Optional[EntityRef] = None, + time: Optional[DatetimeOrStr] = None, + identifier: OptionalID = None, + other_attributes: Optional[RecordAttributesArg] = None, + ) -> ProvUsage: """ Creates a new usage record for an activity. @@ -1697,26 +1869,27 @@ def usage( :param other_attributes: Optional other attributes as a dictionary or list of tuples to be added to the record optionally (default: None). """ + attributes = { + PROV_ATTR_ACTIVITY: activity, + PROV_ATTR_ENTITY: entity, + PROV_ATTR_TIME: _ensure_datetime(time), + } # type: dict[QualifiedNameCandidate, Any] return self.new_record( PROV_USAGE, identifier, - { - PROV_ATTR_ACTIVITY: activity, - PROV_ATTR_ENTITY: entity, - PROV_ATTR_TIME: _ensure_datetime(time), - }, + attributes, other_attributes, - ) + ) # type: ignore def start( self, - activity, - trigger=None, - starter=None, - time=None, - identifier=None, - other_attributes=None, - ): + activity: ActivityRef, + trigger: Optional[EntityRef] = None, + starter: Optional[ActivityRef] = None, + time: Optional[DatetimeOrStr] = None, + identifier: OptionalID = None, + other_attributes: Optional[RecordAttributesArg] = None, + ) -> ProvStart: """ Creates a new start record for an activity. @@ -1732,27 +1905,28 @@ def start( :param other_attributes: Optional other attributes as a dictionary or list of tuples to be added to the record optionally (default: None). """ + attributes = { + PROV_ATTR_ACTIVITY: activity, + PROV_ATTR_TRIGGER: trigger, + PROV_ATTR_STARTER: starter, + PROV_ATTR_TIME: _ensure_datetime(time), + } # type: dict[QualifiedNameCandidate, Any] return self.new_record( PROV_START, identifier, - { - PROV_ATTR_ACTIVITY: activity, - PROV_ATTR_TRIGGER: trigger, - PROV_ATTR_STARTER: starter, - PROV_ATTR_TIME: _ensure_datetime(time), - }, + attributes, other_attributes, - ) + ) # type: ignore def end( self, - activity, - trigger=None, - ender=None, - time=None, - identifier=None, - other_attributes=None, - ): + activity: ActivityRef, + trigger: Optional[EntityRef] = None, + ender: Optional[ActivityRef] = None, + time: Optional[DatetimeOrStr] = None, + identifier: OptionalID = None, + other_attributes: Optional[RecordAttributesArg] = None, + ) -> ProvEnd: """ Creates a new end record for an activity. @@ -1768,21 +1942,27 @@ def end( :param other_attributes: Optional other attributes as a dictionary or list of tuples to be added to the record optionally (default: None). """ + attributes = { + PROV_ATTR_ACTIVITY: activity, + PROV_ATTR_TRIGGER: trigger, + PROV_ATTR_ENDER: ender, + PROV_ATTR_TIME: _ensure_datetime(time), + } # type: dict[QualifiedNameCandidate, Any] return self.new_record( PROV_END, identifier, - { - PROV_ATTR_ACTIVITY: activity, - PROV_ATTR_TRIGGER: trigger, - PROV_ATTR_ENDER: ender, - PROV_ATTR_TIME: _ensure_datetime(time), - }, + attributes, other_attributes, - ) + ) # type: ignore def invalidation( - self, entity, activity=None, time=None, identifier=None, other_attributes=None - ): + self, + entity: EntityRef, + activity: Optional[ActivityRef] = None, + time: Optional[DatetimeOrStr] = None, + identifier: OptionalID = None, + other_attributes: Optional[RecordAttributesArg] = None, + ) -> ProvInvalidation: """ Creates a new invalidation record for an entity. @@ -1792,24 +1972,29 @@ def invalidation( :param time: Optional time for the invalidation (default: None). Either a :py:class:`datetime.datetime` object or a string that can be parsed by :py:func:`dateutil.parser`. - :param identifier: Identifier for new invalidation record. + :param identifier: Identifier for the new invalidation record. :param other_attributes: Optional other attributes as a dictionary or list of tuples to be added to the record optionally (default: None). """ + attributes = { + PROV_ATTR_ENTITY: entity, + PROV_ATTR_ACTIVITY: activity, + PROV_ATTR_TIME: _ensure_datetime(time), + } # type: dict[QualifiedNameCandidate, Any] return self.new_record( PROV_INVALIDATION, identifier, - { - PROV_ATTR_ENTITY: entity, - PROV_ATTR_ACTIVITY: activity, - PROV_ATTR_TIME: _ensure_datetime(time), - }, + attributes, other_attributes, - ) + ) # type: ignore def communication( - self, informed, informant, identifier=None, other_attributes=None - ): + self, + informed: ActivityRef, + informant: ActivityRef, + identifier: OptionalID = None, + other_attributes: Optional[RecordAttributesArg] = None, + ) -> ProvCommunication: """ Creates a new communication record for an entity. @@ -1819,14 +2004,22 @@ def communication( :param other_attributes: Optional other attributes as a dictionary or list of tuples to be added to the record optionally (default: None). """ + attributes = { + PROV_ATTR_INFORMED: informed, + PROV_ATTR_INFORMANT: informant, + } # type: dict[QualifiedNameCandidate, Any] return self.new_record( PROV_COMMUNICATION, identifier, - {PROV_ATTR_INFORMED: informed, PROV_ATTR_INFORMANT: informant}, + attributes, other_attributes, - ) + ) # type: ignore - def agent(self, identifier, other_attributes=None): + def agent( + self, + identifier: QualifiedNameCandidate, + other_attributes: Optional[RecordAttributesArg] = None, + ) -> ProvAgent: """ Creates a new agent. @@ -1834,9 +2027,15 @@ def agent(self, identifier, other_attributes=None): :param other_attributes: Optional other attributes as a dictionary or list of tuples to be added to the record optionally (default: None). """ - return self.new_record(PROV_AGENT, identifier, None, other_attributes) + return self.new_record(PROV_AGENT, identifier, None, other_attributes) # type: ignore - def attribution(self, entity, agent, identifier=None, other_attributes=None): + def attribution( + self, + entity: EntityRef, + agent: AgentRef, + identifier: OptionalID = None, + other_attributes: Optional[RecordAttributesArg] = None, + ) -> ProvAttribution: """ Creates a new attribution record between an entity and an agent. @@ -1848,16 +2047,25 @@ def attribution(self, entity, agent, identifier=None, other_attributes=None): :param other_attributes: Optional other attributes as a dictionary or list of tuples to be added to the record optionally (default: None). """ + attributes = { + PROV_ATTR_ENTITY: entity, + PROV_ATTR_AGENT: agent, + } # type: dict[QualifiedNameCandidate, Any] return self.new_record( PROV_ATTRIBUTION, identifier, - {PROV_ATTR_ENTITY: entity, PROV_ATTR_AGENT: agent}, + attributes, other_attributes, - ) + ) # type: ignore def association( - self, activity, agent=None, plan=None, identifier=None, other_attributes=None - ): + self, + activity: ActivityRef, + agent: Optional[AgentRef] = None, + plan: Optional[EntityRef] = None, + identifier: OptionalID = None, + other_attributes: Optional[RecordAttributesArg] = None, + ) -> ProvAssociation: """ Creates a new association record for an activity. @@ -1870,25 +2078,26 @@ def association( :param other_attributes: Optional other attributes as a dictionary or list of tuples to be added to the record optionally (default: None). """ + attributes = { + PROV_ATTR_ACTIVITY: activity, + PROV_ATTR_AGENT: agent, + PROV_ATTR_PLAN: plan, + } # type: dict[QualifiedNameCandidate, Any] return self.new_record( PROV_ASSOCIATION, identifier, - { - PROV_ATTR_ACTIVITY: activity, - PROV_ATTR_AGENT: agent, - PROV_ATTR_PLAN: plan, - }, + attributes, other_attributes, - ) + ) # type: ignore def delegation( self, - delegate, - responsible, - activity=None, - identifier=None, - other_attributes=None, - ): + delegate: AgentRef, + responsible: AgentRef, + activity: Optional[ActivityRef] = None, + identifier: OptionalID = None, + other_attributes: Optional[RecordAttributesArg] = None, + ) -> ProvDelegation: """ Creates a new delegation record on behalf of an agent. @@ -1901,18 +2110,25 @@ def delegation( :param other_attributes: Optional other attributes as a dictionary or list of tuples to be added to the record optionally (default: None). """ + attributes = { + PROV_ATTR_DELEGATE: delegate, + PROV_ATTR_RESPONSIBLE: responsible, + PROV_ATTR_ACTIVITY: activity, + } # type: dict[QualifiedNameCandidate, Any] return self.new_record( PROV_DELEGATION, identifier, - { - PROV_ATTR_DELEGATE: delegate, - PROV_ATTR_RESPONSIBLE: responsible, - PROV_ATTR_ACTIVITY: activity, - }, + attributes, other_attributes, - ) + ) # type: ignore - def influence(self, influencee, influencer, identifier=None, other_attributes=None): + def influence( + self, + influencee: EntityRef | ActivityRef | AgentRef, + influencer: EntityRef | ActivityRef | AgentRef, + identifier: OptionalID = None, + other_attributes: Optional[RecordAttributesArg] = None, + ) -> ProvInfluence: """ Creates a new influence record between two entities, activities or agents. @@ -1924,23 +2140,27 @@ def influence(self, influencee, influencer, identifier=None, other_attributes=No :param other_attributes: Optional other attributes as a dictionary or list of tuples to be added to the record optionally (default: None). """ + attributes = { + PROV_ATTR_INFLUENCEE: influencee, + PROV_ATTR_INFLUENCER: influencer, + } # type: dict[QualifiedNameCandidate, Any] return self.new_record( PROV_INFLUENCE, identifier, - {PROV_ATTR_INFLUENCEE: influencee, PROV_ATTR_INFLUENCER: influencer}, + attributes, other_attributes, - ) + ) # type: ignore def derivation( self, - generatedEntity, - usedEntity, - activity=None, - generation=None, - usage=None, - identifier=None, - other_attributes=None, - ): + generatedEntity: EntityRef, + usedEntity: EntityRef, + activity: Optional[ActivityRef] = None, + generation: Optional[GenrationRef] = None, + usage: Optional[UsageRef] = None, + identifier: OptionalID = None, + other_attributes: Optional[RecordAttributesArg] = None, + ) -> ProvDerivation: """ Creates a new derivation record for a generated entity from a used entity. @@ -1963,21 +2183,21 @@ def derivation( PROV_ATTR_ACTIVITY: activity, PROV_ATTR_GENERATION: generation, PROV_ATTR_USAGE: usage, - } + } # type: dict[QualifiedNameCandidate, Any] return self.new_record( PROV_DERIVATION, identifier, attributes, other_attributes - ) + ) # type: ignore def revision( self, - generatedEntity, - usedEntity, - activity=None, - generation=None, - usage=None, - identifier=None, - other_attributes=None, - ): + generatedEntity: EntityRef, + usedEntity: EntityRef, + activity: Optional[ActivityRef] = None, + generation: Optional[GenrationRef] = None, + usage: Optional[UsageRef] = None, + identifier: OptionalID = None, + other_attributes: Optional[RecordAttributesArg] = None, + ) -> ProvDerivation: """ Creates a new revision record for a generated entity from a used entity. @@ -2008,14 +2228,14 @@ def revision( def quotation( self, - generatedEntity, - usedEntity, - activity=None, - generation=None, - usage=None, - identifier=None, - other_attributes=None, - ): + generatedEntity: EntityRef, + usedEntity: EntityRef, + activity: Optional[ActivityRef] = None, + generation: Optional[GenrationRef] = None, + usage: Optional[UsageRef] = None, + identifier: OptionalID = None, + other_attributes: Optional[RecordAttributesArg] = None, + ) -> ProvDerivation: """ Creates a new quotation record for a generated entity from a used entity. @@ -2046,14 +2266,14 @@ def quotation( def primary_source( self, - generatedEntity, - usedEntity, - activity=None, - generation=None, - usage=None, - identifier=None, - other_attributes=None, - ): + generatedEntity: EntityRef, + usedEntity: EntityRef, + activity: Optional[ActivityRef] = None, + generation: Optional[GenrationRef] = None, + usage: Optional[UsageRef] = None, + identifier: OptionalID = None, + other_attributes: Optional[RecordAttributesArg] = None, + ) -> ProvDerivation: """ Creates a new primary source record for a generated entity from a used entity. @@ -2081,9 +2301,11 @@ def primary_source( other_attributes, ) record.add_asserted_type(PROV["PrimarySource"]) - return record + return record # type: ignore - def specialization(self, specificEntity, generalEntity): + def specialization( + self, specificEntity: EntityRef, generalEntity: EntityRef + ) -> ProvSpecialization: """ Creates a new specialisation record for a specific from a general entity. @@ -2092,16 +2314,17 @@ def specialization(self, specificEntity, generalEntity): :param generalEntity: Entity or a string identifier for the general entity (relationship destination). """ + attributes = { + PROV_ATTR_SPECIFIC_ENTITY: specificEntity, + PROV_ATTR_GENERAL_ENTITY: generalEntity, + } # type: dict[QualifiedNameCandidate, Any] return self.new_record( PROV_SPECIALIZATION, None, - { - PROV_ATTR_SPECIFIC_ENTITY: specificEntity, - PROV_ATTR_GENERAL_ENTITY: generalEntity, - }, - ) + attributes, + ) # type: ignore - def alternate(self, alternate1, alternate2): + def alternate(self, alternate1: EntityRef, alternate2: EntityRef) -> ProvAlternate: """ Creates a new alternate record between two entities. @@ -2110,13 +2333,19 @@ def alternate(self, alternate1, alternate2): :param alternate2: Entity or a string identifier for the second entity (relationship destination). """ + attributes = { + PROV_ATTR_ALTERNATE1: alternate1, + PROV_ATTR_ALTERNATE2: alternate2, + } # type: dict[QualifiedNameCandidate, Any] return self.new_record( PROV_ALTERNATE, None, - {PROV_ATTR_ALTERNATE1: alternate1, PROV_ATTR_ALTERNATE2: alternate2}, - ) + attributes, + ) # type: ignore - def mention(self, specificEntity, generalEntity, bundle): + def mention( + self, specificEntity: EntityRef, generalEntity: EntityRef, bundle: EntityRef + ) -> ProvMention: """ Creates a new mention record for a specific from a general entity. @@ -2126,17 +2355,22 @@ def mention(self, specificEntity, generalEntity, bundle): (relationship destination). :param bundle: XXX """ + attributes = { + PROV_ATTR_SPECIFIC_ENTITY: specificEntity, + PROV_ATTR_GENERAL_ENTITY: generalEntity, + PROV_ATTR_BUNDLE: bundle, + } # type: dict[QualifiedNameCandidate, Any] return self.new_record( PROV_MENTION, None, - { - PROV_ATTR_SPECIFIC_ENTITY: specificEntity, - PROV_ATTR_GENERAL_ENTITY: generalEntity, - PROV_ATTR_BUNDLE: bundle, - }, - ) + attributes, + ) # type: ignore - def collection(self, identifier, other_attributes=None): + def collection( + self, + identifier: QualifiedNameCandidate, + other_attributes: Optional[RecordAttributesArg] = None, + ) -> ProvEntity: """ Creates a new collection record for a particular record. @@ -2146,29 +2380,33 @@ def collection(self, identifier, other_attributes=None): """ record = self.new_record(PROV_ENTITY, identifier, None, other_attributes) record.add_asserted_type(PROV["Collection"]) - return record + return record # type: ignore - def membership(self, collection, entity): + def membership(self, collection: EntityRef, entity: EntityRef) -> ProvMembership: """ Creates a new membership record for an entity to a collection. :param collection: Collection the entity is to be added to. :param entity: Entity to be added to the collection. """ + attributes = { + PROV_ATTR_COLLECTION: collection, + PROV_ATTR_ENTITY: entity, + } # type: dict[QualifiedNameCandidate, Any] return self.new_record( PROV_MEMBERSHIP, None, - {PROV_ATTR_COLLECTION: collection, PROV_ATTR_ENTITY: entity}, - ) + attributes, + ) # type: ignore def plot( self, - filename=None, - show_nary=True, - use_labels=False, - show_element_attributes=True, - show_relation_attributes=True, - ): + filename: Optional[PathLike] = None, + show_nary: bool = True, + use_labels: bool = False, + show_element_attributes: bool = True, + show_relation_attributes: bool = True, + ) -> None: """ Convenience function to plot a PROV document. @@ -2191,7 +2429,7 @@ def plot( from prov import dot if filename: - format = os.path.splitext(filename)[-1].lower().strip(os.path.extsep) + format = str(os.path.splitext(filename))[-1].lower().strip(os.path.extsep) else: format = "png" format = format.lower() @@ -2214,9 +2452,9 @@ def plot( fh.write(buf.read()) else: # Use matplotlib to show the image as it likely is more - # widespread then PIL and works nicely in the ipython notebook. - import matplotlib.pylab as plt - import matplotlib.image as mpimg + # widespread than PIL and works nicely in the ipython notebook. + import matplotlib.pylab as plt # type: ignore + import matplotlib.image as mpimg # type: ignore max_size = 30 @@ -2262,7 +2500,11 @@ def plot( class ProvDocument(ProvBundle): """Provenance Document.""" - def __init__(self, records=None, namespaces=None): + def __init__( + self, + records: Optional[Iterable[ProvRecord]] = None, + namespaces: Optional[NSCollection] = None, + ): """ Constructor. @@ -2273,12 +2515,12 @@ def __init__(self, records=None, namespaces=None): ProvBundle.__init__( self, records=records, identifier=None, namespaces=namespaces ) - self._bundles = dict() + self._bundles = dict() # type: dict[QualifiedName, ProvBundle] - def __repr__(self): + def __repr__(self) -> str: return "" - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: if not isinstance(other, ProvDocument): return False # Comparing the documents' content @@ -2296,7 +2538,7 @@ def __eq__(self, other): # Everything is the same return True - def is_document(self): + def is_document(self) -> bool: """ `True` if the object is a document, `False` otherwise. @@ -2304,7 +2546,7 @@ def is_document(self): """ return True - def is_bundle(self): + def is_bundle(self) -> bool: """ `True` if the object is a bundle, `False` otherwise. @@ -2312,7 +2554,7 @@ def is_bundle(self): """ return False - def has_bundles(self): + def has_bundles(self) -> bool: """ `True` if the object has at least one bundle, `False` otherwise. @@ -2321,7 +2563,7 @@ def has_bundles(self): return len(self._bundles) > 0 @property - def bundles(self): + def bundles(self) -> Iterable[ProvBundle]: """ Returns bundles contained in the document @@ -2330,7 +2572,7 @@ def bundles(self): return self._bundles.values() # Transformations - def flattened(self): + def flattened(self) -> ProvDocument: """ Flattens the document by moving all the records in its bundles up to the document level. @@ -2350,9 +2592,9 @@ def flattened(self): # returning the same document return self - def unified(self): + def unified(self) -> ProvDocument: """ - Returns a new document containing all records having same identifiers + Returns a new document containing all records having the same identifiers unified (including those inside bundles). :return: :py:class:`ProvDocument` @@ -2364,10 +2606,10 @@ def unified(self): document.add_bundle(unified_bundle) return document - def update(self, other): + def update(self, other: ProvBundle) -> None: """ Append all the records of the *other* document/bundle into this document. - Bundles having same identifiers will be merged. + Bundles having the same identifiers will be merged. :param other: The other document/bundle whose records to be appended. :type other: :py:class:`ProvDocument` or :py:class:`ProvBundle` @@ -2378,10 +2620,12 @@ def update(self, other): self.add_record(record) if other.has_bundles(): for bundle in other.bundles: + bundle_id = bundle.identifier + assert bundle_id is not None if bundle.identifier in self._bundles: self._bundles[bundle.identifier].update(bundle) else: - new_bundle = self.bundle(bundle.identifier) + new_bundle = self.bundle(bundle_id) new_bundle.update(bundle) else: raise ProvException( @@ -2390,7 +2634,9 @@ def update(self, other): ) # Bundle operations - def add_bundle(self, bundle, identifier=None): + def add_bundle( + self, bundle: ProvBundle, identifier: Optional[QualifiedName] = None + ) -> None: """ Add a bundle to the current document. @@ -2425,7 +2671,7 @@ def add_bundle(self, bundle, identifier=None): # Link the bundle namespace manager to the document's bundle._namespaces.parent = self._namespaces - valid_id = bundle.valid_qualified_name(identifier) + valid_id = bundle.mandatory_valid_qname(identifier) # IMPORTANT: Rewriting the bundle identifier for consistency bundle._identifier = valid_id @@ -2435,7 +2681,7 @@ def add_bundle(self, bundle, identifier=None): self._bundles[valid_id] = bundle bundle._document = self - def bundle(self, identifier): + def bundle(self, identifier: QualifiedNameCandidate) -> ProvBundle: """ Returns a new bundle from the current document. @@ -2458,7 +2704,12 @@ def bundle(self, identifier): return b # Serializing and deserializing - def serialize(self, destination=None, format="json", **args): + def serialize( + self, + destination: Optional[io.IOBase | PathLike] = None, + format: str = "json", + **args: Any, + ) -> str | None: """ Serialize the :py:class:`ProvDocument` to the destination. @@ -2475,20 +2726,21 @@ def serialize(self, destination=None, format="json", **args): """ serializer = serializers.get(format)(self) if destination is None: - stream = io.StringIO() - serializer.serialize(stream, **args) - return stream.getvalue() - if hasattr(destination, "write"): + buffer = io.StringIO() + serializer.serialize(buffer, **args) + return buffer.getvalue() + + if isinstance(destination, IOBase): stream = destination serializer.serialize(stream, **args) else: - location = destination + location = str(destination) scheme, netloc, path, params, _query, fragment = urlparse(location) if netloc != "": print( "WARNING: not saving as location " + "is not a local file reference" ) - return + return None fd, name = tempfile.mkstemp() stream = os.fdopen(fd, "wb") serializer.serialize(stream, **args) @@ -2498,9 +2750,15 @@ def serialize(self, destination=None, format="json", **args): else: shutil.copy(name, path) os.remove(name) + return None @staticmethod - def deserialize(source=None, content=None, format="json", **args): + def deserialize( + source: Optional[io.IOBase | PathLike] = None, + content: Optional[str | bytes] = None, + format: str = "json", + **args: Any, + ) -> ProvDocument: """ Deserialize the :py:class:`ProvDocument` from source (a stream or a file path) or directly from a string content. @@ -2529,14 +2787,18 @@ def deserialize(source=None, content=None, format="json", **args): return serializer.deserialize(stream, **args) if source is not None: - if hasattr(source, "read"): + if isinstance(source, io.IOBase): return serializer.deserialize(source, **args) else: with open(source) as f: return serializer.deserialize(f, **args) + raise TypeError("Either source or content must be provided") + -def sorted_attributes(element, attributes): +def sorted_attributes( + element: QualifiedName, attributes: Iterable[NameValuePair] +) -> list[NameValuePair]: """ Helper function sorting attributes into the order required by PROV-XML. @@ -2555,8 +2817,8 @@ def sorted_attributes(element, attributes): # sorting. We now interpret it as sorting by tag including the prefix # first and then sorting by the text, also including the namespace # prefix if given. - def sort_fct(x): - return (str(x[0]), str(x[1].value if hasattr(x[1], "value") else x[1])) + def sort_fct(x: NameValuePair) -> tuple[str, str]: + return str(x[0]), str(x[1].value if hasattr(x[1], "value") else x[1]) sorted_elements = [] for item in order: diff --git a/src/prov/serializers/__init__.py b/src/prov/serializers/__init__.py index 043fa69..573a8cd 100644 --- a/src/prov/serializers/__init__.py +++ b/src/prov/serializers/__init__.py @@ -37,7 +37,7 @@ def serialize(self, stream: io.IOBase, **args: Any) -> None: pass @abstractmethod - def deserialize(self, stream: io.IOBase, **args: Any) -> ProvDocument | None: + def deserialize(self, stream: io.IOBase, **args: Any) -> ProvDocument: """ Abstract method for deserializing. diff --git a/src/prov/tests/test_extras.py b/src/prov/tests/test_extras.py index 3c9c0bf..e6c85a1 100644 --- a/src/prov/tests/test_extras.py +++ b/src/prov/tests/test_extras.py @@ -135,7 +135,7 @@ def test_extra_attributes(self): add_further_attributes(inf) self.assertEqual( - len(inf.attributes), len(list(inf.formal_attributes) + inf.extra_attributes) + len(inf.attributes), len(list(inf.formal_attributes) + list(inf.extra_attributes)) ) def test_serialize_to_path(self): @@ -191,13 +191,10 @@ def test_bundle_is_bundle(self): def test_bundle_get_record_by_id(self): document = ProvDocument() - self.assertEqual(document.get_record(None), None) + self.assertEqual(0, len(document.get_record("nonexistentid"))) - # record = document.entity(identifier=EX_NS['e1']) - # self.assertEqual(document.get_record(EX_NS['e1']), record) - # - # bundle = document.bundle(EX_NS['b']) - # self.assertEqual(bundle.get_record(EX_NS['e1']), record) + record = document.entity(identifier=EX_NS['e1']) + self.assertEqual(record, document.get_record(EX_NS['e1'])[0]) def test_bundle_get_records(self): document = ProvDocument() From ce03334d8f720244f59a86876ec53a84e803c922 Mon Sep 17 00:00:00 2001 From: Trung Dong Huynh Date: Mon, 9 Jun 2025 21:39:15 +0100 Subject: [PATCH 11/20] Add type annotations in `provn.py` #143 --- src/prov/serializers/provn.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/prov/serializers/provn.py b/src/prov/serializers/provn.py index bfd3c32..e423cd7 100644 --- a/src/prov/serializers/provn.py +++ b/src/prov/serializers/provn.py @@ -2,24 +2,31 @@ __email__ = "trungdong@donggiang.com" import io +from typing import Any +from prov.model import ProvDocument from prov.serializers import Serializer class ProvNSerializer(Serializer): """PROV-N serializer for ProvDocument""" - def serialize(self, stream, **kwargs): + def serialize(self, stream: io.IOBase, **args: Any) -> None: """ Serializes a :class:`prov.model.ProvDocument` instance to a `PROV-N `_. :param stream: Where to save the output. """ + if self.document is None: + raise Exception("No document to serialize") + provn_content = self.document.get_provn() - if not isinstance(stream, io.TextIOBase): - provn_content = provn_content.encode("utf-8") - stream.write(provn_content) + stream.write( + provn_content + if isinstance(stream, io.TextIOBase) + else provn_content.encode("utf-8") + ) - def deserialize(self, stream, **kwargs): + def deserialize(self, stream: io.IOBase, **args: Any) -> ProvDocument: raise NotImplementedError From 6275cb3dcab9c8774bc8755ce8d1b3d1e31ef297 Mon Sep 17 00:00:00 2001 From: Trung Dong Huynh Date: Mon, 9 Jun 2025 22:01:50 +0100 Subject: [PATCH 12/20] Add type annotations to `dot.py` and `graph.py` #143 --- src/prov/dot.py | 61 ++++++++++++++++++++++++++--------------------- src/prov/graph.py | 6 ++--- 2 files changed, 37 insertions(+), 30 deletions(-) diff --git a/src/prov/dot.py b/src/prov/dot.py index 0ec44e6..459d040 100644 --- a/src/prov/dot.py +++ b/src/prov/dot.py @@ -11,9 +11,14 @@ .. moduleauthor:: Trung Dong Huynh """ + +from __future__ import annotations # needed for | type annotations in Python < 3.10 from datetime import datetime +from html import escape +from typing import Any, Optional from prov.graph import INFERRED_ELEMENT_CLASS +from prov.identifier import QualifiedName from prov.model import ( ProvEntity, ProvActivity, @@ -42,13 +47,10 @@ PROV_ATTRIBUTE_QNAMES, sorted_attributes, ProvException, + ProvRecord, + ProvElement, ) -import pydot - -try: - from html import escape -except ImportError: - from cgi import escape +import pydot # type: ignore[import] __author__ = "Trung Dong Huynh" __email__ = "trungdong@donggiang.com" @@ -87,7 +89,7 @@ "fillcolor": "lightgray", "color": "dimgray", }, -} +} # type: dict[Optional[type[ProvElement | ProvBundle]], dict[str, str]] DOT_PROV_STYLE = { # Generic node 0: { @@ -166,7 +168,7 @@ ANNOTATION_END_ROW = " >" -def htlm_link_if_uri(value): +def htlm_link_if_uri(value: Any) -> str: try: uri = value.uri return '%s' % (uri, str(value)) @@ -175,13 +177,13 @@ def htlm_link_if_uri(value): def prov_to_dot( - bundle, - show_nary=True, - use_labels=False, - direction="BT", - show_element_attributes=True, - show_relation_attributes=True, -): + bundle: ProvBundle, + show_nary: bool = True, + use_labels: bool = False, + direction: str = "BT", + show_element_attributes: bool = True, + show_relation_attributes: bool = True, +) -> pydot.Dot: """ Convert a provenance bundle/document into a DOT graphical representation. @@ -203,11 +205,11 @@ def prov_to_dot( direction = "BT" # reset it to the default value maindot = pydot.Dot(graph_type="digraph", rankdir=direction, charset="utf-8") - node_map = {} + node_map = {} # type: dict[str, pydot.Node] count = [0, 0, 0, 0] # counters for node ids - def _bundle_to_dot(dot, bundle): - def _attach_attribute_annotation(node, record): + def _bundle_to_dot(dot: pydot.Dot | pydot.Cluster, bundle: ProvBundle) -> None: + def _attach_attribute_annotation(node: pydot.Node, record: ProvRecord) -> None: # Adding a node to show all attributes attributes = list( (attr_name, value) @@ -244,17 +246,17 @@ def _attach_attribute_annotation(node, record): dot.add_node(annotations) dot.add_edge(pydot.Edge(annotations, node, **ANNOTATION_LINK_STYLE)) - def _add_bundle(bundle): + def _add_bundle(bundle: ProvBundle) -> pydot.Cluster: count[2] += 1 subdot = pydot.Cluster( - graph_name="c%d" % count[2], URL=f'"{bundle.identifier.uri}"' + graph_name="c%d" % count[2], URL=f'"{bundle.identifier.uri}"' # type: ignore[union-attr] ) subdot.set_label('"%s"' % str(bundle.identifier)) _bundle_to_dot(subdot, bundle) dot.add_subgraph(subdot) return subdot - def _add_node(record): + def _add_node(record: ProvRecord) -> pydot.Node: count[0] += 1 node_id = "n%d" % count[0] if use_labels: @@ -267,12 +269,12 @@ def _add_node(record): node_label = ( f"<{record.label}
" f'' - f'{record.identifier}>' + f"{record.identifier}>" ) else: node_label = f'"{record.identifier}"' - uri = record.identifier.uri + uri = record.identifier.uri # type: ignore[union-attr] style = DOT_PROV_STYLE[record.get_type()] node = pydot.Node(node_id, label=node_label, URL='"%s"' % uri, **style) node_map[uri] = node @@ -282,7 +284,9 @@ def _add_node(record): _attach_attribute_annotation(node, rec) return node - def _add_generic_node(qname, prov_type=None): + def _add_generic_node( + qname: QualifiedName, prov_type: Optional[type[ProvElement]] = None + ) -> pydot.Node: count[0] += 1 node_id = "n%d" % count[0] node_label = f'"{qname}"' @@ -294,14 +298,17 @@ def _add_generic_node(qname, prov_type=None): dot.add_node(node) return node - def _get_bnode(): + def _get_bnode() -> pydot.Node: count[1] += 1 bnode_id = "b%d" % count[1] bnode = pydot.Node(bnode_id, label='""', shape="point", color="gray") dot.add_node(bnode) return bnode - def _get_node(qname, prov_type=None): + def _get_node( + qname: Optional[QualifiedName], + prov_type: Optional[type[ProvElement]] = None, + ) -> pydot.Node: if qname is None: return _get_bnode() uri = qname.uri @@ -382,7 +389,7 @@ def _get_node(qname, prov_type=None): if add_attribute_annotation: _attach_attribute_annotation(bnode, rec) else: - # show a simple binary relations with no annotation + # show a simple binary relation with no annotation dot.add_edge( pydot.Edge( _get_node(nodes[0], inferred_types[0]), diff --git a/src/prov/graph.py b/src/prov/graph.py index 2c1a5af..2e7452f 100644 --- a/src/prov/graph.py +++ b/src/prov/graph.py @@ -56,7 +56,7 @@ } -def prov_to_graph(prov_document): +def prov_to_graph(prov_document: ProvDocument) -> nx.MultiDiGraph: """ Convert a :class:`~prov.model.ProvDocument` to a `MultiDiGraph `_ @@ -64,7 +64,7 @@ def prov_to_graph(prov_document): :param prov_document: The :class:`~prov.model.ProvDocument` instance to convert. """ - g = nx.MultiDiGraph() + g = nx.MultiDiGraph() # type: nx.MultiDiGraph unified = prov_document.unified() node_map = dict() for element in unified.get_records(ProvElement): @@ -89,7 +89,7 @@ def prov_to_graph(prov_document): return g -def graph_to_prov(g): +def graph_to_prov(g: nx.MultiDiGraph) -> ProvDocument: """ Convert a `MultiDiGraph `_ From 495d97dfad52a0f657286cf7946c391344957977 Mon Sep 17 00:00:00 2001 From: Trung Dong Huynh Date: Mon, 9 Jun 2025 22:09:25 +0100 Subject: [PATCH 13/20] Add type annotations to `convert.py` and `compare.py` #143 --- src/prov/scripts/compare.py | 11 ++++++----- src/prov/scripts/convert.py | 14 ++++++++------ 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/prov/scripts/compare.py b/src/prov/scripts/compare.py index 19b633e..b24ff4e 100755 --- a/src/prov/scripts/compare.py +++ b/src/prov/scripts/compare.py @@ -18,13 +18,14 @@ import sys import logging import traceback +from typing import Optional from prov.model import ProvDocument logger = logging.getLogger(__name__) -__all__ = [] +__all__ = [] # type: ignore __version__ = 0.1 __date__ = "2015-06-16" __updated__ = "2025-06-07" @@ -37,15 +38,15 @@ class CLIError(Exception): """Generic exception to raise and log different fatal errors.""" - def __init__(self, msg): - super(CLIError).__init__(type(self)) + def __init__(self, msg: str): + super(CLIError, self).__init__(type(self)) self.msg = "E: %s" % msg - def __str__(self): + def __str__(self) -> str: return self.msg -def main(argv=None): # IGNORE:C0111 +def main(argv: Optional[list] = None) -> int: # IGNORE:C0111 """Command line options.""" if argv is None: diff --git a/src/prov/scripts/convert.py b/src/prov/scripts/convert.py index c4d5183..c53cebc 100755 --- a/src/prov/scripts/convert.py +++ b/src/prov/scripts/convert.py @@ -14,10 +14,12 @@ """ from argparse import ArgumentParser, RawDescriptionHelpFormatter, FileType +import io import os import sys import logging import traceback +from typing import Optional from prov.model import ProvDocument from prov import serializers @@ -25,7 +27,7 @@ logger = logging.getLogger(__name__) -__all__ = [] +__all__ = [] # type: ignore __version__ = 0.1 __date__ = "2014-03-14" __updated__ = "2025-06-07" @@ -74,15 +76,15 @@ class CLIError(Exception): """Generic exception to raise and log different fatal errors.""" - def __init__(self, msg): - super(CLIError).__init__(type(self)) + def __init__(self, msg: str): + super(CLIError, self).__init__(type(self)) self.msg = "E: %s" % msg - def __str__(self): + def __str__(self) -> str: return self.msg -def convert_file(infile, outfile, output_format): +def convert_file(infile: io.FileIO, outfile: io.FileIO, output_format: str) -> None: prov_doc = ProvDocument.deserialize(infile) # Formats not supported by prov.serializers @@ -102,7 +104,7 @@ def convert_file(infile, outfile, output_format): raise CLIError('Output format "%s" is not supported.' % output_format) -def main(argv=None): # IGNORE:C0111 +def main(argv: Optional[list] = None) -> int: # IGNORE:C0111 """Command line options.""" if argv is None: From 8441a6a7f5d78a431fd6b845d273d6e36e9afe50 Mon Sep 17 00:00:00 2001 From: Trung Dong Huynh Date: Mon, 9 Jun 2025 22:48:28 +0100 Subject: [PATCH 14/20] Add type annotations to `provjson.py` #143 --- src/prov/serializers/provjson.py | 77 ++++++++++++++++++-------------- 1 file changed, 43 insertions(+), 34 deletions(-) diff --git a/src/prov/serializers/provjson.py b/src/prov/serializers/provjson.py index 67a17b7..eb00a05 100644 --- a/src/prov/serializers/provjson.py +++ b/src/prov/serializers/provjson.py @@ -1,7 +1,9 @@ +from __future__ import annotations # needed for | type annotations in Python < 3.10 from collections import defaultdict import datetime import io import json +from typing import Any, Optional from prov import Error from prov.serializers import Serializer @@ -15,6 +17,8 @@ ProvBundle, first, parse_xsd_datetime, + ProvRecord, + QualifiedNameCandidate, ) import logging @@ -25,16 +29,19 @@ __email__ = "trungdong@donggiang.com" +ProvJSONDict = dict[str, dict[str, Any]] + + class ProvJSONException(Error): pass class AnonymousIDGenerator: - def __init__(self): - self._cache = {} - self._count = 0 + def __init__(self) -> None: + self._cache = {} # type: dict[ProvRecord, Identifier] + self._count = 0 # type: int - def get_anon_id(self, obj, local_prefix="id"): + def get_anon_id(self, obj: ProvRecord, local_prefix: str = "id") -> Identifier: if obj not in self._cache: self._count += 1 self._cache[obj] = Identifier("_:%s%d" % (local_prefix, self._count)) @@ -44,7 +51,7 @@ def get_anon_id(self, obj, local_prefix="id"): # Reverse map for prov.model.XSD_DATATYPE_PARSERS LITERAL_XSDTYPE_MAP = { float: "xsd:double", - int: "xsd:int" + int: "xsd:int", # boolean, string values are supported natively by PROV-JSON # datetime values are converted separately } @@ -55,7 +62,7 @@ class ProvJSONSerializer(Serializer): PROV-JSON serializer for :class:`~prov.model.ProvDocument` """ - def serialize(self, stream, **kwargs): + def serialize(self, stream: io.IOBase, **args: Any) -> None: """ Serializes a :class:`~prov.model.ProvDocument` instance to `PROV-JSON `_. @@ -64,10 +71,10 @@ def serialize(self, stream, **kwargs): """ buf = io.StringIO() try: - json.dump(self.document, buf, cls=ProvJSONEncoder, **kwargs) + json.dump(self.document, buf, cls=ProvJSONEncoder, **args) buf.seek(0, 0) # Right now this is a bytestream. If the object to stream to is - # a text object is must be decoded. We assume utf-8 here which + # a text object, it must be decoded. We assume utf-8 here, which # should be fine for almost every case. if isinstance(stream, io.TextIOBase): stream.write(buf.read()) @@ -76,7 +83,7 @@ def serialize(self, stream, **kwargs): finally: buf.close() - def deserialize(self, stream, **kwargs): + def deserialize(self, stream: io.IOBase, **args: Any) -> ProvDocument: """ Deserialize from the `PROV JSON `_ representation to a @@ -87,11 +94,11 @@ def deserialize(self, stream, **kwargs): if not isinstance(stream, io.TextIOBase): buf = io.StringIO(stream.read().decode("utf-8")) stream = buf - return json.load(stream, cls=ProvJSONDecoder, **kwargs) + return json.load(stream, cls=ProvJSONDecoder, **args) class ProvJSONEncoder(json.JSONEncoder): - def default(self, o): + def default(self, o: Any) -> Any: if isinstance(o, ProvDocument): return encode_json_document(o) else: @@ -99,7 +106,7 @@ def default(self, o): class ProvJSONDecoder(json.JSONDecoder): - def decode(self, s, *args, **kwargs): + def decode(self, s: str, *args: Any, **kwargs: Any) -> Any: container = super(ProvJSONDecoder, self).decode(s, *args, **kwargs) document = ProvDocument() decode_json_document(container, document) @@ -107,14 +114,16 @@ def decode(self, s, *args, **kwargs): # Encoding/decoding functions -def valid_qualified_name(bundle, value): +def valid_qualified_name( + bundle: ProvBundle, value: Optional[QualifiedNameCandidate] +) -> QualifiedName | None: if value is None: return None qualified_name = bundle.valid_qualified_name(value) return qualified_name -def encode_json_document(document): +def encode_json_document(document: ProvDocument) -> ProvJSONDict: container = encode_json_container(document) for bundle in document.bundles: # encoding the sub-bundle @@ -123,9 +132,9 @@ def encode_json_document(document): return container -def encode_json_container(bundle): - container = defaultdict(dict) - prefixes = {} +def encode_json_container(bundle: ProvBundle) -> ProvJSONDict: + container = defaultdict(dict) # type: dict[str, dict] + prefixes = {} # type: dict[str, str] for namespace in bundle._namespaces.get_registered_namespaces(): prefixes[namespace.prefix] = namespace.uri if bundle._namespaces._default: @@ -135,7 +144,7 @@ def encode_json_container(bundle): id_generator = AnonymousIDGenerator() - def real_or_anon_id(r): + def real_or_anon_id(r: ProvRecord) -> Identifier: return r._identifier if r._identifier else id_generator.get_anon_id(r) for record in bundle._records: @@ -143,9 +152,9 @@ def real_or_anon_id(r): rec_label = PROV_N_MAP[rec_type] identifier = str(real_or_anon_id(record)) - record_json = {} + record_json = {} # type: dict[str, Any] if record._attributes: - for (attr, values) in record._attributes.items(): + for attr, values in record._attributes.items(): if not values: continue attr_name = str(attr) @@ -153,7 +162,7 @@ def real_or_anon_id(r): # TODO: QName export record_json[attr_name] = str(first(values)) elif attr in PROV_ATTRIBUTE_LITERALS: - record_json[attr_name] = first(values).isoformat() + record_json[attr_name] = first(values).isoformat() # type: ignore[union-attr] else: if len(values) == 1: # single value @@ -182,7 +191,7 @@ def real_or_anon_id(r): return container -def decode_json_document(content, document): +def decode_json_document(content: ProvJSONDict, document: ProvDocument) -> None: bundles = dict() if "bundle" in content: bundles = content["bundle"] @@ -196,12 +205,12 @@ def decode_json_document(content, document): document.add_bundle(bundle, bundle.valid_qualified_name(bundle_id)) -def decode_json_container(jc, bundle): +def decode_json_container(jc: ProvJSONDict, bundle: ProvBundle) -> None: if "prefix" in jc: prefixes = jc["prefix"] for prefix, uri in prefixes.items(): if prefix != "default": - bundle.add_namespace(Namespace(prefix, uri)) + bundle.add_namespace(Namespace(prefix, uri)) # type: ignore else: bundle.set_default_namespace(uri) del jc["prefix"] @@ -217,16 +226,16 @@ def decode_json_container(jc, bundle): elements = content for element in elements: - attributes = dict() - other_attributes = [] + attributes = dict() # type: dict[QualifiedNameCandidate, Any] + other_attributes = [] # type: list[tuple[QualifiedNameCandidate, Any]] # this is for the multiple-entity membership hack to come membership_extra_members = None for attr_name, values in element.items(): attr = ( PROV_ATTRIBUTES_ID_MAP[attr_name] if attr_name in PROV_ATTRIBUTES_ID_MAP - else valid_qualified_name(bundle, attr_name) - ) + else bundle.mandatory_valid_qname(attr_name) + ) # type: QualifiedName if attr in PROV_ATTRIBUTES: if isinstance(values, list): # only one value is allowed @@ -280,11 +289,11 @@ def decode_json_container(jc, bundle): collection = attributes[PROV_ATTR_COLLECTION] for member in membership_extra_members: bundle.membership( - collection, valid_qualified_name(bundle, member) + collection, bundle.mandatory_valid_qname(member) ) -def encode_json_representation(value): +def encode_json_representation(value: Any) -> Any: if isinstance(value, Literal): return literal_json_representation(value) elif isinstance(value, datetime.datetime): @@ -301,12 +310,12 @@ def encode_json_representation(value): return value -def decode_json_representation(literal, bundle): +def decode_json_representation(literal: Any, bundle: ProvBundle) -> Any: if isinstance(literal, dict): # complex type value = literal["$"] - datatype = literal["type"] if "type" in literal else None - datatype = valid_qualified_name(bundle, datatype) + datatype_str = literal["type"] if "type" in literal else None # type: Optional[str] + datatype = valid_qualified_name(bundle, datatype_str) langtag = literal["lang"] if "lang" in literal else None if datatype == XSD_ANYURI: return Identifier(value) @@ -322,7 +331,7 @@ def decode_json_representation(literal, bundle): return literal -def literal_json_representation(literal): +def literal_json_representation(literal: Literal) -> dict[str, str]: # TODO: QName export value, datatype, langtag = literal.value, literal.datatype, literal.langtag if langtag: From 03e76ed2dcab3f545e86392a1ee0378a0241be36 Mon Sep 17 00:00:00 2001 From: Trung Dong Huynh Date: Tue, 10 Jun 2025 11:35:03 +0100 Subject: [PATCH 15/20] Add type annotations and improve typing consistency in `provrdf.py` #143 --- src/prov/serializers/provrdf.py | 285 ++++++++++++++++++-------------- 1 file changed, 164 insertions(+), 121 deletions(-) diff --git a/src/prov/serializers/provrdf.py b/src/prov/serializers/provrdf.py index 0ff162d..2938b0a 100644 --- a/src/prov/serializers/provrdf.py +++ b/src/prov/serializers/provrdf.py @@ -1,9 +1,12 @@ -"""PROV-RDF serializers for ProvDocument -""" +"""PROV-RDF serializers for ProvDocument""" + +from __future__ import annotations # needed for | type annotations in Python < 3.10 import base64 from collections import OrderedDict import datetime import io +from typing import Any, Optional, Generator +import warnings import dateutil.parser @@ -44,6 +47,7 @@ PROV_ATTR_USED_ENTITY, PROV_ASSOCIATION, ) +from prov.identifier import QualifiedName from prov.serializers import Serializer @@ -56,14 +60,14 @@ class ProvRDFException(Error): class AnonymousIDGenerator: - def __init__(self): - self._cache = {} - self._count = 0 + def __init__(self) -> None: + self._cache = {} # type: dict[Any, str] + self._count = 0 # type: int - def get_anon_id(self, obj, local_prefix="id"): + def get_anon_id(self, obj: pm.ProvRecord, local_prefix: str = "id") -> str: if obj not in self._cache: self._count += 1 - self._cache[obj] = pm.Identifier("_:%s%d" % (local_prefix, self._count)).uri + self._cache[obj] = "_:%s%d" % (local_prefix, self._count) return self._cache[obj] @@ -76,7 +80,7 @@ def get_anon_id(self, obj, local_prefix="id"): # datetime values are converted separately } -relation_mapper = { +RELATION_MAP = { URIRef(PROV["alternateOf"].uri): "alternate", URIRef(PROV["actedOnBehalfOf"].uri): "delegation", URIRef(PROV["specializationOf"].uri): "specialization", @@ -93,7 +97,7 @@ def get_anon_id(self, obj, local_prefix="id"): URIRef(PROV["hadMember"].uri): "membership", URIRef(PROV["used"].uri): "usage", } -predicate_mapper = { +PREDICATE_MAP = { RDFS.label: pm.PROV["label"], URIRef(PROV["atLocation"].uri): PROV_LOCATION, URIRef(PROV["startedAtTime"].uri): PROV_ATTR_STARTTIME, @@ -107,25 +111,22 @@ def get_anon_id(self, obj, local_prefix="id"): } -def attr2rdf(attr): +def attr2rdf(attr: QualifiedName) -> URIRef: return URIRef(PROV[PROV_ID_ATTRIBUTES_MAP[attr].split("prov:")[1]].uri) -def valid_qualified_name(bundle, value, xsd_qname=False): - if value is None: - return None - qualified_name = bundle.valid_qualified_name(value) - return qualified_name if not xsd_qname else XSD_QNAME(qualified_name) - - class ProvRDFSerializer(Serializer): """ PROV-O serializer for :class:`~prov.model.ProvDocument` """ def serialize( - self, stream=None, rdf_format="trig", PROV_N_MAP=PROV_N_MAP, **kwargs - ): + self, + stream: io.IOBase, + rdf_format: str = "trig", + PROV_N_MAP: dict[pm.QualifiedName, str] = PROV_N_MAP, + **kwargs: Any, + ) -> None: """ Serializes a :class:`~prov.model.ProvDocument` instance to `PROV-O `_. @@ -133,6 +134,9 @@ def serialize( :param stream: Where to save the output. :param rdf_format: The RDF format of the output, default to TRiG. """ + if self.document is None: + raise ProvRDFException("No document to serialize.") + container = self.encode_document(self.document, PROV_N_MAP=PROV_N_MAP) newargs = kwargs.copy() newargs["format"] = rdf_format @@ -153,12 +157,12 @@ def serialize( def deserialize( self, - stream, - rdf_format="trig", - relation_mapper=relation_mapper, - predicate_mapper=predicate_mapper, - **kwargs, - ): + stream: io.IOBase, + rdf_format: str = "trig", + relation_mapper: dict[URIRef, str] = RELATION_MAP, + predicate_mapper: dict[URIRef, pm.QualifiedName] = PREDICATE_MAP, + **kwargs: Any, + ) -> pm.ProvDocument: """ Deserialize from the `PROV-O `_ representation to a :class:`~prov.model.ProvDocument` instance. @@ -170,20 +174,21 @@ def deserialize( newargs["format"] = rdf_format container = ConjunctiveGraph() container.parse(stream, **newargs) - document = pm.ProvDocument() - self.document = document + self.document = pm.ProvDocument() self.decode_document( container, - document, + self.document, relation_mapper=relation_mapper, predicate_mapper=predicate_mapper, ) - return document + return self.document - def valid_identifier(self, value): - return self.document.valid_qualified_name(value) + def valid_identifier( + self, value: pm.QualifiedNameCandidate + ) -> pm.QualifiedName | None: + return self.document.valid_qualified_name(value) # type: ignore[union-attr] - def encode_rdf_representation(self, value): + def encode_rdf_representation(self, value: Any) -> RDFLiteral | URIRef: if isinstance(value, URIRef): return value elif isinstance(value, pm.Literal): @@ -199,11 +204,11 @@ def encode_rdf_representation(self, value): else: return RDFLiteral(value) - def decode_rdf_representation(self, literal, graph): + def decode_rdf_representation(self, literal: Any, graph: ConjunctiveGraph) -> Any: if isinstance(literal, RDFLiteral): value = literal.value if literal.value is not None else literal - datatype = literal.datatype if hasattr(literal, "datatype") else None - langtag = literal.language if hasattr(literal, "language") else None + datatype = literal.datatype + langtag = literal.language if datatype and "XMLLiteral" in datatype: value = literal if datatype and "base64Binary" in datatype: @@ -232,26 +237,34 @@ def decode_rdf_representation(self, literal, graph): rval = self.valid_identifier(literal) if rval is None: prefix, iri, _ = graph.namespace_manager.compute_qname(literal) - ns = self.document.add_namespace(prefix, iri) + ns = self.document.add_namespace(prefix, iri) # type: ignore[union-attr] rval = pm.QualifiedName(ns, literal.replace(ns.uri, "")) return rval else: # simple type, just return it return literal - def encode_document(self, document, PROV_N_MAP=PROV_N_MAP): + def encode_document( + self, + document: pm.ProvDocument, + PROV_N_MAP: dict[pm.QualifiedName, str] = PROV_N_MAP, + ) -> ConjunctiveGraph: container = self.encode_container(document) for item in document.bundles: # encoding the sub-bundle bundle = self.encode_container( - item, identifier=item.identifier.uri, PROV_N_MAP=PROV_N_MAP + item, identifier=item.identifier.uri, PROV_N_MAP=PROV_N_MAP # type: ignore[union-attr] ) container.addN(bundle.quads()) return container def encode_container( - self, bundle, PROV_N_MAP=PROV_N_MAP, container=None, identifier=None - ): + self, + bundle: pm.ProvBundle, + PROV_N_MAP: dict[pm.QualifiedName, str] = PROV_N_MAP, + container: Optional[ConjunctiveGraph] = None, + identifier: Optional[str] = None, + ) -> ConjunctiveGraph: if container is None: container = ConjunctiveGraph(identifier=identifier) nm = container.namespace_manager @@ -261,8 +274,8 @@ def encode_container( container.bind(namespace.prefix, namespace.uri) id_generator = AnonymousIDGenerator() - real_or_anon_id = ( - lambda record: record._identifier.uri + real_or_anon_id = lambda record: ( + record._identifier.uri if record._identifier else id_generator.get_anon_id(record) ) @@ -482,11 +495,11 @@ def encode_container( def decode_document( self, - content, - document, - relation_mapper=relation_mapper, - predicate_mapper=predicate_mapper, - ): + content: ConjunctiveGraph, + document: pm.ProvDocument, + relation_mapper: dict[URIRef, str] = RELATION_MAP, + predicate_mapper: dict[URIRef, pm.QualifiedName] = PREDICATE_MAP, + ) -> None: for prefix, url in content.namespaces(): document.add_namespace(prefix, str(url)) if hasattr(content, "contexts"): @@ -517,25 +530,34 @@ def decode_document( def decode_container( self, - graph, - bundle, - relation_mapper=relation_mapper, - predicate_mapper=predicate_mapper, - ): - ids = {} - PROV_CLS_MAP = {} - formal_attributes = {} - unique_sets = {} - for key, val in PROV_BASE_CLS.items(): - PROV_CLS_MAP[key.uri] = PROV_BASE_CLS[key] - other_attributes = {} + graph: ConjunctiveGraph, + bundle: pm.ProvBundle, + relation_mapper: dict[URIRef, str] = RELATION_MAP, + predicate_mapper: dict[URIRef, pm.QualifiedName] = PREDICATE_MAP, + ) -> None: + record_types = {} # type: dict[str, pm.QualifiedName] + PROV_CLS_MAP = {} # type: dict[str, pm.QualifiedName] + formal_attributes = ( + {} + ) # type: dict[str, dict[pm.QualifiedName, Optional[pm.QualifiedNameCandidate | datetime.datetime]]] + unique_sets = ( + {} + ) # type: dict[str, dict[pm.QualifiedName, list[pm.QualifiedNameCandidate | datetime.datetime]]] + for prov_type, _ in PROV_BASE_CLS.items(): + PROV_CLS_MAP[prov_type.uri] = PROV_BASE_CLS[prov_type] + other_attributes = ( + {} + ) # type: dict[str, list[tuple[pm.QualifiedNameCandidate, Any]]] for stmt in graph.triples((None, RDF.type, None)): - id = str(stmt[0]) + subj = str(stmt[0]) obj = str(stmt[2]) if obj in PROV_CLS_MAP: - if not isinstance(stmt[0], BNode) and self.valid_identifier(id) is None: - prefix, iri, _ = graph.namespace_manager.compute_qname(id) - self.document.add_namespace(prefix, iri) + if ( + not isinstance(stmt[0], BNode) + and self.valid_identifier(subj) is None + ): + prefix, iri, _ = graph.namespace_manager.compute_qname(subj) + self.document.add_namespace(prefix, iri) # type: ignore[union-attr] try: prov_obj = PROV_CLS_MAP[obj] except AttributeError: @@ -547,7 +569,7 @@ def decode_container( or pm.PROV["PrimarySource"].uri in stmt[2] ) if ( - id not in ids + subj not in record_types and prov_obj and ( prov_obj.uri == obj @@ -555,12 +577,12 @@ def decode_container( or isinstance(stmt[0], BNode) ) ): - ids[id] = prov_obj + record_types[subj] = prov_obj klass = pm.PROV_REC_CLS[prov_obj] - formal_attributes[id] = OrderedDict( + formal_attributes[subj] = OrderedDict( [(key, None) for key in klass.FORMAL_ATTRIBUTES] ) - unique_sets[id] = OrderedDict( + unique_sets[subj] = OrderedDict( [(key, []) for key in klass.FORMAL_ATTRIBUTES] ) add_attr = False or ( @@ -568,31 +590,33 @@ def decode_container( and prov_obj.uri != obj ) if add_attr: - if id not in other_attributes: - other_attributes[id] = [] + if subj not in other_attributes: + other_attributes[subj] = [] obj_formatted = self.decode_rdf_representation(stmt[2], graph) - other_attributes[id].append((pm.PROV["type"], obj_formatted)) + other_attributes[subj].append((pm.PROV["type"], obj_formatted)) else: - if id not in other_attributes: - other_attributes[id] = [] + if subj not in other_attributes: + other_attributes[subj] = [] obj = self.decode_rdf_representation(stmt[2], graph) - other_attributes[id].append((pm.PROV["type"], obj)) - for id, pred, obj in graph: - id = str(id) - if id not in other_attributes: - other_attributes[id] = [] + other_attributes[subj].append((pm.PROV["type"], obj)) + for subj, pred, obj in graph: + subj = str(subj) + if subj not in other_attributes: + other_attributes[subj] = [] if pred == RDF.type: continue if pred in relation_mapper: if "alternateOf" in pred: - getattr(bundle, relation_mapper[pred])(obj, id) + getattr(bundle, relation_mapper[pred])(obj, subj) elif "mentionOf" in pred: mentionBundle = None for stmt in graph.triples( - (URIRef(id), URIRef(pm.PROV["asInBundle"].uri), None) + (URIRef(subj), URIRef(pm.PROV["asInBundle"].uri), None) ): mentionBundle = stmt[2] - getattr(bundle, relation_mapper[pred])(id, str(obj), mentionBundle) + getattr(bundle, relation_mapper[pred])( + subj, str(obj), mentionBundle + ) elif "actedOnBehalfOf" in pred or "wasAssociatedWith" in pred: qualifier = ( "qualified" @@ -601,77 +625,92 @@ def decode_container( ) qualifier_bnode = None for stmt in graph.triples( - (URIRef(id), URIRef(pm.PROV[qualifier].uri), None) + (URIRef(subj), URIRef(pm.PROV[qualifier].uri), None) ): qualifier_bnode = stmt[2] if qualifier_bnode is None: - getattr(bundle, relation_mapper[pred])(id, str(obj)) + getattr(bundle, relation_mapper[pred])(subj, str(obj)) else: fakeys = list(formal_attributes[str(qualifier_bnode)].keys()) - formal_attributes[str(qualifier_bnode)][fakeys[0]] = id + formal_attributes[str(qualifier_bnode)][fakeys[0]] = subj formal_attributes[str(qualifier_bnode)][fakeys[1]] = str(obj) else: - getattr(bundle, relation_mapper[pred])(id, str(obj)) - elif id in ids: + getattr(bundle, relation_mapper[pred])(subj, str(obj)) + elif subj in record_types: obj1 = self.decode_rdf_representation(obj, graph) if obj is not None and obj1 is None: raise ValueError(("Error transforming", obj)) pred_new = pred if pred in predicate_mapper: pred_new = predicate_mapper[pred] - if ids[id] == PROV_COMMUNICATION and "activity" in str(pred_new): + if record_types[subj] == PROV_COMMUNICATION and "activity" in str( + pred_new + ): pred_new = PROV_ATTR_INFORMANT - if ids[id] == PROV_DELEGATION and "agent" in str(pred_new): + if record_types[subj] == PROV_DELEGATION and "agent" in str(pred_new): pred_new = PROV_ATTR_RESPONSIBLE - if ids[id] in [PROV_END, PROV_START] and "entity" in str(pred_new): + if record_types[subj] in [PROV_END, PROV_START] and "entity" in str( + pred_new + ): pred_new = PROV_ATTR_TRIGGER - if ids[id] in [PROV_END] and "activity" in str(pred_new): + if record_types[subj] in [PROV_END] and "activity" in str(pred_new): pred_new = PROV_ATTR_ENDER - if ids[id] in [PROV_START] and "activity" in str(pred_new): + if record_types[subj] in [PROV_START] and "activity" in str(pred_new): pred_new = PROV_ATTR_STARTER - if ids[id] == PROV_DERIVATION and "entity" in str(pred_new): + if record_types[subj] == PROV_DERIVATION and "entity" in str(pred_new): pred_new = PROV_ATTR_USED_ENTITY - if str(pred_new) in [val.uri for val in formal_attributes[id]]: - qname_key = self.valid_identifier(pred_new) - formal_attributes[id][qname_key] = obj1 - unique_sets[id][qname_key].append(obj1) - if len(unique_sets[id][qname_key]) > 1: - formal_attributes[id][qname_key] = None + if str(pred_new) in [val.uri for val in formal_attributes[subj]]: + qname_key = self.document.mandatory_valid_qname(pred_new) # type: ignore[union-attr] + formal_attributes[subj][qname_key] = obj1 + unique_sets[subj][qname_key].append(obj1) + if len(unique_sets[subj][qname_key]) > 1: + formal_attributes[subj][qname_key] = None else: if "qualified" not in str(pred_new) and "asInBundle" not in str( pred_new ): - other_attributes[id].append((str(pred_new), obj1)) + other_attributes[subj].append((str(pred_new), obj1)) local_key = str(obj) - if local_key in ids: + if local_key in record_types: if "qualified" in pred: formal_attributes[local_key][ list(formal_attributes[local_key].keys())[0] - ] = id - for id in ids: + ] = subj + for subj in record_types: attrs = None - if id in other_attributes: - attrs = other_attributes[id] - items_to_walk = [] - for qname, values in unique_sets[id].items(): + if subj in other_attributes: + attrs = other_attributes[subj] + items_to_walk = ( + [] + ) # type: list[tuple[pm.QualifiedName, list[pm.QualifiedNameCandidate | datetime.datetime]]] + for qname, values in unique_sets[subj].items(): if values and len(values) > 1: items_to_walk.append((qname, values)) if items_to_walk: for subset in list(walk(items_to_walk)): - for key, value in subset.items(): - formal_attributes[id][key] = value - bundle.new_record(ids[id], id, formal_attributes[id], attrs) + for prov_type, value in subset.items(): + formal_attributes[subj][prov_type] = value + bundle.new_record( + record_types[subj], subj, formal_attributes[subj].items(), attrs + ) else: - bundle.new_record(ids[id], id, formal_attributes[id], attrs) - ids[id] = None + bundle.new_record( + record_types[subj], subj, formal_attributes[subj].items(), attrs + ) + if attrs is not None: - other_attributes[id] = [] - for key, val in other_attributes.items(): - if val: - ids[key].add_attributes(val) + del other_attributes[subj] + + if other_attributes: + warnings.warn( + "The following attributes were not converted: " + str(other_attributes), + UserWarning, + ) -def walk(children, level=0, path=None, usename=True): +def walk( + children: list, level: int = 0, path: dict = None, usename: bool = True # type: ignore[assignment] +) -> Generator[dict]: """Generate all the full paths in a tree, as a dict. :Example: @@ -686,6 +725,7 @@ def walk(children, level=0, path=None, usename=True): # Entry point if level == 0: path = {} + # Exit condition if not children: yield path.copy() @@ -704,13 +744,16 @@ def walk(children, level=0, path=None, usename=True): yield child_paths -def literal_rdf_representation(literal): - value = str(literal.value) if literal.value else literal +def literal_rdf_representation(literal: pm.Literal) -> RDFLiteral: if literal.langtag: # a language tag can only go with prov:InternationalizedString - return RDFLiteral(value, lang=str(literal.langtag)) + return RDFLiteral(literal.value, lang=literal.langtag) else: datatype = literal.datatype - if "base64Binary" in datatype.uri: - value = literal.value.encode() - return RDFLiteral(value, datatype=datatype.uri) + if datatype is not None: + if "base64Binary" in datatype.uri: + return RDFLiteral(literal.value.encode(), datatype=datatype.uri) + else: + return RDFLiteral(literal.value, datatype=datatype.uri) + else: + raise ValueError("Literal has no datatype") From 104ffed040a21d6e0212f26f8e82ae0c1e09dbab Mon Sep 17 00:00:00 2001 From: Trung Dong Huynh Date: Tue, 10 Jun 2025 13:08:44 +0100 Subject: [PATCH 16/20] Add type annotations and improve typing consistency in `provxml.py` #143 Add PROV_INTERNATIONALIZEDSTRING to `constants.py` for convenience --- src/prov/constants.py | 1 + src/prov/model.py | 6 +- src/prov/serializers/provxml.py | 115 ++++++++++++++++++++------------ 3 files changed, 78 insertions(+), 44 deletions(-) diff --git a/src/prov/constants.py b/src/prov/constants.py index e4417da..3f0b011 100644 --- a/src/prov/constants.py +++ b/src/prov/constants.py @@ -186,6 +186,7 @@ PROV_ROLE = PROV["role"] PROV_QUALIFIEDNAME = PROV["QUALIFIED_NAME"] +PROV_INTERNATIONALIZEDSTRING = PROV["InternationalizedString"] # XSD DATA TYPES XSD_ANYURI = XSD["anyURI"] diff --git a/src/prov/model.py b/src/prov/model.py index ce37e7b..dba8e7d 100644 --- a/src/prov/model.py +++ b/src/prov/model.py @@ -159,16 +159,16 @@ def __init__( "Assuming prov:InternationalizedString as the type of " '"%s"@%s' % (value, langtag) ) - datatype = PROV["InternationalizedString"] + datatype = PROV_INTERNATIONALIZEDSTRING # PROV JSON states that the type field must not be set when # using the lang attribute and PROV XML requires it to be an # internationalized string. - elif datatype != PROV["InternationalizedString"]: + elif datatype != PROV_INTERNATIONALIZEDSTRING: logger.warning( 'Invalid data type (%s) for "%s"@%s, overridden as ' "prov:InternationalizedString." % (datatype, value, langtag) ) - datatype = PROV["InternationalizedString"] + datatype = PROV_INTERNATIONALIZEDSTRING self._datatype: Optional[QualifiedName] = datatype # langtag is always a string self._langtag: Optional[str] = str(langtag) if langtag is not None else None diff --git a/src/prov/serializers/provxml.py b/src/prov/serializers/provxml.py index 324a0b7..e13524e 100644 --- a/src/prov/serializers/provxml.py +++ b/src/prov/serializers/provxml.py @@ -1,8 +1,11 @@ +from __future__ import annotations # needed for | type annotations in Python < 3.10 import datetime import logging from lxml import etree import io +from typing import Any, Optional import warnings + import prov import prov.identifier from prov.model import DEFAULT_NAMESPACES, sorted_attributes @@ -34,7 +37,9 @@ class ProvXMLException(prov.Error): class ProvXMLSerializer(Serializer): """PROV-XML serializer for :class:`~prov.model.ProvDocument`""" - def serialize(self, stream, force_types=False, **kwargs): + def serialize( + self, stream: io.IOBase, force_types: bool = False, **kwargs: Any + ) -> None: """ Serializes a :class:`~prov.model.ProvDocument` instance to `PROV-XML `_. @@ -49,6 +54,9 @@ def serialize(self, stream, force_types=False, **kwargs): types will always be set if the Python type requires it. False is a good default and it should rarely require changing. """ + if self.document is None: + raise ProvXMLException("No document to serialize.") + xml_root = self.serialize_bundle(bundle=self.document, force_types=force_types) for bundle in self.document.bundles: self.serialize_bundle( @@ -65,9 +73,14 @@ def serialize(self, stream, force_types=False, **kwargs): ) ) else: - et.write(stream, pretty_print=True, xml_declaration=True, encoding="UTF-8") - - def serialize_bundle(self, bundle, element=None, force_types=False): + et.write(stream, pretty_print=True, xml_declaration=True, encoding="UTF-8") # type: ignore[arg-type] + + def serialize_bundle( + self, + bundle: prov.model.ProvBundle, + element: Optional[etree._Element] = None, + force_types: bool = False, + ) -> etree._Element: """ Serializes a bundle or document to PROV XML. @@ -86,10 +99,11 @@ def serialize_bundle(self, bundle, element=None, force_types=False): # element. nsmap = { ns.prefix: ns.uri - for ns in self.document._namespaces.get_registered_namespaces() - } - if self.document._namespaces._default: - nsmap[None] = self.document._namespaces._default.uri + for ns in self.document._namespaces.get_registered_namespaces() # type: ignore[union-attr] + } # type: dict[str, str] + if self.document._namespaces._default: # type: ignore[union-attr] + # TODO: Check if the below works as expected. + nsmap[None] = self.document._namespaces._default.uri # type: ignore[union-attr, index] for namespace in bundle.namespaces: if namespace not in nsmap: nsmap[namespace.prefix] = namespace.uri @@ -123,7 +137,7 @@ def serialize_bundle(self, bundle, element=None, force_types=False): # Derive the record label from its attributes which is sometimes # needed. - attributes = list(record.attributes) + attributes = record.attributes rec_label = self._derive_record_label(rec_type, attributes) elem = etree.SubElement(xml_bundle_root, _ns_prov(rec_label), attrs) @@ -133,7 +147,10 @@ def serialize_bundle(self, bundle, element=None, force_types=False): elem, _ns(attr.namespace.uri, attr.localpart) ) if isinstance(value, prov.model.Literal): - if value.datatype not in [None, PROV["InternationalizedString"]]: + if ( + value.datatype is not None + and value.datatype != PROV_INTERNATIONALIZEDSTRING + ): subelem.attrib[_ns_xsi("type")] = "%s:%s" % ( value.datatype.namespace.prefix, value.datatype.localpart, @@ -163,14 +180,13 @@ def serialize_bundle(self, bundle, element=None, force_types=False): # # To enable a mapping of Python types to XML and back, # the XSD type must be written for these types. - ALWAYS_CHECK = [ + ALWAYS_CHECK = { bool, datetime.datetime, float, int, prov.identifier.Identifier, - ] - ALWAYS_CHECK = tuple(ALWAYS_CHECK) + } if ( ( force_types @@ -215,7 +231,7 @@ def serialize_bundle(self, bundle, element=None, force_types=False): subelem.text = v return xml_bundle_root - def deserialize(self, stream, **kwargs): + def deserialize(self, stream: io.IOBase, **kwargs: Any) -> prov.model.ProvDocument: """ Deserialize from `PROV-XML `_ representation to a :class:`~prov.model.ProvDocument` instance. @@ -226,20 +242,24 @@ def deserialize(self, stream, **kwargs): with io.BytesIO() as buf: buf.write(stream.read().encode("utf-8")) buf.seek(0, 0) - xml_doc = etree.parse(buf).getroot() + xml_doc = etree.parse(buf).getroot() # type: etree._Element else: - xml_doc = etree.parse(stream).getroot() + xml_doc = etree.parse(stream).getroot() # type: ignore[arg-type] # Remove all comments. - for c in xml_doc.xpath("//comment()"): - p = c.getparent() - p.remove(c) + for c in xml_doc.xpath("//comment()"): # type: ignore[union-attr] + p = c.getparent() # type: ignore[union-attr] + p.remove(c) # type: ignore[union-attr, arg-type] document = prov.model.ProvDocument() self.deserialize_subtree(xml_doc, document) return document - def deserialize_subtree(self, xml_doc, bundle): + def deserialize_subtree( + self, + xml_doc: etree._Element, + bundle: prov.model.ProvDocument | prov.model.ProvBundle, + ) -> prov.model.ProvDocument | prov.model.ProvBundle: """ Deserialize an etree element containing a PROV document or a bundle and write it to the provided internal object. @@ -265,14 +285,18 @@ def deserialize_subtree(self, xml_doc, bundle): id_tag = _ns_prov("id") rec_id = element.attrib[id_tag] if id_tag in element.attrib else None - - if rec_id is not None: - # Try to make a qualified name out of it! - rec_id = xml_qname_to_QualifiedName(element, rec_id) + # Try to make a qualified name out of it! + prov_rec_id = ( + xml_qname_to_QualifiedName(element, rec_id) # type: ignore[arg-type] + if rec_id is not None + else None + ) # Recursively read bundles. if qname.localname == "bundleContent": - b = bundle.bundle(identifier=rec_id) + assert isinstance(bundle, prov.model.ProvDocument) + assert prov_rec_id is not None + b = bundle.bundle(identifier=prov_rec_id) self.deserialize_subtree(element, b) continue @@ -284,18 +308,22 @@ def deserialize_subtree(self, xml_doc, bundle): if _ns_xsi("type") in element.attrib: value = xml_qname_to_QualifiedName( - element, element.attrib[_ns_xsi("type")] + element, element.attrib[_ns_xsi("type")] # type: ignore[arg-type] ) attributes.append((PROV["type"], value)) - rec = bundle.new_record(rec_type, rec_id, attributes) + rec = bundle.new_record(rec_type, prov_rec_id, attributes) # Add the actual type in case a base type has been used. if rec_type != q_prov_name: rec.add_asserted_type(q_prov_name) return bundle - def _derive_record_label(self, rec_type, attributes): + def _derive_record_label( + self, + rec_type: prov.model.QualifiedName, + attributes: list[tuple[prov.model.QualifiedName, Any]], + ) -> str: """ Helper function trying to derive the record label taking care of subtypes and what not. It will also remove the type declaration for @@ -318,13 +346,15 @@ def _derive_record_label(self, rec_type, attributes): return rec_label -def _extract_attributes(element): +def _extract_attributes( + element: etree._Element, +) -> list[tuple[prov.model.QualifiedName, Any]]: """ Extract the PROV attributes from an etree element. :param element: The lxml.etree.Element instance. """ - attributes = [] + attributes = [] # type: list[tuple[prov.model.QualifiedName, Any]] for subel in element: sqname = etree.QName(subel) _t = xml_qname_to_QualifiedName( @@ -332,16 +362,17 @@ def _extract_attributes(element): ) for key, value in subel.attrib.items(): - if key == _ns_xsi("type"): - datatype = xml_qname_to_QualifiedName(subel, value) + value_str = value.decode("utf-8") if isinstance(value, bytes) else value + if key == _ns_prov("ref"): + _v = xml_qname_to_QualifiedName(subel, value_str) # type: Any + elif key == _ns_xsi("type"): + datatype = xml_qname_to_QualifiedName(subel, value_str) if datatype == XSD_QNAME: - _v = xml_qname_to_QualifiedName(subel, subel.text) + _v = xml_qname_to_QualifiedName(subel, subel.text) # type: ignore[arg-type] else: _v = prov.model.Literal(subel.text, datatype) - elif key == _ns_prov("ref"): - _v = xml_qname_to_QualifiedName(subel, value) elif key == _ns_xml("lang"): - _v = prov.model.Literal(subel.text, langtag=value) + _v = prov.model.Literal(subel.text, langtag=value_str) else: warnings.warn( "The element '%s' contains an attribute %s='%s' " @@ -359,7 +390,9 @@ def _extract_attributes(element): return attributes -def xml_qname_to_QualifiedName(element, qname_str): +def xml_qname_to_QualifiedName( + element: etree._Element, qname_str: str +) -> prov.model.QualifiedName: if ":" in qname_str: prefix, localpart = qname_str.split(":", 1) if prefix in element.nsmap: @@ -383,18 +416,18 @@ def xml_qname_to_QualifiedName(element, qname_str): ) -def _ns(ns, tag): +def _ns(ns: str, tag: str) -> str: return "{%s}%s" % (ns, tag) -def _ns_prov(tag): +def _ns_prov(tag: str) -> str: return _ns(DEFAULT_NAMESPACES["prov"].uri, tag) -def _ns_xsi(tag): +def _ns_xsi(tag: str) -> str: return _ns(DEFAULT_NAMESPACES["xsi"].uri, tag) -def _ns_xml(tag): +def _ns_xml(tag: str) -> str: NS_XML = "http://www.w3.org/XML/1998/namespace" return _ns(NS_XML, tag) From eaf85306a3ca48775d38d172e0b60fb1de530b75 Mon Sep 17 00:00:00 2001 From: Trung Dong Huynh Date: Tue, 10 Jun 2025 13:16:45 +0100 Subject: [PATCH 17/20] Update mypy workflow to check all in the `src` directory --- .github/workflows/mypy.yml | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index 4118b44..014f4ed 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -14,14 +14,8 @@ jobs: with: python-version: '3.x' - run: pip install mypy - - name: Get Python changed files - id: changed-py-files - uses: tj-actions/changed-files@v46 - with: - files: | - *.py - **/*.py - - name: Run if any of the listed files above is changed - if: steps.changed-py-files.outputs.any_changed == 'true' - run: mypy ${{ steps.changed-py-files.outputs.all_changed_files }} --ignore-missing-imports + - name: Install type stubs for third-party packages + run: mypy --install-types src + - name: mypy main + run: mypy --ignore-missing-imports src From 48dfa976459125cee79e6803ebb9519d9be86e93 Mon Sep 17 00:00:00 2001 From: Trung Dong Huynh Date: Tue, 10 Jun 2025 13:20:45 +0100 Subject: [PATCH 18/20] Update mypy configuration and workflow --- .github/workflows/mypy.yml | 6 ++---- pyproject.toml | 1 + 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index 014f4ed..85170ee 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -14,8 +14,6 @@ jobs: with: python-version: '3.x' - run: pip install mypy - - name: Install type stubs for third-party packages - run: mypy --install-types src - - name: mypy main - run: mypy --ignore-missing-imports src + - name: Type checking and install type stubs for third-party packages + run: mypy --install-types --non-interactive src diff --git a/pyproject.toml b/pyproject.toml index 8c7c672..edad2fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,3 +104,4 @@ exclude = ["prov/tests/*"] module = "prov.*" disallow_untyped_defs = true check_untyped_defs = true +ignore_missing_imports = true From 70434c9ba0e744485b34158a3e5192d3dc374fe1 Mon Sep 17 00:00:00 2001 From: Trung Dong Huynh Date: Tue, 10 Jun 2025 13:48:49 +0100 Subject: [PATCH 19/20] Add type ignore comments for mypy import-not-found errors in `provrdf.py` --- pyproject.toml | 1 - src/prov/serializers/provrdf.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index edad2fe..8c7c672 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -104,4 +104,3 @@ exclude = ["prov/tests/*"] module = "prov.*" disallow_untyped_defs = true check_untyped_defs = true -ignore_missing_imports = true diff --git a/src/prov/serializers/provrdf.py b/src/prov/serializers/provrdf.py index 2938b0a..425fd2d 100644 --- a/src/prov/serializers/provrdf.py +++ b/src/prov/serializers/provrdf.py @@ -10,10 +10,10 @@ import dateutil.parser -from rdflib.term import URIRef, BNode +from rdflib.term import URIRef, BNode # type: ignore[import-not-found] from rdflib.term import Literal as RDFLiteral -from rdflib.graph import ConjunctiveGraph -from rdflib.namespace import RDF, RDFS, XSD +from rdflib.graph import ConjunctiveGraph # type: ignore[import-not-found] +from rdflib.namespace import RDF, RDFS, XSD # type: ignore[import-not-found] from prov import Error import prov.model as pm From d7ebb851842b8ef2db3d48697f4624a87a8457a3 Mon Sep 17 00:00:00 2001 From: Trung Dong Huynh Date: Tue, 10 Jun 2025 13:50:45 +0100 Subject: [PATCH 20/20] Restrict mypy workflow to the master branch only --- .github/workflows/mypy.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index 85170ee..25515e2 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -1,7 +1,7 @@ name: mypy check on: push: - branches: [main, master, 143-type-hints] + branches: [main, master] pull_request: branches: [main, master, dev]