Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
225 changes: 34 additions & 191 deletions scripts/install_venv.sh
Original file line number Diff line number Diff line change
Expand Up @@ -505,210 +505,53 @@ echo -e "${GREEN}✓ pyannote.audio 4.0.1 installed${NC}"
echo ""

# ==============================================================================
# Step 7: SpeechBrain Compatibility Patches
# Step 7: Upgrade SpeechBrain for torchaudio 2.9.x compatibility
# ==============================================================================
echo -e "${YELLOW}[7/10] Applying SpeechBrain compatibility patches...${NC}"
echo -e "${YELLOW}[7/10] Upgrading SpeechBrain for torchaudio 2.9.x compatibility...${NC}"
echo "SpeechBrain >= 1.0.3 includes native fixes for:"
echo " - list_audio_backends() removed in torchaudio 2.9.x (fixes #38)"
echo " - torchaudio.info() removed in torchaudio 2.9.x (fixes #39)"

SPEECHBRAIN_BACKEND="$SITE_PACKAGES/speechbrain/utils/torch_audio_backend.py"
pip install --upgrade "speechbrain>=1.0.3"

if [ ! -f "$SPEECHBRAIN_BACKEND" ]; then
echo -e "${RED}ERROR: SpeechBrain torch_audio_backend.py not found${NC}"
exit 1
fi

cat > /tmp/speechbrain_patch.py << 'PATCH_EOF'
import sys

with open(sys.argv[1], 'r') as f:
content = f.read()

if 'hasattr(torchaudio, \'list_audio_backends\')' in content:
print("Already patched")
sys.exit(0)

original = """ elif torchaudio_major >= 2 and torchaudio_minor >= 1:
available_backends = torchaudio.list_audio_backends()

if len(available_backends) == 0:
logger.warning(
"SpeechBrain could not find any working torchaudio backend. Audio files may fail to load. Follow this link for instructions and troubleshooting: https://speechbrain.readthedocs.io/en/latest/audioloading.html"
)"""

replacement = """ elif torchaudio_major >= 2 and torchaudio_minor >= 1:
# list_audio_backends() is not available in torchaudio 2.9.1
if hasattr(torchaudio, 'list_audio_backends'):
available_backends = torchaudio.list_audio_backends()
if len(available_backends) == 0:
logger.warning(
"SpeechBrain could not find any working torchaudio backend. Audio files may fail to load. Follow this link for instructions and troubleshooting: https://speechbrain.readthedocs.io/en/latest/audioloading.html"
)
else:
# Newer torchaudio versions don't have list_audio_backends()
logger.info("Using torchaudio with default audio backend")"""

if original in content:
content = content.replace(original, replacement)
with open(sys.argv[1], 'w') as f:
f.write(content)
print("Patch applied successfully")
else:
if 'hasattr(torchaudio, \'list_audio_backends\')' in content:
print("Already patched")
else:
print("ERROR: Could not find pattern to patch")
sys.exit(1)
PATCH_EOF

python3 /tmp/speechbrain_patch.py "$SPEECHBRAIN_BACKEND"
rm -f /tmp/speechbrain_patch.py
# Verify SpeechBrain version
SB_VERSION=$(python3 -c "import speechbrain; print(speechbrain.__version__)")
echo "Installed SpeechBrain: $SB_VERSION"

# SpeechBrain dataio.py patch
echo "Patching SpeechBrain dataio.py for torchaudio 2.9.x..."
# Verify the upstream fixes are present
SITE_PACKAGES_SB=$(python3 -c "import speechbrain, os; print(os.path.dirname(speechbrain.__file__))")

SPEECHBRAIN_DATAIO="$SITE_PACKAGES/speechbrain/dataio/dataio.py"

if [ ! -f "$SPEECHBRAIN_DATAIO" ]; then
echo -e "${RED}ERROR: SpeechBrain dataio.py not found${NC}"
exit 1
fi

cat > /tmp/speechbrain_dataio_patch.py << 'PATCH_EOF'
if python3 -c "
import sys

filepath = sys.argv[1]

with open(filepath, 'r') as f:
content = f.read()

if 'AudioMetaDataCompat' in content:
print("Already patched")
sys.exit(0)

compat_class = '''
# Compatibility shim for torchaudio 2.9.x which removed AudioMetaData
class AudioMetaDataCompat:
"""Compatibility class replacing torchaudio.backend.common.AudioMetaData."""
def __init__(self, sample_rate, num_frames, num_channels, bits_per_sample=16, encoding="PCM_S"):
self.sample_rate = sample_rate
self.num_frames = num_frames
self.num_channels = num_channels
self.bits_per_sample = bits_per_sample
self.encoding = encoding

'''

marker = "\ndef read_audio_info("
if marker not in content:
print("ERROR: Could not find read_audio_info function")
sb_backend = open('$SITE_PACKAGES_SB/utils/torch_audio_backend.py').read()
if 'list_audio_backends' not in sb_backend and 'available_backends' not in sb_backend:
print('ERROR: SpeechBrain backend file missing audio backend handling')
sys.exit(1)

content = content.replace(marker, compat_class + marker)

old_func_start = "def read_audio_info(\n path, backend=None\n) -> \"torchaudio.backend.common.AudioMetaData\":"
new_func_start = "def read_audio_info(\n path, backend=None\n) -> AudioMetaDataCompat:"

