|
20 | 20 | import warnings
|
21 | 21 | from collections import OrderedDict, abc
|
22 | 22 | from difflib import get_close_matches
|
| 23 | +from importlib.metadata import requires, version |
23 | 24 | from typing import (
|
24 | 25 | TYPE_CHECKING,
|
25 | 26 | Any,
|
@@ -1092,3 +1093,91 @@ def has_c() -> bool:
|
1092 | 1093 | return True
|
1093 | 1094 | except ImportError:
|
1094 | 1095 | return False
|
| 1096 | + |
| 1097 | + |
| 1098 | +class Version(tuple[int, ...]): |
| 1099 | + """A class that can be used to compare version strings.""" |
| 1100 | + |
| 1101 | + def __new__(cls, *version: int) -> Version: |
| 1102 | + padded_version = cls._padded(version, 4) |
| 1103 | + return super().__new__(cls, tuple(padded_version)) |
| 1104 | + |
| 1105 | + @classmethod |
| 1106 | + def _padded(cls, iter: Any, length: int, padding: int = 0) -> list[int]: |
| 1107 | + as_list = list(iter) |
| 1108 | + if len(as_list) < length: |
| 1109 | + for _ in range(length - len(as_list)): |
| 1110 | + as_list.append(padding) |
| 1111 | + return as_list |
| 1112 | + |
| 1113 | + @classmethod |
| 1114 | + def from_string(cls, version_string: str) -> Version: |
| 1115 | + mod = 0 |
| 1116 | + bump_patch_level = False |
| 1117 | + if version_string.endswith("+"): |
| 1118 | + version_string = version_string[0:-1] |
| 1119 | + mod = 1 |
| 1120 | + elif version_string.endswith("-pre-"): |
| 1121 | + version_string = version_string[0:-5] |
| 1122 | + mod = -1 |
| 1123 | + elif version_string.endswith("-"): |
| 1124 | + version_string = version_string[0:-1] |
| 1125 | + mod = -1 |
| 1126 | + # Deal with .devX substrings |
| 1127 | + if ".dev" in version_string: |
| 1128 | + version_string = version_string[0 : version_string.find(".dev")] |
| 1129 | + mod = -1 |
| 1130 | + # Deal with '-rcX' substrings |
| 1131 | + if "-rc" in version_string: |
| 1132 | + version_string = version_string[0 : version_string.find("-rc")] |
| 1133 | + mod = -1 |
| 1134 | + # Deal with git describe generated substrings |
| 1135 | + elif "-" in version_string: |
| 1136 | + version_string = version_string[0 : version_string.find("-")] |
| 1137 | + mod = -1 |
| 1138 | + bump_patch_level = True |
| 1139 | + |
| 1140 | + version = [int(part) for part in version_string.split(".")] |
| 1141 | + version = cls._padded(version, 3) |
| 1142 | + # Make from_string and from_version_array agree. For example: |
| 1143 | + # MongoDB Enterprise > db.runCommand('buildInfo').versionArray |
| 1144 | + # [ 3, 2, 1, -100 ] |
| 1145 | + # MongoDB Enterprise > db.runCommand('buildInfo').version |
| 1146 | + # 3.2.0-97-g1ef94fe |
| 1147 | + if bump_patch_level: |
| 1148 | + version[-1] += 1 |
| 1149 | + version.append(mod) |
| 1150 | + |
| 1151 | + return Version(*version) |
| 1152 | + |
| 1153 | + @classmethod |
| 1154 | + def from_version_array(cls, version_array: Any) -> Version: |
| 1155 | + version = list(version_array) |
| 1156 | + if version[-1] < 0: |
| 1157 | + version[-1] = -1 |
| 1158 | + version = cls._padded(version, 3) |
| 1159 | + return Version(*version) |
| 1160 | + |
| 1161 | + def at_least(self, *other_version: Any) -> bool: |
| 1162 | + return self >= Version(*other_version) |
| 1163 | + |
| 1164 | + def __str__(self) -> str: |
| 1165 | + return ".".join(map(str, self)) |
| 1166 | + |
| 1167 | + |
| 1168 | +def check_for_min_version(package_name: str) -> tuple[str, str, bool]: |
| 1169 | + """Test whether an installed package is of the desired version.""" |
| 1170 | + package_version_str = version(package_name) |
| 1171 | + package_version = Version.from_string(package_version_str) |
| 1172 | + # Dependency is expected to be in one of the forms: |
| 1173 | + # "pymongocrypt<2.0.0,>=1.13.0; extra == 'encryption'" |
| 1174 | + # 'dnspython<3.0.0,>=1.16.0' |
| 1175 | + # |
| 1176 | + requirements = requires("pymongo") |
| 1177 | + assert requirements is not None |
| 1178 | + requirement = [i for i in requirements if i.startswith(package_name)][0] # noqa: RUF015 |
| 1179 | + if ";" in requirement: |
| 1180 | + requirement = requirement.split(";")[0] |
| 1181 | + required_version = requirement[requirement.find(">=") + 2 :] |
| 1182 | + is_valid = package_version >= Version.from_string(required_version) |
| 1183 | + return package_version_str, required_version, is_valid |
0 commit comments