Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,28 @@ pip install -e ".[plotting]"
pip install -e ".[training]"
```

## Running on Apple Silicon (macOS)

TRIBE v2 works on Apple Silicon Macs (M1–M4). The model runs on CPU for inference (MPS support is limited for some operations). A few notes:

- **whisperx** (used for audio transcription) does not support MPS so it automatically falls back to CPU with `int8` compute.
- **Multiprocessing**: macOS defaults to the `spawn` start method, which conflicts with PyTorch DataLoaders. Use `fork` and wrap your code in `if __name__ == "__main__"`:

```python
import multiprocessing
from tribev2 import TribeModel

if __name__ == "__main__":
multiprocessing.set_start_method("fork")

model = TribeModel.from_pretrained("facebook/tribev2", cache_folder="./cache", device="cpu")
df = model.get_events_dataframe(video_path="path/to/video.mp4")
preds, segments = model.predict(events=df)
print(preds.shape) # (n_timesteps, n_vertices)
```

- **First run** is slow (feature extraction for LLaMA 3.2, V-JEPA2, Wav2Vec-BERT). Extracted features are cached in `cache_folder`, so subsequent runs on the same video are near-instant.

## Training a model from scratch

### 1. Set environment variables
Expand Down
6 changes: 6 additions & 0 deletions tribev2/demo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,12 @@ def from_pretrained(
config["cache_folder"] = (
str(cache_folder) if cache_folder is not None else "./cache"
)
if device in ("cpu", "mps"): # mps not supported by neuralset extractors
# Override all extractor devices to cpu when cuda is unavailable
for modality in ["text", "audio"]:
config[f"data.{modality}_feature.device"] = "cpu"
config["data.image_feature.image.device"] = "cpu"
config["data.video_feature.image.device"] = "cpu"
if config_update is not None:
config.update(config_update)
xp = cls(**config)
Expand Down
9 changes: 7 additions & 2 deletions tribev2/eventstransforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,13 @@ def _get_transcript_from_audio(wav_filename: Path, language: str) -> pd.DataFram
if language not in language_codes:
raise ValueError(f"Language {language} not supported")

device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16"
# whisperx (ctranslate2) only supports cuda or cpu
if torch.cuda.is_available():
device = "cuda"
compute_type = "float16"
else:
device = "cpu"
compute_type = "int8"

with tempfile.TemporaryDirectory() as output_dir:
logger.info("Running whisperx via uvx...")
Expand Down