diff --git a/tiktoken/load.py b/tiktoken/load.py index 295deb9f..bd5ad4a5 100644 --- a/tiktoken/load.py +++ b/tiktoken/load.py @@ -3,10 +3,24 @@ import base64 import hashlib import os +import io +import urllib.parse def read_file(blobpath: str) -> bytes: - if not blobpath.startswith("http://") and not blobpath.startswith("https://"): + url = urllib.parse.urlparse(blobpath) + if url.scheme is None or url.scheme == "": + with open(blobpath, "rb") as f: + with io.BufferedReader(f) as br: + return br.read() + elif url.scheme in ["http", "https"]: + # avoiding blobfile for public files helps avoid auth issues, like MFA prompts + import requests + + resp = requests.get(blobpath) + resp.raise_for_status() + return resp.content + else: try: import blobfile except ImportError as e: @@ -16,13 +30,6 @@ def read_file(blobpath: str) -> bytes: with blobfile.BlobFile(blobpath, "rb") as f: return f.read() - # avoiding blobfile for public files helps avoid auth issues, like MFA prompts - import requests - - resp = requests.get(blobpath) - resp.raise_for_status() - return resp.content - def check_hash(data: bytes, expected_hash: str) -> bool: actual_hash = hashlib.sha256(data).hexdigest()