From 79a22ade8285825a6bea92d059841c6d8c079989 Mon Sep 17 00:00:00 2001 From: alangnt Date: Sun, 29 Mar 2026 13:19:12 +0200 Subject: [PATCH 1/4] feat: cuda if cuda available otherwise cpu + propagate device to feature extractors --- tribev2/demo_utils.py | 18 +++++++++++++++++- tribev2/eventstransforms.py | 9 +++++++-- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/tribev2/demo_utils.py b/tribev2/demo_utils.py index fa735a7..4b5d09c 100644 --- a/tribev2/demo_utils.py +++ b/tribev2/demo_utils.py @@ -190,7 +190,12 @@ def from_pretrained( if cache_folder is not None: Path(cache_folder).mkdir(parents=True, exist_ok=True) if device == "auto": - device = "cuda" if torch.cuda.is_available() else "cpu" + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + else: + device = "cpu" checkpoint_dir = Path(checkpoint_dir) if checkpoint_dir.exists(): config_path = checkpoint_dir / "config.yaml" @@ -238,6 +243,17 @@ def from_pretrained( model.to(device) model.eval() xp._model = model + # Propagate device to feature extractors (and nested sub-extractors) + for modality in xp.data.features_to_use: + extractor = getattr(xp.data, f"{modality}_feature", None) + if extractor is not None: + if hasattr(extractor, "device"): + extractor.device = device + # Handle nested extractors (e.g. HuggingFaceVideo.image) + for field in extractor.model_fields: + sub = getattr(extractor, field, None) + if hasattr(sub, "device"): + sub.device = device return xp def get_events_dataframe( diff --git a/tribev2/eventstransforms.py b/tribev2/eventstransforms.py index 1fd0c69..d2f7fde 100644 --- a/tribev2/eventstransforms.py +++ b/tribev2/eventstransforms.py @@ -104,8 +104,13 @@ def _get_transcript_from_audio(wav_filename: Path, language: str) -> pd.DataFram if language not in language_codes: raise ValueError(f"Language {language} not supported") - device = "cuda" if torch.cuda.is_available() else "cpu" - compute_type = "float16" + # whisperx (ctranslate2) only supports cuda or cpu + if torch.cuda.is_available(): + device = "cuda" + compute_type = "float16" + else: + device = "cpu" + compute_type = "int8" with tempfile.TemporaryDirectory() as output_dir: logger.info("Running whisperx via uvx...") From 083c66c4fafa725e9a5da53496314a319c5e2107 Mon Sep 17 00:00:00 2001 From: alangnt Date: Sun, 29 Mar 2026 13:19:38 +0200 Subject: [PATCH 2/4] chore: updated the README file to explain how to run it properly on Apple Silicon --- README.md | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/README.md b/README.md index b63cd36..d540d25 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,28 @@ pip install -e ".[plotting]" pip install -e ".[training]" ``` +## Running on Apple Silicon (macOS) + +TRIBE v2 works on Apple Silicon Macs (M1–M4). The model runs on CPU for inference (MPS support is limited for some operations). A few notes: + +- **whisperx** (used for audio transcription) does not support MPS so it automatically falls back to CPU with `int8` compute. +- **Multiprocessing**: macOS defaults to the `spawn` start method, which conflicts with PyTorch DataLoaders. Use `fork` and wrap your code in `if __name__ == "__main__"`: + +```python +import multiprocessing +from tribev2 import TribeModel + +if __name__ == "__main__": + multiprocessing.set_start_method("fork") + + model = TribeModel.from_pretrained("facebook/tribev2", cache_folder="./cache", device="cpu") + df = model.get_events_dataframe(video_path="path/to/video.mp4") + preds, segments = model.predict(events=df) + print(preds.shape) # (n_timesteps, n_vertices) +``` + +- **First run** is slow (feature extraction for LLaMA 3.2, V-JEPA2, Wav2Vec-BERT). Extracted features are cached in `cache_folder`, so subsequent runs on the same video are near-instant. + ## Training a model from scratch ### 1. Set environment variables From 1dc000811b65f2b3805aed16ece100cfbedc7dff Mon Sep 17 00:00:00 2001 From: alangnt Date: Sun, 29 Mar 2026 13:31:28 +0200 Subject: [PATCH 3/4] fix: unnecessary mps check --- tribev2/demo_utils.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tribev2/demo_utils.py b/tribev2/demo_utils.py index 4b5d09c..8034aaf 100644 --- a/tribev2/demo_utils.py +++ b/tribev2/demo_utils.py @@ -190,12 +190,7 @@ def from_pretrained( if cache_folder is not None: Path(cache_folder).mkdir(parents=True, exist_ok=True) if device == "auto": - if torch.cuda.is_available(): - device = "cuda" - elif torch.backends.mps.is_available(): - device = "mps" - else: - device = "cpu" + device = "cuda" if torch.cuda.is_available() else "cpu" checkpoint_dir = Path(checkpoint_dir) if checkpoint_dir.exists(): config_path = checkpoint_dir / "config.yaml" From d0d524b96f542bcf36e798a999509aca5beddf96 Mon Sep 17 00:00:00 2001 From: alangnt Date: Sun, 29 Mar 2026 13:59:57 +0200 Subject: [PATCH 4/4] fix: set extractor devics via config instead of post-hoc mutation --- tribev2/demo_utils.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/tribev2/demo_utils.py b/tribev2/demo_utils.py index 8034aaf..5e4a10a 100644 --- a/tribev2/demo_utils.py +++ b/tribev2/demo_utils.py @@ -220,6 +220,12 @@ def from_pretrained( config["cache_folder"] = ( str(cache_folder) if cache_folder is not None else "./cache" ) + if device in ("cpu", "mps"): # mps not supported by neuralset extractors + # Override all extractor devices to cpu when cuda is unavailable + for modality in ["text", "audio"]: + config[f"data.{modality}_feature.device"] = "cpu" + config["data.image_feature.image.device"] = "cpu" + config["data.video_feature.image.device"] = "cpu" if config_update is not None: config.update(config_update) xp = cls(**config) @@ -238,17 +244,6 @@ def from_pretrained( model.to(device) model.eval() xp._model = model - # Propagate device to feature extractors (and nested sub-extractors) - for modality in xp.data.features_to_use: - extractor = getattr(xp.data, f"{modality}_feature", None) - if extractor is not None: - if hasattr(extractor, "device"): - extractor.device = device - # Handle nested extractors (e.g. HuggingFaceVideo.image) - for field in extractor.model_fields: - sub = getattr(extractor, field, None) - if hasattr(sub, "device"): - sub.device = device return xp def get_events_dataframe(