diff --git a/README.md b/README.md index b63cd36..d540d25 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,28 @@ pip install -e ".[plotting]" pip install -e ".[training]" ``` +## Running on Apple Silicon (macOS) + +TRIBE v2 works on Apple Silicon Macs (M1–M4). The model runs on CPU for inference (MPS support is limited for some operations). A few notes: + +- **whisperx** (used for audio transcription) does not support MPS so it automatically falls back to CPU with `int8` compute. +- **Multiprocessing**: macOS defaults to the `spawn` start method, which conflicts with PyTorch DataLoaders. Use `fork` and wrap your code in `if __name__ == "__main__"`: + +```python +import multiprocessing +from tribev2 import TribeModel + +if __name__ == "__main__": + multiprocessing.set_start_method("fork") + + model = TribeModel.from_pretrained("facebook/tribev2", cache_folder="./cache", device="cpu") + df = model.get_events_dataframe(video_path="path/to/video.mp4") + preds, segments = model.predict(events=df) + print(preds.shape) # (n_timesteps, n_vertices) +``` + +- **First run** is slow (feature extraction for LLaMA 3.2, V-JEPA2, Wav2Vec-BERT). Extracted features are cached in `cache_folder`, so subsequent runs on the same video are near-instant. + ## Training a model from scratch ### 1. Set environment variables diff --git a/tribev2/demo_utils.py b/tribev2/demo_utils.py index fa735a7..5e4a10a 100644 --- a/tribev2/demo_utils.py +++ b/tribev2/demo_utils.py @@ -220,6 +220,12 @@ def from_pretrained( config["cache_folder"] = ( str(cache_folder) if cache_folder is not None else "./cache" ) + if device in ("cpu", "mps"): # mps not supported by neuralset extractors + # Override all extractor devices to cpu when cuda is unavailable + for modality in ["text", "audio"]: + config[f"data.{modality}_feature.device"] = "cpu" + config["data.image_feature.image.device"] = "cpu" + config["data.video_feature.image.device"] = "cpu" if config_update is not None: config.update(config_update) xp = cls(**config) diff --git a/tribev2/eventstransforms.py b/tribev2/eventstransforms.py index 1fd0c69..d2f7fde 100644 --- a/tribev2/eventstransforms.py +++ b/tribev2/eventstransforms.py @@ -104,8 +104,13 @@ def _get_transcript_from_audio(wav_filename: Path, language: str) -> pd.DataFram if language not in language_codes: raise ValueError(f"Language {language} not supported") - device = "cuda" if torch.cuda.is_available() else "cpu" - compute_type = "float16" + # whisperx (ctranslate2) only supports cuda or cpu + if torch.cuda.is_available(): + device = "cuda" + compute_type = "float16" + else: + device = "cpu" + compute_type = "int8" with tempfile.TemporaryDirectory() as output_dir: logger.info("Running whisperx via uvx...")