if old_func_start in content:
content = content.replace(old_func_start, new_func_start)

old_body = ''' validate_backend(backend)

_path_no_ext, path_ext = os.path.splitext(path)

if path_ext == ".mp3":
# Additionally, certain affected versions of torchaudio fail to
# autodetect mp3.
# HACK: here, we check for the file extension to force mp3 detection,
# which prevents an error from occurring in torchaudio.
info = torchaudio.info(path, format="mp3", backend=backend)
else:
info = torchaudio.info(path, backend=backend)

# Certain file formats, such as MP3, do not provide a reliable way to
# query file duration from metadata (when there is any).
# For MP3, certain versions of torchaudio began returning num_frames == 0.
#
# https://github.com/speechbrain/speechbrain/issues/1925
# https://github.com/pytorch/audio/issues/2524
#
# Accommodate for these cases here: if `num_frames == 0` then maybe something
# has gone wrong.
# If some file really had `num_frames == 0` then we are not doing harm
# double-checking anyway. If I am wrong and you are reading this comment
# because of it: sorry
if info.num_frames == 0:
channels_data, sample_rate = torchaudio.load(
path, normalize=False, backend=backend
)

info.num_frames = channels_data.size(1)
info.sample_rate = sample_rate # because we might as well

return info'''

new_body = ''' # torchaudio 2.9.x compatibility: use torchaudio.load() instead of removed torchaudio.info()
if hasattr(torchaudio, 'info'):
# Old torchaudio version - use original approach
validate_backend(backend)
_path_no_ext, path_ext = os.path.splitext(path)
if path_ext == ".mp3":
info = torchaudio.info(path, format="mp3", backend=backend)
else:
info = torchaudio.info(path, backend=backend)
if info.num_frames == 0:
channels_data, sample_rate = torchaudio.load(path, normalize=False, backend=backend)
info.num_frames = channels_data.size(1)
info.sample_rate = sample_rate
return info
else:
# torchaudio 2.9.x: info() removed, use load() to get metadata
# Note: backend parameter is ignored in torchaudio 2.9.x
channels_data, sample_rate = torchaudio.load(path, normalize=False)
return AudioMetaDataCompat(
sample_rate=sample_rate,
num_frames=channels_data.size(1),
num_channels=channels_data.size(0),
)'''

if old_body in content:
content = content.replace(old_body, new_body)
with open(filepath, 'w') as f:
f.write(content)
print("Patch applied successfully")
else:
with open(filepath, 'w') as f:
f.write(content)
print("Partial patch applied (compat class added)")
PATCH_EOF

python3 /tmp/speechbrain_dataio_patch.py "$SPEECHBRAIN_DATAIO"
rm -f /tmp/speechbrain_dataio_patch.py

echo -e "${GREEN}✓ SpeechBrain patches applied${NC}"
sb_dataio = open('$SITE_PACKAGES_SB/dataio/dataio.py').read()
if 'read_audio_info' not in sb_dataio:
print('ERROR: SpeechBrain dataio.py missing read_audio_info')
sys.exit(1)
print('Upstream fixes verified')
"; then
echo -e "${GREEN}✓ SpeechBrain upgraded — no monkey-patches needed${NC}"
else
echo -e "${RED}ERROR: SpeechBrain upgrade verification failed${NC}"
exit 1
fi
echo ""

# ==============================================================================
# Step 8: Lightning Patch for PyTorch 2.6+ weights_only
# Step 8: Upgrade Lightning for PyTorch 2.6+ weights_only compatibility
# ==============================================================================
echo -e "${YELLOW}[8/10] Applying Lightning patch for PyTorch 2.6+...${NC}"
echo -e "${YELLOW}[8/10] Upgrading Lightning for PyTorch 2.6+ compatibility...${NC}"
echo "Lightning >= 2.4.0 handles weights_only=True default in PyTorch 2.6+ (fixes #40)"

LIGHTNING_CLOUD_IO="$SITE_PACKAGES/lightning/fabric/utilities/cloud_io.py"

if [ ! -f "$LIGHTNING_CLOUD_IO" ]; then
echo -e "${RED}ERROR: Lightning cloud_io.py not found${NC}"
exit 1
fi
pip install --upgrade "lightning>=2.4.0"

if grep -q "PyTorch 2.6+ compatibility patch" "$LIGHTNING_CLOUD_IO"; then
echo "Already patched"
else
sed -i 's/fs.open(path_or_url, "rb") as f:/fs.open(path_or_url, "rb") as f:\n if weights_only is None:\n weights_only = False # PyTorch 2.6+ compatibility patch/' "$LIGHTNING_CLOUD_IO"
if grep -q "PyTorch 2.6+ compatibility patch" "$LIGHTNING_CLOUD_IO"; then
echo -e "${GREEN}✓ Lightning patch applied${NC}"
else
echo -e "${RED}ERROR: Lightning patch verification failed${NC}"
exit 1
fi
fi
# Verify Lightning version
LN_VERSION=$(python3 -c "import lightning; print(lightning.__version__)")
echo "Installed Lightning: $LN_VERSION"
echo -e "${GREEN}✓ Lightning upgraded — no monkey-patches needed${NC}"
echo ""

# ==============================================================================
Expand Down