Skip to content

Commit 6e8e7db

Browse files
authored
Update dependency Updater (#9302)
1 parent 6ed3ca1 commit 6e8e7db

File tree

1 file changed

+148
-34
lines changed

1 file changed

+148
-34
lines changed

scripts/update_deps.py

Lines changed: 148 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import re
2020
import sys
2121
from typing import Optional
22+
from html.parser import HTMLParser
23+
import urllib.request
2224

2325
logger = logging.getLogger(__name__)
2426

@@ -35,8 +37,67 @@
3537

3638
# Page listing libtpu nightly builds.
3739
_LIBTPU_BUILDS_URL = 'https://storage.googleapis.com/libtpu-wheels/index.html'
38-
# Page listing jax nightly builds.
39-
_JAX_BUILDS_URL = 'https://storage.googleapis.com/jax-releases/jax_nightly_releases.html'
40+
# New JAX package index URLs (PEP 503 compliant)
41+
_JAX_INDEX_URL = 'https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/'
42+
_JAX_PROJECT_URL = _JAX_INDEX_URL + 'jax/'
43+
_JAXLIB_PROJECT_URL = _JAX_INDEX_URL + 'jaxlib/'
44+
45+
46+
class PEP503Parser(HTMLParser):
47+
"""Parser for PEP 503 simple repository API pages.
48+
49+
This parser extracts all links and their link text from the
50+
HTML content of a PEP 503 index page.
51+
"""
52+
53+
links: list[tuple[str, str]]
54+
"""List of (href, text) tuples for all links found."""
55+
56+
_current_link: str | None
57+
"""The current link being processed."""
58+
59+
_current_text: str
60+
"""The text content of the current link being processed."""
61+
62+
def __init__(self):
63+
super().__init__()
64+
self.links = []
65+
self._current_link = None
66+
self._current_text = ""
67+
68+
def handle_starttag(self, tag: str, attrs: list[tuple[str,
69+
str | None]]) -> None:
70+
"""Handles the start of an HTML tag.
71+
72+
Starts processing a link if the tag is an anchor (<a>).
73+
"""
74+
if tag == 'a':
75+
href = None
76+
for attr, value in attrs:
77+
if attr == 'href':
78+
href = value
79+
break
80+
if href:
81+
self._current_link = href
82+
self._current_text = ""
83+
84+
def handle_data(self, data: str) -> None:
85+
"""Handles the text data within an HTML tag.
86+
87+
If currently processing a link, appends the data to the current text.
88+
"""
89+
if self._current_link:
90+
self._current_text += data
91+
92+
def handle_endtag(self, tag: str) -> None:
93+
"""Handles the end of an HTML tag.
94+
95+
If the tag is an anchor (<a>), adds the link and its text to the list.
96+
"""
97+
if tag == 'a' and self._current_link:
98+
self.links.append((self._current_link, self._current_text.strip()))
99+
self._current_link = None
100+
self._current_text = ""
40101

41102

42103
def clean_tmp_dir() -> None:
@@ -48,7 +109,7 @@ def clean_tmp_dir() -> None:
48109

49110
def get_last_xla_commit_and_date() -> tuple[str, str]:
50111
"""Finds the latest commit in the master branch of https://github.com/openxla/xla.
51-
112+
52113
Returns:
53114
A tuple of the latest commit SHA and its date (YYYY-MM-DD).
54115
"""
@@ -68,7 +129,7 @@ def get_last_xla_commit_and_date() -> tuple[str, str]:
68129

69130
def update_openxla() -> bool:
70131
"""Updates the OpenXLA version in the WORKSPACE file to the latest commit.
71-
132+
72133
Returns:
73134
True if the WORKSPACE file was updated, False otherwise.
74135
"""
@@ -100,7 +161,7 @@ def update_openxla() -> bool:
100161
def find_latest_nightly(html_lines: list[str],
101162
build_re: str) -> Optional[tuple[str, str, str]]:
102163
"""Finds the latest nightly build from the list of HTML lines.
103-
164+
104165
Args:
105166
html_lines: A list of HTML lines to search for the nightly build.
106167
build_re: A regular expression for matching the nightly build line.
@@ -130,7 +191,7 @@ def find_latest_nightly(html_lines: list[str],
130191

131192
def find_latest_libtpu_nightly() -> Optional[tuple[str, str, str]]:
132193
"""Finds the latest libtpu nightly build for the current platform.
133-
194+
134195
Returns:
135196
A tuple of the version, date, and suffix of the latest libtpu nightly build,
136197
or None if no build is found.
@@ -151,53 +212,106 @@ def find_latest_libtpu_nightly() -> Optional[tuple[str, str, str]]:
151212
_PLATFORM + r'\.whl</a>')
152213

153214

