1919import re
2020import sys
2121from typing import Optional
22+ from html .parser import HTMLParser
23+ import urllib .request
2224
2325logger = logging .getLogger (__name__ )
2426
3537
3638# Page listing libtpu nightly builds.
3739_LIBTPU_BUILDS_URL = 'https://storage.googleapis.com/libtpu-wheels/index.html'
38- # Page listing jax nightly builds.
39- _JAX_BUILDS_URL = 'https://storage.googleapis.com/jax-releases/jax_nightly_releases.html'
40+ # New JAX package index URLs (PEP 503 compliant)
41+ _JAX_INDEX_URL = 'https://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/'
42+ _JAX_PROJECT_URL = _JAX_INDEX_URL + 'jax/'
43+ _JAXLIB_PROJECT_URL = _JAX_INDEX_URL + 'jaxlib/'
44+
45+
46+ class PEP503Parser (HTMLParser ):
47+ """Parser for PEP 503 simple repository API pages.
48+
49+ This parser extracts all links and their link text from the
50+ HTML content of a PEP 503 index page.
51+ """
52+
53+ links : list [tuple [str , str ]]
54+ """List of (href, text) tuples for all links found."""
55+
56+ _current_link : str | None
57+ """The current link being processed."""
58+
59+ _current_text : str
60+ """The text content of the current link being processed."""
61+
62+ def __init__ (self ):
63+ super ().__init__ ()
64+ self .links = []
65+ self ._current_link = None
66+ self ._current_text = ""
67+
68+ def handle_starttag (self , tag : str , attrs : list [tuple [str ,
69+ str | None ]]) -> None :
70+ """Handles the start of an HTML tag.
71+
72+ Starts processing a link if the tag is an anchor (<a>).
73+ """
74+ if tag == 'a' :
75+ href = None
76+ for attr , value in attrs :
77+ if attr == 'href' :
78+ href = value
79+ break
80+ if href :
81+ self ._current_link = href
82+ self ._current_text = ""
83+
84+ def handle_data (self , data : str ) -> None :
85+ """Handles the text data within an HTML tag.
86+
87+ If currently processing a link, appends the data to the current text.
88+ """
89+ if self ._current_link :
90+ self ._current_text += data
91+
92+ def handle_endtag (self , tag : str ) -> None :
93+ """Handles the end of an HTML tag.
94+
95+ If the tag is an anchor (<a>), adds the link and its text to the list.
96+ """
97+ if tag == 'a' and self ._current_link :
98+ self .links .append ((self ._current_link , self ._current_text .strip ()))
99+ self ._current_link = None
100+ self ._current_text = ""
40101
41102
42103def clean_tmp_dir () -> None :
@@ -48,7 +109,7 @@ def clean_tmp_dir() -> None:
48109
49110def get_last_xla_commit_and_date () -> tuple [str , str ]:
50111 """Finds the latest commit in the master branch of https://github.com/openxla/xla.
51-
112+
52113 Returns:
53114 A tuple of the latest commit SHA and its date (YYYY-MM-DD).
54115 """
@@ -68,7 +129,7 @@ def get_last_xla_commit_and_date() -> tuple[str, str]:
68129
69130def update_openxla () -> bool :
70131 """Updates the OpenXLA version in the WORKSPACE file to the latest commit.
71-
132+
72133 Returns:
73134 True if the WORKSPACE file was updated, False otherwise.
74135 """
@@ -100,7 +161,7 @@ def update_openxla() -> bool:
100161def find_latest_nightly (html_lines : list [str ],
101162 build_re : str ) -> Optional [tuple [str , str , str ]]:
102163 """Finds the latest nightly build from the list of HTML lines.
103-
164+
104165 Args:
105166 html_lines: A list of HTML lines to search for the nightly build.
106167 build_re: A regular expression for matching the nightly build line.
@@ -130,7 +191,7 @@ def find_latest_nightly(html_lines: list[str],
130191
131192def find_latest_libtpu_nightly () -> Optional [tuple [str , str , str ]]:
132193 """Finds the latest libtpu nightly build for the current platform.
133-
194+
134195 Returns:
135196 A tuple of the version, date, and suffix of the latest libtpu nightly build,
136197 or None if no build is found.
@@ -151,53 +212,106 @@ def find_latest_libtpu_nightly() -> Optional[tuple[str, str, str]]:
151212 _PLATFORM + r'\.whl</a>' )
152213
153214
215+ def fetch_pep503_page (url : str ) -> list [tuple [str , str ]]:
216+ """Fetches and parses a PEP 503 index page.
217+
218+ Args:
219+ url: The URL of the PEP 503 index page.
220+
221+ Returns:
222+ A list of (href, text) tuples for all links on the page.
223+ """
224+ try :
225+ with urllib .request .urlopen (url ) as response :
226+ html = response .read ().decode ('utf-8' )
227+
228+ parser = PEP503Parser ()
229+ parser .feed (html )
230+ return parser .links
231+ except Exception as e :
232+ logger .error (f'Failed to fetch { url } : { e } ' )
233+ return []
234+
235+
154236def find_latest_jax_nightly () -> Optional [tuple [str , str , str ]]:
155- """Finds the latest JAX nightly build.
156-
237+ """Finds the latest JAX nightly build using the new package index .
238+
157239 Returns:
158240 A tuple of the jax version, jaxlib version, and date of the latest JAX nightly build,
159241 or None if no build is found.
160242 """
161243
162- # Read the nightly jax build page.
163- clean_tmp_dir ()
164- os .system ('curl -s {} > {}/jax_builds.html' .format (_JAX_BUILDS_URL , _TMP_DIR ))
165- with open (f'{ _TMP_DIR } /jax_builds.html' , 'r' ) as f :
166- html_lines = f .readlines ()
244+ def parse_version_date (url : str , pattern : str ) -> list [tuple [str , str ]]:
245+ links = fetch_pep503_page (url )
246+ if not links :
247+ logger .error (f'Could not fetch packages from { url } ' )
248+ return []
249+ compiled = re .compile (pattern )
250+ results = []
251+ for href , text in links :
252+ filename = text if text else href .split ('/' )[- 1 ].split ('#' )[0 ]
253+ m = compiled .match (filename )
254+ if m :
255+ version , date = m .groups ()
256+ results .append ((version , date ))
257+ return results
258+
259+ # Find JAX libraries.
260+ #
261+ # Look for patterns like: jax-0.6.1.dev20250428-py3-none-any.whl
262+ # Group 1: Represents the JAX version (formatted as a series of digits and dots).
263+ # Group 2: Represents the build date (an 8-digit string typically in YYYYMMDD format).
264+ jax_versions_dates = parse_version_date (
265+ _JAX_PROJECT_URL , r'jax-([\d.]+)\.dev(\d{8})-py3-none-any\.whl' )
266+ if not jax_versions_dates :
267+ logger .error (f"Could not fetch JAX packages from { _JAX_PROJECT_URL } " )
268+ return None
167269
168- # Find lines like
169- # <a href=...>jax/jax-0.6.1.dev20250428-py3-none-any.whl</a>
170- jax_build = find_latest_nightly (
171- html_lines , r'.*<a href=.*?>jax/jax-(.*?)\.dev(\d{8})-(.*)\.whl</a>' )
172- if not jax_build :
173- logger .error (
174- f'Could not find latest jax nightly build in { _JAX_BUILDS_URL } .' )
270+ # Fetch jaxlib libraries
271+ #
272+ # Look for patterns like: jaxlib-0.6.1.dev20250428-cp310-cp310-manylinux2014_x86_64.whl
273+ # Group 1: Represents the jaxlib version (formatted as a series of digits and dots).
274+ # Group 2: Represents the build date (an 8-digit string typically in YYYYMMDD format).
275+ jaxlib_versions_dates = parse_version_date (
276+ _JAXLIB_PROJECT_URL , r'jaxlib-([\d.]+)\.dev(\d{8})-.*\.whl' )
277+ if not jaxlib_versions_dates :
278+ logger .error (f"Could not fetch jaxlib packages from { _JAXLIB_PROJECT_URL } " )
175279 return None
176280
177- # Find lines like
178- # <a href=...>nocuda/jaxlib-0.6.1.dev20250428-....whl</a>
179- jaxlib_build = find_latest_nightly (
180- html_lines ,
181- r'.*<a href=.*?>nocuda/jaxlib-(.*?)\.dev(\d{8})-(.*)\.whl</a>' )
182- if not jaxlib_build :
281+ latest_jax_version = ''
282+ latest_jax_date = ''
283+ for version , date in jax_versions_dates :
284+ if date > latest_jax_date :
285+ latest_jax_version = version
286+ latest_jax_date = date
287+
288+ if not latest_jax_version :
183289 logger .error (
184- f'Could not find latest jaxlib nightly build in { _JAX_BUILDS_URL } .' )
290+ f'Could not find any JAX nightly builds. Tried parsing { _JAX_PROJECT_URL } '
291+ )
185292 return None
186293
187- jax_version , jax_date , _ = jax_build
188- jaxlib_version , jaxlib_date , _ = jaxlib_build
189- if jax_date != jaxlib_date :
294+ latest_jaxlib_version = ''
295+ for version , date in jaxlib_versions_dates :
296+ # Only consider jaxlib builds from the same date as JAX
297+ if date == latest_jax_date and version > latest_jaxlib_version :
298+ latest_jaxlib_version = version
299+
300+ if not latest_jaxlib_version :
190301 logger .error (
191- f'The latest jax date { jax_date } != the latest jaxlib date { jaxlib_date } in { _JAX_BUILDS_URL } . '
302+ f'Could not find jaxlib nightly build for date { latest_jax_date } . Tried parsing { _JAXLIB_PROJECT_URL } '
192303 )
193304 return None
194305
195- return jax_version , jaxlib_version , jax_date
306+ logger .info (
307+ f'Found JAX { latest_jax_version } and jaxlib { latest_jaxlib_version } from { latest_jax_date } '
308+ )
309+ return latest_jax_version , latest_jaxlib_version , latest_jax_date
196310
197311
198312def update_libtpu () -> bool :
199313 """Updates the libtpu version in setup.py to the latest nightly build.
200-
314+
201315 Returns:
202316 True if the setup.py file was updated, False otherwise.
203317 """
@@ -248,7 +362,7 @@ def update_libtpu() -> bool:
248362
249363def update_jax () -> bool :
250364 """Updates the jax/jaxlib versions in setup.py to the latest nightly build.
251-
365+
252366 Returns:
253367 True if the setup.py file was updated, False otherwise.
254368 """
0 commit comments