215+
def fetch_pep503_page(url: str) -> list[tuple[str, str]]:
216+
"""Fetches and parses a PEP 503 index page.
217+
218+
Args:
219+
url: The URL of the PEP 503 index page.
220+
221+
Returns:
222+
A list of (href, text) tuples for all links on the page.
223+
"""
224+
try:
225+
with urllib.request.urlopen(url) as response:
226+
html = response.read().decode('utf-8')
227+
228+
parser = PEP503Parser()
229+
parser.feed(html)
230+
return parser.links
231+
except Exception as e:
232+
logger.error(f'Failed to fetch {url}: {e}')
233+
return []
234+
235+
154236
def find_latest_jax_nightly() -> Optional[tuple[str, str, str]]:
155-
"""Finds the latest JAX nightly build.
156-
237+
"""Finds the latest JAX nightly build using the new package index.
238+
157239
Returns:
158240
A tuple of the jax version, jaxlib version, and date of the latest JAX nightly build,
159241
or None if no build is found.
160242
"""
161243

162-
# Read the nightly jax build page.
163-
clean_tmp_dir()
164-
os.system('curl -s {} > {}/jax_builds.html'.format(_JAX_BUILDS_URL, _TMP_DIR))
165-
with open(f'{_TMP_DIR}/jax_builds.html', 'r') as f:
166-
html_lines = f.readlines()
244+
def parse_version_date(url: str, pattern: str) -> list[tuple[str, str]]:
245+
links = fetch_pep503_page(url)
246+
if not links:
247+
logger.error(f'Could not fetch packages from {url}')
248+
return []
249+
compiled = re.compile(pattern)
250+
results = []
251+
for href, text in links:
252+
filename = text if text else href.split('/')[-1].split('#')[0]
253+
m = compiled.match(filename)
254+
if m:
255+
version, date = m.groups()
256+
results.append((version, date))
257+
return results
258+
259+
# Find JAX libraries.
260+
#
261+
# Look for patterns like: jax-0.6.1.dev20250428-py3-none-any.whl
262+
# Group 1: Represents the JAX version (formatted as a series of digits and dots).
263+
# Group 2: Represents the build date (an 8-digit string typically in YYYYMMDD format).
264+
jax_versions_dates = parse_version_date(
265+
_JAX_PROJECT_URL, r'jax-([\d.]+)\.dev(\d{8})-py3-none-any\.whl')
266+
if not jax_versions_dates:
267+
logger.error(f"Could not fetch JAX packages from {_JAX_PROJECT_URL}")
268+
return None
167269

168-
# Find lines like
169-
# <a href=...>jax/jax-0.6.1.dev20250428-py3-none-any.whl</a>
170-
jax_build = find_latest_nightly(
171-
html_lines, r'.*<a href=.*?>jax/jax-(.*?)\.dev(\d{8})-(.*)\.whl</a>')
172-
if not jax_build:
173-
logger.error(
174-
f'Could not find latest jax nightly build in {_JAX_BUILDS_URL}.')
270+
# Fetch jaxlib libraries
271+
#
272+
# Look for patterns like: jaxlib-0.6.1.dev20250428-cp310-cp310-manylinux2014_x86_64.whl
273+
# Group 1: Represents the jaxlib version (formatted as a series of digits and dots).
274+
# Group 2: Represents the build date (an 8-digit string typically in YYYYMMDD format).
275+
jaxlib_versions_dates = parse_version_date(
276+
_JAXLIB_PROJECT_URL, r'jaxlib-([\d.]+)\.dev(\d{8})-.*\.whl')
277+
if not jaxlib_versions_dates:
278+
logger.error(f"Could not fetch jaxlib packages from {_JAXLIB_PROJECT_URL}")
175279
return None
176280

177-
# Find lines like
178-
# <a href=...>nocuda/jaxlib-0.6.1.dev20250428-....whl</a>
179-
jaxlib_build = find_latest_nightly(
180-
html_lines,
181-
r'.*<a href=.*?>nocuda/jaxlib-(.*?)\.dev(\d{8})-(.*)\.whl</a>')
182-
if not jaxlib_build:
281+
latest_jax_version = ''
282+
latest_jax_date = ''
283+
for version, date in jax_versions_dates:
284+
if date > latest_jax_date:
285+
latest_jax_version = version
286+
latest_jax_date = date
287+
288+
if not latest_jax_version:
183289
logger.error(
184-
f'Could not find latest jaxlib nightly build in {_JAX_BUILDS_URL}.')
290+
f'Could not find any JAX nightly builds. Tried parsing {_JAX_PROJECT_URL}'
291+
)
185292
return None
186293

187-
jax_version, jax_date, _ = jax_build
188-
jaxlib_version, jaxlib_date, _ = jaxlib_build
189-
if jax_date != jaxlib_date:
294+
latest_jaxlib_version = ''
295+
for version, date in jaxlib_versions_dates:
296+
# Only consider jaxlib builds from the same date as JAX
297+
if date == latest_jax_date and version > latest_jaxlib_version:
298+
latest_jaxlib_version = version
299+
300+
if not latest_jaxlib_version:
190301
logger.error(
191-
f'The latest jax date {jax_date} != the latest jaxlib date {jaxlib_date} in {_JAX_BUILDS_URL}.'
302+
f'Could not find jaxlib nightly build for date {latest_jax_date}. Tried parsing {_JAXLIB_PROJECT_URL}'
192303
)
193304
return None
194305

195-
return jax_version, jaxlib_version, jax_date
306+
logger.info(
307+
f'Found JAX {latest_jax_version} and jaxlib {latest_jaxlib_version} from {latest_jax_date}'
308+
)
309+
return latest_jax_version, latest_jaxlib_version, latest_jax_date
196310

197311

198312
def update_libtpu() -> bool:
199313
"""Updates the libtpu version in setup.py to the latest nightly build.
200-
314+
201315
Returns:
202316
True if the setup.py file was updated, False otherwise.
203317
"""
@@ -248,7 +362,7 @@ def update_libtpu() -> bool:
248362

249363
def update_jax() -> bool:
250364
"""Updates the jax/jaxlib versions in setup.py to the latest nightly build.
251-
365+
252366
Returns:
253367
True if the setup.py file was updated, False otherwise.
254368
"""

0 commit comments

Comments
 (0)