diff --git a/.gitignore b/.gitignore index a987ae0..386a443 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,9 @@ # StemForge .gitignore # ============================================================================= +# GSD planning (local-only, not committed) +.planning/ + # ----------------------------------------------------------------------------- # Python # ----------------------------------------------------------------------------- diff --git a/backend/api/midi.py b/backend/api/midi.py index cea82a6..5b13dc3 100644 --- a/backend/api/midi.py +++ b/backend/api/midi.py @@ -14,6 +14,7 @@ from backend.services.job_manager import job_manager from backend.services.session_store import SessionStore, TrackState, get_user_session from backend.services import pipeline_manager +from models.registry import DrumMidiSpec, list_specs from utils.paths import MIDI_DIR router = APIRouter(prefix="/api/midi", tags=["midi"]) @@ -86,6 +87,14 @@ "Drums & percussion": True, } + +def _build_adt_model_list() -> list[dict]: + """Return a list of ADT model metadata dicts from the model registry.""" + return [ + {"model_id": s.model_id, "display_name": s.display_name, "tooltip": s.description} + for s in list_specs(DrumMidiSpec) + ] + # ─── SoundFont discovery & state ───────────────────────────────────────── _SF2_SEARCH_PATHS = [ @@ -119,6 +128,7 @@ class ExtractRequest(BaseModel): onset_threshold: float = 0.5 frame_threshold: float = 0.3 min_note_ms: float = 58.0 + adt_model: str = "adtof-drums" class RenderRequest(BaseModel): @@ -346,11 +356,12 @@ def get_midi_stems(session: SessionStore = Depends(get_user_session)) -> dict: @router.get("/gm-programs") def get_gm_programs() -> dict: - """Return list of all 128 GM program names and smart defaults per stem label.""" + """Return list of all 128 GM program names, smart defaults per stem label, and ADT models.""" return { "programs": GM_PROGRAMS, "defaults": STEM_DEFAULT_PROGRAM, "drum_stems": STEM_IS_DRUM, + "adt_models": _build_adt_model_list(), } diff --git a/frontend/components/midi.js b/frontend/components/midi.js index 94ef5a6..21239e5 100644 --- a/frontend/components/midi.js +++ b/frontend/components/midi.js @@ -15,6 +15,7 @@ function clearChildren(elem) { let gmPrograms = []; let stemDefaults = {}; let drumStems = {}; +let adtModels = []; /** LilyPond availability (checked on init). */ let _lilypondAvailable = false; @@ -45,6 +46,14 @@ export function initMidi() { ), ); + const adtGroup = el('div', { className: 'form-group hidden', id: 'midi-adt-group' }, + el('label', {}, 'ADT Model'), + el('select', { id: 'midi-adt-model' }), + el('p', { className: 'text-dim', style: { fontSize: '12px', margin: '4px 0 0' } }, + 'Best results with acoustic drums. Electronic/programmed drums may have lower accuracy.', + ), + ); + const keyGroup = el('div', { className: 'form-group' }, el('label', {}, 'Key'), el('select', { id: 'midi-key' }, @@ -104,7 +113,7 @@ export function initMidi() { const importInput = el('input', { type: 'file', accept: '.mid,.midi', style: { display: 'none' }, id: 'midi-import-input' }); const importBtn = el('button', { className: 'btn btn-sm', id: 'midi-import' }, 'Import MIDI file'); - left.append(stemSection, keyGroup, bpmGroup, tsGroup, onsetGroup, frameGroup, sf2Group, extractBtn, importInput, importBtn); + left.append(stemSection, adtGroup, keyGroup, bpmGroup, tsGroup, onsetGroup, frameGroup, sf2Group, extractBtn, importInput, importBtn); // ─── Right: results ─── const right = el('div', { className: 'col-right' }); @@ -158,6 +167,8 @@ export function initMidi() { document.getElementById('midi-start').disabled = false; }); + document.getElementById('midi-stems').addEventListener('change', syncAdtGroupVisibility); + // Load GM programs, current soundfont, and check LilyPond on init loadGmPrograms(); loadCurrentSoundfont(); @@ -170,6 +181,14 @@ async function loadGmPrograms() { gmPrograms = data.programs || []; stemDefaults = data.defaults || {}; drumStems = data.drum_stems || {}; + adtModels = data.adt_models || []; + const adtSel = document.getElementById('midi-adt-model'); + if (adtSel) { + clearChildren(adtSel); + for (const m of adtModels) { + adtSel.appendChild(el('option', { value: m.model_id, title: m.tooltip || '' }, m.display_name)); + } + } } catch { /* fail silently, will use defaults */ } } @@ -244,6 +263,15 @@ function populateStemCheckboxes(stemPaths) { ), ); } + syncAdtGroupVisibility(); +} + +function syncAdtGroupVisibility() { + const hasDrum = Array.from( + document.querySelectorAll('#midi-stems input[type="checkbox"]:checked') + ).some(cb => isDrumStem(cb.value)); + const adtGroup = document.getElementById('midi-adt-group'); + if (adtGroup) adtGroup.classList.toggle('hidden', !hasDrum); } async function startExtraction() { @@ -268,6 +296,7 @@ async function startExtraction() { time_signature: document.getElementById('midi-ts').value, onset_threshold: parseFloat(document.getElementById('midi-onset').value), frame_threshold: parseFloat(document.getElementById('midi-frame').value), + adt_model: document.getElementById('midi-adt-model')?.value || 'adtof-drums', }), }); diff --git a/models/midi_loader.py b/models/midi_loader.py index 112bcd7..254a533 100644 --- a/models/midi_loader.py +++ b/models/midi_loader.py @@ -56,6 +56,7 @@ def __init__(self) -> None: self._bp_loader = BasicPitchModelLoader() self._model: Any | None = None # BasicPitch TF model self._whisper_model: Any | None = None # faster-whisper WhisperModel + self._adtof_backend: Any | None = None # AdtofBackend (lazy-loaded) # ------------------------------------------------------------------ # Lifecycle @@ -93,12 +94,41 @@ def is_loaded(self) -> bool: return self._model is not None def evict(self) -> None: - """Release both models from memory and trigger GC.""" + """Release all models from memory and trigger GC.""" self._bp_loader.evict() self._model = None self._whisper_model = None + self.evict_drum_model() log.debug("MidiModelLoader: models evicted.") + # ------------------------------------------------------------------ + # Internal: ADTOF lazy loader + # ------------------------------------------------------------------ + + def _ensure_adtof(self) -> Any: + """Load the ADTOF drum transcription backend on first use.""" + if self._adtof_backend is not None: + return self._adtof_backend + try: + from pipelines.adtof_backend import AdtofBackend + except ImportError as exc: + raise ModelLoadError( + "ADTOF backend is not available.", + model_name="adtof-drums", + ) from exc + backend = AdtofBackend() + backend.load() + self._adtof_backend = backend + log.info("MidiModelLoader: ADTOF drum backend ready.") + return self._adtof_backend + + def evict_drum_model(self) -> None: + """Evict the ADTOF backend only, leaving BasicPitch intact.""" + if self._adtof_backend is not None: + self._adtof_backend.evict() + self._adtof_backend = None + log.debug("MidiModelLoader: ADTOF backend evicted.") + # ------------------------------------------------------------------ # Internal: Whisper lazy loader # ------------------------------------------------------------------ @@ -379,3 +409,47 @@ def convert_vocal_to_midi( path.name, len(events), len(lyrics), len(segments), ) return events, lyrics + + def convert_drum_to_midi( + self, + path: pathlib.Path, + *, + duration: float = 0.0, + ) -> list[NoteEvent]: + """Transcribe a drum stem to GM note events via ADTOF. + + Parameters + ---------- + path: + Drum audio file (must be 44100 Hz). + duration: + If positive, clip events to this many seconds. + + Returns + ------- + list[NoteEvent] + GM channel-10 note events from ADTOF_5CLASS_GM_NOTE. + + Raises + ------ + :class:`~utils.errors.ModelLoadError` + If the ADTOF backend cannot be loaded. + :class:`~utils.errors.PipelineExecutionError` + If drum transcription fails. + """ + backend = self._ensure_adtof() + try: + events = backend.predict(path) + except PipelineExecutionError: + raise + except Exception as exc: + raise PipelineExecutionError( + f"Drum transcription failed for '{path.name}': {exc}", + pipeline_name="midi", + ) from exc + + if duration > 0.0: + events = [(s, min(e, duration), p, v) for s, e, p, v in events if s < duration] + + log.debug("convert_drum_to_midi: %s -> %d drum events", path.name, len(events)) + return events diff --git a/models/registry.py b/models/registry.py index 973d847..8b2df0d 100644 --- a/models/registry.py +++ b/models/registry.py @@ -231,6 +231,26 @@ class StableAudioSpec(ModelSpec): conditioning_keys: tuple[str, ...] +@dataclass(frozen=True, slots=True) +class DrumMidiSpec(ModelSpec): + """Descriptor for a drum transcription (ADT) model. + + Additional fields + ----------------- + class_count: + Number of drum output classes this model produces. + class_labels: + Human-readable label for each class, in model output order. + checkpoint_url: + Direct download URL for model weights; empty string means weights + are bundled in the pip package. + """ + + class_count: int = 0 + class_labels: tuple[str, ...] = () + checkpoint_url: str = "" + + @dataclass(frozen=True, slots=True) class RoformerSpec(ModelSpec): """Descriptor for a BS-Roformer / MelBand-Roformer separation model. @@ -742,6 +762,38 @@ def _register(spec: ModelSpec) -> ModelSpec: default_num_overlap=2, )) +# --------------------------------------------------------------------------- +# Drum transcription (ADT) models +# --------------------------------------------------------------------------- + +ADTOF_DRUMS = _register(DrumMidiSpec( + model_id="adtof-drums", + display_name="ADTOF Drums (5-class)", + version="0.1.0", + source="xavriley/ADTOF-pytorch", + device="auto", + gpu_capable=True, + device_fallback="cpu", + device_quirks="", + sample_rate=44_100, + hop_size=0, + chunk_size=0, + max_duration_seconds=0.0, + default_bpm=0.0, + default_key="", + default_time_signature="", + quantize_grid="none", + default_min_note_ms=0.0, + capabilities=frozenset({"transcribe", "drum_transcription", "gpu_acceleration"}), + cache_subdir="adtof", + description="Automatic drum transcription — 5-class (kick, snare, tom, hi-hat, cymbal).", + preprocessing="Mono 44.1 kHz; mel spectrogram; context window of 9 frames.", + postprocessing="Peak picking per class; GM note mapping; 60 ms fixed note duration.", + class_count=5, + class_labels=("kick", "snare", "tom", "hi_hat", "cymbal"), + checkpoint_url="", +)) + # --------------------------------------------------------------------------- # Convenience constants # --------------------------------------------------------------------------- @@ -854,6 +906,9 @@ def get_loader_kwargs(model_id: str) -> dict[str, Any]: "config_url": spec.config_url, } + if isinstance(spec, DrumMidiSpec): + return {"cache_dir": spec.cache_dir} + raise NotImplementedError( f"No loader kwargs defined for spec type {type(spec).__name__!r}." ) @@ -925,6 +980,9 @@ def get_pipeline_defaults(model_id: str) -> dict[str, Any]: "num_overlap": spec.default_num_overlap, } + if isinstance(spec, DrumMidiSpec): + return {"class_count": spec.class_count} + raise NotImplementedError( f"No pipeline defaults defined for spec type {type(spec).__name__!r}." ) @@ -1004,6 +1062,14 @@ def get_gui_metadata(model_id: str) -> dict[str, Any]: "target_instrument": spec.target_instrument, } + if isinstance(spec, DrumMidiSpec): + return { + "class_count": spec.class_count, + "class_labels": list(spec.class_labels), + "model_choices": [s.model_id for s in list_specs(DrumMidiSpec)], + "tooltip": spec.description, + } + raise NotImplementedError( f"No GUI metadata defined for spec type {type(spec).__name__!r}." ) diff --git a/pipelines/adtof_backend.py b/pipelines/adtof_backend.py new file mode 100644 index 0000000..6421335 --- /dev/null +++ b/pipelines/adtof_backend.py @@ -0,0 +1,198 @@ +"""ADTOF-pytorch automatic drum transcription backend. + +Provides :class:`AdtofBackendProtocol` (structural interface for ADT backends) +and :class:`AdtofBackend` (ADTOF Frame-RNN implementation). + +Usage:: + + backend = AdtofBackend() + backend.load(device="cuda") + events = backend.predict(Path("drums.wav")) + backend.evict() +""" +from __future__ import annotations + +import logging +import pathlib +from typing import Protocol, runtime_checkable + +import soundfile as sf +import torch + +from utils.errors import InvalidInputError, ModelLoadError, PipelineExecutionError +from utils.midi_io import NoteEvent + +logger = logging.getLogger(__name__) + +NOTE_DURATION: float = 0.06 # 60ms — matches notes_to_midi() drum cap +DEFAULT_VELOCITY: int = 100 # Fixed velocity; amplitude-based is v2 (ECLASS-02) +_VALID_GM_NOTES: frozenset[int] = frozenset({35, 38, 42, 47, 49}) + + +def _peaks_to_note_events( + peaks: dict[int, list[float]], + duration: float = NOTE_DURATION, + velocity: int = DEFAULT_VELOCITY, +) -> list[NoteEvent]: + """Convert PeakPicker output to sorted NoteEvent list. + + Parameters + ---------- + peaks: + ``{gm_note: [onset_time_sec, ...]}`` from ``PeakPicker.pick()``. + duration: + Fixed note duration in seconds. + velocity: + Fixed MIDI velocity (1-127). + + Returns + ------- + list[NoteEvent] + Sorted by onset time ascending. + """ + events: list[NoteEvent] = [] + for gm_note, times in peaks.items(): + gm_note_int = int(gm_note) + assert gm_note_int in _VALID_GM_NOTES, ( + f"Unexpected GM note {gm_note_int} from PeakPicker; " + f"expected one of {sorted(_VALID_GM_NOTES)}" + ) + for onset in times: + events.append((onset, onset + duration, gm_note_int, velocity)) + events.sort(key=lambda e: e[0]) + return events + + +@runtime_checkable +class AdtofBackendProtocol(Protocol): + """Structural interface for automatic drum transcription backends. + + Implementations are not required to inherit from this class. Any class + with matching ``load``, ``predict``, and ``evict`` method signatures will + satisfy ``isinstance(obj, AdtofBackendProtocol)`` checks at runtime. + """ + + def load(self, device: str = "cpu") -> None: + """Load model weights into memory on the given device.""" + ... + + def predict(self, audio_path: pathlib.Path) -> list[NoteEvent]: + """Transcribe drum audio; return NoteEvent tuples, GM notes only.""" + ... + + def evict(self) -> None: + """Release model weights from memory.""" + ... + + +class AdtofBackend: + """ADTOF Frame-RNN drum transcription backend. + + Implements :class:`AdtofBackendProtocol` via structural subtyping + (no inheritance required). + + Example:: + + backend = AdtofBackend() + backend.load(device="cuda") + events = backend.predict(Path("drums_stem.wav")) + backend.evict() + """ + + def __init__(self) -> None: + self._model: torch.nn.Module | None = None + self._device: str = "cpu" + + def load(self, device: str = "cpu") -> None: + """Load ADTOF Frame-RNN weights into memory. + + Imports adtof_pytorch functions at call time to preserve lazy loading. + Wraps any exception from weight loading in :class:`ModelLoadError`. + + Parameters + ---------- + device: + Torch device string, e.g. ``"cpu"``, ``"cuda"``, ``"cuda:0"``. + """ + try: + import adtof_pytorch + + n_bins = adtof_pytorch.calculate_n_bins() + model = adtof_pytorch.create_frame_rnn_model(n_bins) + model.eval() + weights_path = adtof_pytorch.get_default_weights_path() + model = adtof_pytorch.load_pytorch_weights(model, str(weights_path), strict=False) + model.to(device) + self._model = model + self._device = device + logger.info("ADTOF model loaded on %s", device) + except Exception as exc: + raise ModelLoadError(str(exc), model_name="adtof-drums") from exc + + def predict(self, audio_path: pathlib.Path) -> list[NoteEvent]: + """Transcribe drum audio and return NoteEvent tuples. + + Raises + ------ + InvalidInputError + If the audio sample rate is not 44100 Hz. + PipelineExecutionError + If the model forward pass or onset detection fails. + RuntimeError + If called before ``load()``. + """ + # --- Sample rate guard (ADT-02) --- + info = sf.info(str(audio_path)) + if info.samplerate != 44100: + raise InvalidInputError( + f"ADTOF requires 44100 Hz audio; got {info.samplerate} Hz: {audio_path}", + field="audio_path", + ) + + if self._model is None: + raise RuntimeError("Model not loaded — call load() first") + + try: + from adtof_pytorch import ( + load_audio_for_model, + PeakPicker, + FRAME_RNN_THRESHOLDS, + LABELS_5, + ) + + # --- Forward pass (ADT-01: zero disk writes) --- + x = load_audio_for_model(str(audio_path)) + x = x.to(self._device) + with torch.no_grad(): + pred = self._model(x).cpu().numpy() + + # --- Onset detection (ADT-04) --- + picker = PeakPicker(thresholds=FRAME_RNN_THRESHOLDS, fps=100) + picked = picker.pick(pred, labels=LABELS_5, label_offset=0)[0] + + # --- Convert to NoteEvent --- + events = _peaks_to_note_events(picked) + logger.info("ADTOF transcription: %d note events", len(events)) + return events + + except InvalidInputError: + raise + except Exception as exc: + raise PipelineExecutionError( + f"ADTOF prediction failed: {exc}", + pipeline_name="adtof_backend", + ) from exc + + def evict(self) -> None: + """Release model weights from GPU/CPU memory. + + Moves the model to CPU before clearing the reference so PyTorch can + free GPU memory. Calls ``torch.cuda.empty_cache()`` when the model + was loaded on a CUDA device. + """ + if self._model is not None: + self._model.cpu() + self._model = None + if self._device.startswith("cuda"): + torch.cuda.empty_cache() + logger.info("ADTOF model evicted") diff --git a/pipelines/midi_pipeline.py b/pipelines/midi_pipeline.py index 8558aa7..af6d6d7 100644 --- a/pipelines/midi_pipeline.py +++ b/pipelines/midi_pipeline.py @@ -53,6 +53,14 @@ "vocals", "Singing voice", }) +# Stem labels that represent drum / percussion tracks — routed to ADTOF +# drum transcription instead of BasicPitch. +# Must match STEM_IS_DRUM keys in backend/api/midi.py. +_DRUM_STEM_LABELS: frozenset[str] = frozenset({ + "drums", # Demucs htdemucs / mdx_extra output label + "Drums & percussion", # BS-Roformer jarredou-6stem / zfturbo-4stem output label +}) + # --------------------------------------------------------------------------- # Configuration @@ -319,6 +327,16 @@ def run(self, stems: dict[str, pathlib.Path]) -> "MidiResult": ) if lyrics: track_lyrics[label] = lyrics + + elif label in _DRUM_STEM_LABELS: + log.info("MidiPipeline: routing '%s' to drum ADT path.", label) + self._report(base_pct + 2.0) # stage 1: before loading + notes = self._loader.convert_drum_to_midi( + path, duration=cfg.duration_seconds, + ) + self._report(base_pct + 10.0) # stage 2: after loading/predict + self._report(base_pct + (1.0 / total) * 70.0) # stage 3: done + else: notes = self._loader.convert_audio_to_midi( path, @@ -335,10 +353,19 @@ def run(self, stems: dict[str, pathlib.Path]) -> "MidiResult": # Build per-stem MIDI object in memory (not written to disk yet) stem_lyrics = track_lyrics.get(label) - stem_midi_data[label] = self._build_stem_midi(label, notes, cfg, stem_lyrics) + stem_midi_data[label] = self._build_stem_midi( + label, notes, cfg, stem_lyrics, + is_drum=(label in _DRUM_STEM_LABELS), + ) self._report(base_pct + (1.0 / total) * 70.0) + # Evict ADTOF backend after processing all drum stems to free GPU memory. + # AdtofBackend.evict() handles torch.cuda.empty_cache() internally — + # no backend/api/midi.py modification needed. + if any(label in _DRUM_STEM_LABELS for label in stems): + self._loader.evict_drum_model() + else: # Text-only: generate a diatonic chord progression self._report(20.0) @@ -418,10 +445,12 @@ def _build_stem_midi( notes: list[NoteEvent], cfg: MidiConfig, lyrics: list[LyricEvent] | None = None, + is_drum: bool = False, ) -> Any: """Build and return a PrettyMIDI object for *stem_name* (no disk write).""" return notes_to_midi( - notes, ticks_per_beat=_TICKS_PER_BEAT, tempo_bpm=cfg.bpm, lyrics=lyrics + notes, ticks_per_beat=_TICKS_PER_BEAT, tempo_bpm=cfg.bpm, + lyrics=lyrics, is_drum=is_drum, ) def write_stem_midi( diff --git a/pyproject.toml b/pyproject.toml index b4efa29..a64cbe3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,8 @@ dependencies = [ # GitHub-pinned ML components (renamed distributions) "demucs @ git+https://github.com/facebookresearch/demucs.git@v4.0.1", - #"basic-pitch @ git+https://github.com/spotify/basic-pitch.git", + "adtof-pytorch @ git+https://github.com/xavriley/ADTOF-pytorch@85c192e", + #"basic-pitch @ git+https://github.com/spotify/basic-pitch.git", # Demucs deps "dora-search==0.1.12", diff --git a/tests/test_adtof_backend.py b/tests/test_adtof_backend.py new file mode 100644 index 0000000..25550d0 --- /dev/null +++ b/tests/test_adtof_backend.py @@ -0,0 +1,349 @@ +"""Tests for AdtofBackendProtocol and AdtofBackend load/evict lifecycle. + +RED phase: All tests expected to fail until pipelines/adtof_backend.py is implemented. +""" +from __future__ import annotations + +import pathlib +from unittest.mock import MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Protocol structural checks +# --------------------------------------------------------------------------- + + +def test_protocol_structural_check() -> None: + """A duck-typed class with load/predict/evict passes isinstance check.""" + from pipelines.adtof_backend import AdtofBackendProtocol + + class DummyBackend: + def load(self, device: str = "cpu") -> None: + ... + + def predict(self, audio_path: pathlib.Path) -> list: + return [] + + def evict(self) -> None: + ... + + assert isinstance(DummyBackend(), AdtofBackendProtocol) is True + + +def test_protocol_rejects_incomplete() -> None: + """A class missing predict() fails isinstance check against AdtofBackendProtocol.""" + from pipelines.adtof_backend import AdtofBackendProtocol + + class IncompleteBackend: + def load(self, device: str = "cpu") -> None: + ... + + def evict(self) -> None: + ... + + assert isinstance(IncompleteBackend(), AdtofBackendProtocol) is False + + +# --------------------------------------------------------------------------- +# Load / evict lifecycle +# --------------------------------------------------------------------------- + +_MOCK_PATCHES = [ + "adtof_pytorch.calculate_n_bins", + "adtof_pytorch.create_frame_rnn_model", + "adtof_pytorch.get_default_weights_path", + "adtof_pytorch.load_pytorch_weights", +] + + +def _make_mock_model() -> MagicMock: + mock_model = MagicMock() + mock_model.eval.return_value = mock_model + mock_model.to.return_value = mock_model + mock_model.cpu.return_value = mock_model + return mock_model + + +def test_load_creates_model() -> None: + """After load(), self._model is not None.""" + from pipelines.adtof_backend import AdtofBackend + + mock_model = _make_mock_model() + + with ( + patch("adtof_pytorch.calculate_n_bins", return_value=168), + patch("adtof_pytorch.create_frame_rnn_model", return_value=mock_model), + patch("adtof_pytorch.get_default_weights_path", return_value=pathlib.Path("/fake/weights.pth")), + patch("adtof_pytorch.load_pytorch_weights", return_value=mock_model), + ): + backend = AdtofBackend() + backend.load(device="cpu") + + assert backend._model is not None + + +def test_load_sets_eval_mode() -> None: + """After load(), model.eval() was called.""" + from pipelines.adtof_backend import AdtofBackend + + mock_model = _make_mock_model() + + with ( + patch("adtof_pytorch.calculate_n_bins", return_value=168), + patch("adtof_pytorch.create_frame_rnn_model", return_value=mock_model), + patch("adtof_pytorch.get_default_weights_path", return_value=pathlib.Path("/fake/weights.pth")), + patch("adtof_pytorch.load_pytorch_weights", return_value=mock_model), + ): + backend = AdtofBackend() + backend.load(device="cpu") + + assert mock_model.eval.called is True + + +def test_evict_clears_model() -> None: + """After load() then evict(), self._model is None.""" + from pipelines.adtof_backend import AdtofBackend + + mock_model = _make_mock_model() + + with ( + patch("adtof_pytorch.calculate_n_bins", return_value=168), + patch("adtof_pytorch.create_frame_rnn_model", return_value=mock_model), + patch("adtof_pytorch.get_default_weights_path", return_value=pathlib.Path("/fake/weights.pth")), + patch("adtof_pytorch.load_pytorch_weights", return_value=mock_model), + ): + backend = AdtofBackend() + backend.load(device="cpu") + backend.evict() + + assert backend._model is None + + +def test_load_after_evict() -> None: + """After evict(), calling load() again creates a new model. + + create_frame_rnn_model should be called twice (once per load call). + """ + from pipelines.adtof_backend import AdtofBackend + + mock_model = _make_mock_model() + mock_create = MagicMock(return_value=mock_model) + + with ( + patch("adtof_pytorch.calculate_n_bins", return_value=168), + patch("adtof_pytorch.create_frame_rnn_model", mock_create), + patch("adtof_pytorch.get_default_weights_path", return_value=pathlib.Path("/fake/weights.pth")), + patch("adtof_pytorch.load_pytorch_weights", return_value=mock_model), + ): + backend = AdtofBackend() + backend.load(device="cpu") + backend.evict() + backend.load(device="cpu") + + assert backend._model is not None + assert mock_create.call_count == 2 + + +def test_load_wraps_errors() -> None: + """If create_frame_rnn_model raises, load() wraps it in ModelLoadError.""" + from pipelines.adtof_backend import AdtofBackend + from utils.errors import ModelLoadError + + with ( + patch("adtof_pytorch.calculate_n_bins", return_value=168), + patch("adtof_pytorch.create_frame_rnn_model", side_effect=RuntimeError("bad weights")), + patch("adtof_pytorch.get_default_weights_path", return_value=pathlib.Path("/fake/weights.pth")), + ): + backend = AdtofBackend() + with pytest.raises(ModelLoadError): + backend.load(device="cpu") + + +# --------------------------------------------------------------------------- +# predict() — sample rate guard, NoteEvent conversion, correctness +# --------------------------------------------------------------------------- + +import builtins +import numpy as np +import torch + + +def _make_predict_backend() -> "AdtofBackend": + """Return an AdtofBackend with _model and _device pre-set (skip load()).""" + from pipelines.adtof_backend import AdtofBackend + + mock_model = MagicMock() + # forward pass returns (1, 10, 5) sigmoid activations — all 0.9 + mock_model.return_value = torch.full((1, 10, 5), 0.9) + backend = AdtofBackend() + backend._model = mock_model + backend._device = "cpu" + return backend + + +def _make_sf_info(samplerate: int = 44100) -> MagicMock: + info = MagicMock() + info.samplerate = samplerate + return info + + +def _make_audio_tensor() -> torch.Tensor: + return torch.zeros(1, 10, 168, 1) + + +def _make_peak_picker_mock(peaks: dict) -> MagicMock: + """Return a mock PeakPicker class whose instance .pick() returns [peaks].""" + mock_picker_instance = MagicMock() + mock_picker_instance.pick.return_value = [peaks] + mock_picker_cls = MagicMock(return_value=mock_picker_instance) + return mock_picker_cls + + +_KNOWN_PEAKS = {35: [0.5, 1.2], 38: [0.3], 47: [0.8], 42: [1.0, 1.5], 49: []} + + +def test_predict_rejects_wrong_sample_rate() -> None: + """predict() raises InvalidInputError when audio is not 44100 Hz.""" + from utils.errors import InvalidInputError + + backend = _make_predict_backend() + with patch("soundfile.info", return_value=_make_sf_info(22050)): + with pytest.raises(InvalidInputError) as exc_info: + backend.predict(pathlib.Path("/fake/drums.wav")) + assert "44100" in str(exc_info.value) + + +def test_predict_no_disk_writes(tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch) -> None: + """predict() with mocked model does not create any files.""" + backend = _make_predict_backend() + monkeypatch.chdir(tmp_path) + + # Track open() calls — raise if any write-mode open occurs + real_open = builtins.open + + def guarded_open(file, mode="r", *args, **kwargs): + if "w" in str(mode) or "x" in str(mode): + raise AssertionError(f"Unexpected write-mode open({file!r}, {mode!r})") + return real_open(file, mode, *args, **kwargs) + + mock_picker_cls = _make_peak_picker_mock(_KNOWN_PEAKS) + + with ( + patch("soundfile.info", return_value=_make_sf_info(44100)), + patch("adtof_pytorch.load_audio_for_model", return_value=_make_audio_tensor()), + patch("adtof_pytorch.PeakPicker", mock_picker_cls), + patch("adtof_pytorch.FRAME_RNN_THRESHOLDS", {}), + patch("adtof_pytorch.LABELS_5", [35, 38, 47, 42, 49]), + patch("builtins.open", side_effect=guarded_open), + ): + backend.predict(pathlib.Path("/fake/drums.wav")) + + assert list(tmp_path.iterdir()) == [] + + +def test_note_events_gm_notes_only() -> None: + """All NoteEvent pitches are exclusively in {35, 38, 42, 47, 49}.""" + backend = _make_predict_backend() + mock_picker_cls = _make_peak_picker_mock(_KNOWN_PEAKS) + + with ( + patch("soundfile.info", return_value=_make_sf_info()), + patch("adtof_pytorch.load_audio_for_model", return_value=_make_audio_tensor()), + patch("adtof_pytorch.PeakPicker", mock_picker_cls), + patch("adtof_pytorch.FRAME_RNN_THRESHOLDS", {}), + patch("adtof_pytorch.LABELS_5", [35, 38, 47, 42, 49]), + ): + events = backend.predict(pathlib.Path("/fake/drums.wav")) + + assert len(events) > 0 + assert all(e[2] in {35, 38, 42, 47, 49} for e in events) + + +def test_non_sequential_label_mapping() -> None: + """Tom (GM 47) and hi-hat (GM 42) are mapped correctly (non-sequential).""" + backend = _make_predict_backend() + # Only tom=47 at t=1.0 and hi-hat=42 at t=2.0 + peaks = {47: [1.0], 42: [2.0]} + mock_picker_cls = _make_peak_picker_mock(peaks) + + with ( + patch("soundfile.info", return_value=_make_sf_info()), + patch("adtof_pytorch.load_audio_for_model", return_value=_make_audio_tensor()), + patch("adtof_pytorch.PeakPicker", mock_picker_cls), + patch("adtof_pytorch.FRAME_RNN_THRESHOLDS", {}), + patch("adtof_pytorch.LABELS_5", [35, 38, 47, 42, 49]), + ): + events = backend.predict(pathlib.Path("/fake/drums.wav")) + + assert len(events) == 2 + tom_event = next(e for e in events if abs(e[0] - 1.0) < 1e-9) + hihat_event = next(e for e in events if abs(e[0] - 2.0) < 1e-9) + assert tom_event[2] == 47, f"Expected GM 47 (tom) at t=1.0, got {tom_event[2]}" + assert hihat_event[2] == 42, f"Expected GM 42 (hi-hat) at t=2.0, got {hihat_event[2]}" + + +def test_note_event_duration_and_velocity() -> None: + """Each NoteEvent has 60ms duration and velocity 100.""" + backend = _make_predict_backend() + mock_picker_cls = _make_peak_picker_mock({35: [0.5], 38: [1.0]}) + + with ( + patch("soundfile.info", return_value=_make_sf_info()), + patch("adtof_pytorch.load_audio_for_model", return_value=_make_audio_tensor()), + patch("adtof_pytorch.PeakPicker", mock_picker_cls), + patch("adtof_pytorch.FRAME_RNN_THRESHOLDS", {}), + patch("adtof_pytorch.LABELS_5", [35, 38, 47, 42, 49]), + ): + events = backend.predict(pathlib.Path("/fake/drums.wav")) + + assert len(events) == 2 + for e in events: + assert e[3] == 100, f"Expected velocity=100, got {e[3]}" + assert abs((e[1] - e[0]) - 0.06) < 1e-9, f"Expected duration=0.06s, got {e[1]-e[0]}" + + +def test_note_events_sorted_by_onset() -> None: + """Output NoteEvents are sorted by start time ascending.""" + backend = _make_predict_backend() + # Out-of-order across classes: kick at t=2.0, snare at t=0.5, hihat at t=1.0 + mock_picker_cls = _make_peak_picker_mock({35: [2.0], 38: [0.5], 42: [1.0]}) + + with ( + patch("soundfile.info", return_value=_make_sf_info()), + patch("adtof_pytorch.load_audio_for_model", return_value=_make_audio_tensor()), + patch("adtof_pytorch.PeakPicker", mock_picker_cls), + patch("adtof_pytorch.FRAME_RNN_THRESHOLDS", {}), + patch("adtof_pytorch.LABELS_5", [35, 38, 47, 42, 49]), + ): + events = backend.predict(pathlib.Path("/fake/drums.wav")) + + onsets = [e[0] for e in events] + assert onsets == sorted(onsets), f"Events not sorted by onset: {onsets}" + + +def test_predict_raises_pipeline_error() -> None: + """If model forward pass raises, predict() wraps it in PipelineExecutionError.""" + from utils.errors import PipelineExecutionError + + backend = _make_predict_backend() + backend._model = MagicMock(side_effect=RuntimeError("CUDA OOM")) + + with ( + patch("soundfile.info", return_value=_make_sf_info()), + patch("adtof_pytorch.load_audio_for_model", return_value=_make_audio_tensor()), + patch("adtof_pytorch.PeakPicker", _make_peak_picker_mock({})), + patch("adtof_pytorch.FRAME_RNN_THRESHOLDS", {}), + patch("adtof_pytorch.LABELS_5", [35, 38, 47, 42, 49]), + ): + with pytest.raises(PipelineExecutionError): + backend.predict(pathlib.Path("/fake/drums.wav")) + + +def test_peaks_to_note_events_empty() -> None: + """_peaks_to_note_events({}) returns empty list.""" + from pipelines.adtof_backend import _peaks_to_note_events + + result = _peaks_to_note_events({}) + assert result == [] diff --git a/tests/test_drum_map.py b/tests/test_drum_map.py new file mode 100644 index 0000000..5dc91e4 --- /dev/null +++ b/tests/test_drum_map.py @@ -0,0 +1,58 @@ +"""Unit tests for utils.drum_map — GM drum constants and gm_note(). + +These tests are written FIRST (TDD RED phase) against the not-yet-existing +utils/drum_map module. All tests should FAIL with ImportError until +utils/drum_map.py is created (Task 2). +""" +from __future__ import annotations + + +def test_adtof_drum_class_enum() -> None: + """AdtofDrumClass enum has exactly 5 members with correct integer values.""" + from utils.drum_map import AdtofDrumClass + + assert AdtofDrumClass.KICK == 0 + assert AdtofDrumClass.SNARE == 1 + assert AdtofDrumClass.TOM == 2 + assert AdtofDrumClass.HI_HAT == 3 + assert AdtofDrumClass.CYMBAL == 4 + assert len(AdtofDrumClass) == 5 + + +def test_drum_map_adtof_5class() -> None: + """ADTOF_5CLASS_GM_NOTE maps class indices to exact GM note numbers.""" + from utils.drum_map import ADTOF_5CLASS_GM_NOTE + + assert ADTOF_5CLASS_GM_NOTE == {0: 35, 1: 38, 2: 47, 3: 42, 4: 49} + + +def test_gm_note_non_sequential() -> None: + """gm_note() preserves the non-sequential ordering: tom(47) > hi-hat(42).""" + from utils.drum_map import AdtofDrumClass, gm_note + + # Tom (index 2) maps to GM note 47. + assert gm_note(AdtofDrumClass.TOM) == 47 + # Hi-hat (index 3) maps to GM note 42 — numerically lower than tom despite higher index. + assert gm_note(AdtofDrumClass.HI_HAT) == 42 + + +def test_gm_note_all_classes() -> None: + """gm_note() returns the correct GM note number for every AdtofDrumClass member.""" + from utils.drum_map import AdtofDrumClass, gm_note + + assert gm_note(AdtofDrumClass.KICK) == 35 + assert gm_note(AdtofDrumClass.SNARE) == 38 + assert gm_note(AdtofDrumClass.TOM) == 47 + assert gm_note(AdtofDrumClass.HI_HAT) == 42 + assert gm_note(AdtofDrumClass.CYMBAL) == 49 + + +def test_gm_drum_names() -> None: + """GM_DRUM_NAMES maps GM note numbers to human-readable percussion names.""" + from utils.drum_map import GM_DRUM_NAMES + + assert GM_DRUM_NAMES[35] == "Acoustic Bass Drum" + assert GM_DRUM_NAMES[38] == "Acoustic Snare" + assert GM_DRUM_NAMES[47] == "Mid Tom" + assert GM_DRUM_NAMES[42] == "Closed Hi-Hat" + assert GM_DRUM_NAMES[49] == "Crash Cymbal 1" diff --git a/tests/test_drum_midi_integration.py b/tests/test_drum_midi_integration.py new file mode 100644 index 0000000..dd8fb6f --- /dev/null +++ b/tests/test_drum_midi_integration.py @@ -0,0 +1,134 @@ +"""Integration tests for drum MIDI channel routing and API endpoint. + +Tests: + - test_drum_stem_writes_channel_10: drum stem -> is_drum=True, survives round-trip + - test_non_drum_stem_writes_channel_1: non-drum stem -> is_drum=False, survives round-trip + - test_gm_programs_includes_adt_models: /api/midi/gm-programs returns adt_models list + +RED phase notes: + - test_drum_stem_writes_channel_10 and test_non_drum_stem_writes_channel_1 should + already pass (channel routing is in place from Phase 3). + - test_gm_programs_includes_adt_models MUST fail until backend/api/midi.py is + extended to include the adt_models key. +""" +from __future__ import annotations + +import pathlib +import tempfile + +import pretty_midi +import pytest +from unittest.mock import MagicMock + +from pipelines.midi_pipeline import MidiPipeline, MidiConfig + + +# --------------------------------------------------------------------------- +# Test infrastructure — reused from test_midi_pipeline_routing.py pattern +# --------------------------------------------------------------------------- + + +def _make_pipeline_with_mock_loader(): + """Create a MidiPipeline with a mock loader, bypassing load_model().""" + pipeline = MidiPipeline() + pipeline._config = MidiConfig() + mock_loader = MagicMock() + # convert_drum_to_midi returns one kick note event + mock_loader.convert_drum_to_midi.return_value = [(0.1, 0.16, 35, 100)] + # convert_audio_to_midi returns one pitched note + mock_loader.convert_audio_to_midi.return_value = [(0.1, 0.4, 60, 80)] + # convert_vocal_to_midi returns (notes, lyrics) + mock_loader.convert_vocal_to_midi.return_value = ([(0.1, 0.5, 60, 80)], []) + pipeline._loader = mock_loader + pipeline.is_loaded = True + return pipeline, mock_loader + + +# --------------------------------------------------------------------------- +# Channel routing integration tests (with round-trip serialization) +# --------------------------------------------------------------------------- + + +def test_drum_stem_writes_channel_10(tmp_path: pathlib.Path) -> None: + """Drum stem produces is_drum=True in-memory and after MIDI file round-trip.""" + pipeline, _ = _make_pipeline_with_mock_loader() + tmp_file = tmp_path / "drums.wav" + tmp_file.write_bytes(b"\x00") + + result = pipeline.run({"drums": tmp_file}) + + midi_obj = result.stem_midi_data["drums"] + assert len(midi_obj.instruments) > 0, "MIDI object should have at least one instrument" + + # In-memory assertion + assert midi_obj.instruments[0].is_drum is True, ( + "Drum stem MIDI should have is_drum=True on instrument[0] in memory" + ) + + # Round-trip: write to .mid file, read back, verify is_drum survives + mid_path = tmp_path / "drums_roundtrip.mid" + midi_obj.write(str(mid_path)) + reloaded = pretty_midi.PrettyMIDI(str(mid_path)) + assert len(reloaded.instruments) > 0, "Round-tripped MIDI should have at least one instrument" + assert reloaded.instruments[0].is_drum is True, ( + "Drum stem MIDI should have is_drum=True after MIDI file round-trip" + ) + + +def test_non_drum_stem_writes_channel_1(tmp_path: pathlib.Path) -> None: + """Non-drum stem produces is_drum=False in-memory and after MIDI file round-trip.""" + pipeline, _ = _make_pipeline_with_mock_loader() + tmp_file = tmp_path / "bass.wav" + tmp_file.write_bytes(b"\x00") + + result = pipeline.run({"bass": tmp_file}) + + midi_obj = result.stem_midi_data["bass"] + assert len(midi_obj.instruments) > 0, "MIDI object should have at least one instrument" + + # In-memory assertion + assert midi_obj.instruments[0].is_drum is False, ( + "Non-drum stem MIDI should have is_drum=False on instrument[0] in memory" + ) + + # Round-trip: write to .mid file, read back, verify is_drum survives + mid_path = tmp_path / "bass_roundtrip.mid" + midi_obj.write(str(mid_path)) + reloaded = pretty_midi.PrettyMIDI(str(mid_path)) + assert len(reloaded.instruments) > 0, "Round-tripped MIDI should have at least one instrument" + assert reloaded.instruments[0].is_drum is False, ( + "Non-drum stem MIDI should have is_drum=False after MIDI file round-trip" + ) + + +# --------------------------------------------------------------------------- +# API endpoint test: /api/midi/gm-programs must include adt_models +# --------------------------------------------------------------------------- + + +def test_gm_programs_includes_adt_models() -> None: + """GET /api/midi/gm-programs returns 200 with adt_models key containing adtof-drums.""" + from fastapi.testclient import TestClient + from backend.main import app + + client = TestClient(app) + response = client.get("/api/midi/gm-programs") + + assert response.status_code == 200, ( + f"Expected 200 from /api/midi/gm-programs, got {response.status_code}" + ) + + data = response.json() + assert "adt_models" in data, ( + f"Response missing 'adt_models' key. Keys present: {list(data.keys())}" + ) + + adt_models = data["adt_models"] + assert isinstance(adt_models, list), ( + f"Expected adt_models to be a list, got {type(adt_models)}" + ) + + model_ids = [entry.get("model_id") for entry in adt_models] + assert "adtof-drums" in model_ids, ( + f"Expected 'adtof-drums' in adt_models list. Got: {model_ids}" + ) diff --git a/tests/test_drum_midi_spec.py b/tests/test_drum_midi_spec.py new file mode 100644 index 0000000..6224773 --- /dev/null +++ b/tests/test_drum_midi_spec.py @@ -0,0 +1,81 @@ +""" +Tests for DrumMidiSpec registry entry and ADTOF drum model registration. + +Phase 01 Plan 02 — RED phase. +""" +from __future__ import annotations + + +def test_drum_midi_spec_importable() -> None: + """DrumMidiSpec is importable from models.registry and is a ModelSpec subclass.""" + from models.registry import DrumMidiSpec, ModelSpec + + assert isinstance(DrumMidiSpec, type) + assert issubclass(DrumMidiSpec, ModelSpec) + + +def test_adtof_registry_entry() -> None: + """list_specs(DrumMidiSpec) returns an entry for 'adtof-drums' with correct fields.""" + from models.registry import DrumMidiSpec, list_specs + + specs = list_specs(DrumMidiSpec) + assert len(specs) >= 1 + + match = next((s for s in specs if s.model_id == "adtof-drums"), None) + assert match is not None, "No entry with model_id='adtof-drums' found" + + spec = match # type: ignore[assignment] + assert spec.class_count == 5 + assert spec.class_labels == ("kick", "snare", "tom", "hi_hat", "cymbal") + assert spec.sample_rate == 44_100 + + +def test_adtof_capabilities() -> None: + """ADTOF registry entry has the correct capability tags.""" + from models.registry import get_spec + + spec = get_spec("adtof-drums") + assert spec.capabilities == frozenset({"transcribe", "drum_transcription", "gpu_acceleration"}) + + +def test_adtof_cache_subdir() -> None: + """ADTOF registry entry uses cache_subdir='adtof'.""" + from models.registry import get_spec + + spec = get_spec("adtof-drums") + assert spec.cache_subdir == "adtof" + + +def test_adtof_checkpoint_url_empty() -> None: + """ADTOF checkpoint_url is empty (weights bundled in pip package).""" + from models.registry import get_spec + + spec = get_spec("adtof-drums") + assert spec.checkpoint_url == "" + + +def test_get_loader_kwargs_adtof() -> None: + """get_loader_kwargs('adtof-drums') returns a dict containing 'cache_dir' without raising.""" + from models.registry import get_loader_kwargs + + kw = get_loader_kwargs("adtof-drums") + assert isinstance(kw, dict) + assert "cache_dir" in kw + + +def test_get_pipeline_defaults_adtof() -> None: + """get_pipeline_defaults('adtof-drums') returns a dict containing 'class_count' without raising.""" + from models.registry import get_pipeline_defaults + + defaults = get_pipeline_defaults("adtof-drums") + assert isinstance(defaults, dict) + assert "class_count" in defaults + + +def test_get_gui_metadata_adtof() -> None: + """get_gui_metadata('adtof-drums') returns a dict containing 'tooltip' without raising.""" + from models.registry import get_gui_metadata + + meta = get_gui_metadata("adtof-drums") + assert isinstance(meta, dict) + assert "tooltip" in meta diff --git a/tests/test_midi_loader_drum.py b/tests/test_midi_loader_drum.py new file mode 100644 index 0000000..10e8351 --- /dev/null +++ b/tests/test_midi_loader_drum.py @@ -0,0 +1,161 @@ +"""Tests for MidiModelLoader ADTOF drum transcription extension. + +RED phase: Tests for _ensure_adtof(), convert_drum_to_midi(), evict_drum_model(). +All tests expected to fail until models/midi_loader.py is extended. +""" +from __future__ import annotations + +import pathlib +from unittest.mock import MagicMock, patch, sentinel + +import pytest + +from models.midi_loader import MidiModelLoader +from utils.errors import PipelineExecutionError + + +# --------------------------------------------------------------------------- +# Lazy loading +# --------------------------------------------------------------------------- + + +def test_adtof_lazy_not_loaded_at_init() -> None: + """_adtof_backend is None immediately after construction — ADTOF not loaded at init.""" + loader = MidiModelLoader() + assert loader._adtof_backend is None + + +def test_ensure_adtof_returns_backend() -> None: + """_ensure_adtof() returns an AdtofBackend instance after first call.""" + loader = MidiModelLoader() + + mock_backend_instance = MagicMock() + mock_backend_cls = MagicMock(return_value=mock_backend_instance) + + with patch("pipelines.adtof_backend.AdtofBackend", mock_backend_cls): + # Patch the import inside _ensure_adtof via the module path it imports from + with patch.dict("sys.modules", {}): + import pipelines.adtof_backend as _adtof_mod + orig_cls = getattr(_adtof_mod, "AdtofBackend", None) + _adtof_mod.AdtofBackend = mock_backend_cls # type: ignore[attr-defined] + try: + # Patch the deferred import inside the method + with patch("models.midi_loader.MidiModelLoader._ensure_adtof", wraps=None) as _mock: + pass + finally: + if orig_cls is not None: + _adtof_mod.AdtofBackend = orig_cls + + # Simpler approach: directly test by patching the deferred import + loader2 = MidiModelLoader() + mock_instance = MagicMock() + mock_cls = MagicMock(return_value=mock_instance) + + import sys + # Create a mock for pipelines.adtof_backend that has AdtofBackend + mock_adtof_module = MagicMock() + mock_adtof_module.AdtofBackend = mock_cls + with patch.dict(sys.modules, {"pipelines.adtof_backend": mock_adtof_module}): + result = loader2._ensure_adtof() + + assert result is mock_instance + mock_instance.load.assert_called_once() + + +def test_ensure_adtof_caches_instance() -> None: + """Second call to _ensure_adtof() returns the same object (cached).""" + loader = MidiModelLoader() + mock_instance = MagicMock() + mock_cls = MagicMock(return_value=mock_instance) + + import sys + mock_adtof_module = MagicMock() + mock_adtof_module.AdtofBackend = mock_cls + with patch.dict(sys.modules, {"pipelines.adtof_backend": mock_adtof_module}): + first = loader._ensure_adtof() + second = loader._ensure_adtof() + + assert first is second + # AdtofBackend() constructor called only once + assert mock_cls.call_count == 1 + + +# --------------------------------------------------------------------------- +# convert_drum_to_midi +# --------------------------------------------------------------------------- + + +def test_convert_drum_to_midi_calls_predict() -> None: + """convert_drum_to_midi(path) calls backend.predict(path) and returns its result.""" + loader = MidiModelLoader() + expected = [(0.1, 0.16, 35, 100)] + + mock_backend = MagicMock() + mock_backend.predict.return_value = expected + loader._adtof_backend = mock_backend + + result = loader.convert_drum_to_midi(pathlib.Path("dummy.wav")) + + mock_backend.predict.assert_called_once_with(pathlib.Path("dummy.wav")) + assert result == expected + + +def test_convert_drum_to_midi_wraps_exception() -> None: + """Non-PipelineExecutionError from predict() is wrapped in PipelineExecutionError.""" + loader = MidiModelLoader() + + mock_backend = MagicMock() + mock_backend.predict.side_effect = ValueError("boom") + loader._adtof_backend = mock_backend + + with pytest.raises(PipelineExecutionError) as exc_info: + loader.convert_drum_to_midi(pathlib.Path("dummy.wav")) + + assert "midi" in exc_info.value.pipeline_name # type: ignore[union-attr] + + +# --------------------------------------------------------------------------- +# evict +# --------------------------------------------------------------------------- + + +def test_evict_clears_adtof_backend() -> None: + """evict() sets _adtof_backend to None and calls backend.evict().""" + loader = MidiModelLoader() + mock_backend = MagicMock() + loader._adtof_backend = mock_backend + + loader.evict() + + assert loader._adtof_backend is None + mock_backend.evict.assert_called_once() + + +def test_evict_no_error_when_none() -> None: + """evict() with _adtof_backend=None does not raise any exception.""" + loader = MidiModelLoader() + assert loader._adtof_backend is None + + # Must not raise + loader.evict() + + +# --------------------------------------------------------------------------- +# evict_drum_model +# --------------------------------------------------------------------------- + + +def test_evict_drum_model_leaves_basicpitch() -> None: + """evict_drum_model() clears _adtof_backend only; _model (BasicPitch) is untouched.""" + loader = MidiModelLoader() + mock_backend = MagicMock() + loader._adtof_backend = mock_backend + # Simulate a loaded BasicPitch model + sentinel_bp = sentinel.basicpitch_model + loader._model = sentinel_bp # type: ignore[assignment] + + loader.evict_drum_model() + + assert loader._adtof_backend is None + mock_backend.evict.assert_called_once() + assert loader._model is sentinel_bp diff --git a/tests/test_midi_pipeline_routing.py b/tests/test_midi_pipeline_routing.py new file mode 100644 index 0000000..08be30b --- /dev/null +++ b/tests/test_midi_pipeline_routing.py @@ -0,0 +1,212 @@ +"""Tests for MidiPipeline drum stem routing branch. + +RED phase: Tests for _DRUM_STEM_LABELS, drum routing in run(), is_drum parameter +on _build_stem_midi(), 3-stage progress callbacks, and post-loop eviction. +All tests expected to fail until pipelines/midi_pipeline.py is extended. +""" +from __future__ import annotations + +import pathlib +from unittest.mock import MagicMock, call + +import pytest + +from pipelines.midi_pipeline import MidiPipeline, MidiConfig + + +# --------------------------------------------------------------------------- +# Test infrastructure +# --------------------------------------------------------------------------- + + +def _make_pipeline_with_mock_loader(): + """Create a MidiPipeline with a mock loader, bypassing load_model().""" + pipeline = MidiPipeline() + pipeline._config = MidiConfig() + mock_loader = MagicMock() + # convert_drum_to_midi returns one kick note event + mock_loader.convert_drum_to_midi.return_value = [(0.1, 0.16, 35, 100)] + # convert_audio_to_midi returns one pitched note + mock_loader.convert_audio_to_midi.return_value = [(0.1, 0.4, 60, 80)] + # convert_vocal_to_midi returns (notes, lyrics) + mock_loader.convert_vocal_to_midi.return_value = ([(0.1, 0.5, 60, 80)], []) + pipeline._loader = mock_loader + pipeline.is_loaded = True + return pipeline, mock_loader + + +# --------------------------------------------------------------------------- +# Routing tests +# --------------------------------------------------------------------------- + + +def test_drum_label_routes_to_adt(tmp_path: pathlib.Path) -> None: + """A stem labeled 'drums' calls convert_drum_to_midi(), not convert_audio_to_midi().""" + pipeline, mock_loader = _make_pipeline_with_mock_loader() + tmp_file = tmp_path / "drums.wav" + tmp_file.write_bytes(b"\x00") + + pipeline.run({"drums": tmp_file}) + + assert mock_loader.convert_drum_to_midi.called, ( + "convert_drum_to_midi() should be called for 'drums' label" + ) + assert not mock_loader.convert_audio_to_midi.called, ( + "convert_audio_to_midi() should NOT be called for 'drums' label" + ) + + +def test_roformer_drum_label_routes_to_adt(tmp_path: pathlib.Path) -> None: + """A stem labeled 'Drums & percussion' calls convert_drum_to_midi(), not convert_audio_to_midi().""" + pipeline, mock_loader = _make_pipeline_with_mock_loader() + tmp_file = tmp_path / "drums_perc.wav" + tmp_file.write_bytes(b"\x00") + + pipeline.run({"Drums & percussion": tmp_file}) + + assert mock_loader.convert_drum_to_midi.called, ( + "convert_drum_to_midi() should be called for 'Drums & percussion' label" + ) + assert not mock_loader.convert_audio_to_midi.called, ( + "convert_audio_to_midi() should NOT be called for 'Drums & percussion' label" + ) + + +def test_vocal_label_not_routed_to_drum(tmp_path: pathlib.Path) -> None: + """A stem labeled 'vocals' calls convert_vocal_to_midi(), not convert_drum_to_midi().""" + pipeline, mock_loader = _make_pipeline_with_mock_loader() + tmp_file = tmp_path / "vocals.wav" + tmp_file.write_bytes(b"\x00") + + pipeline.run({"vocals": tmp_file}) + + assert mock_loader.convert_vocal_to_midi.called, ( + "convert_vocal_to_midi() should be called for 'vocals' label" + ) + assert not mock_loader.convert_drum_to_midi.called, ( + "convert_drum_to_midi() should NOT be called for 'vocals' label" + ) + + +def test_bass_label_not_routed_to_drum(tmp_path: pathlib.Path) -> None: + """A stem labeled 'bass' calls convert_audio_to_midi(), not convert_drum_to_midi().""" + pipeline, mock_loader = _make_pipeline_with_mock_loader() + tmp_file = tmp_path / "bass.wav" + tmp_file.write_bytes(b"\x00") + + pipeline.run({"bass": tmp_file}) + + assert mock_loader.convert_audio_to_midi.called, ( + "convert_audio_to_midi() should be called for 'bass' label" + ) + assert not mock_loader.convert_drum_to_midi.called, ( + "convert_drum_to_midi() should NOT be called for 'bass' label" + ) + + +# --------------------------------------------------------------------------- +# is_drum parameter tests +# --------------------------------------------------------------------------- + + +def test_drum_stem_midi_is_drum_true(tmp_path: pathlib.Path) -> None: + """Drum stem result has instruments[0].is_drum == True.""" + pipeline, mock_loader = _make_pipeline_with_mock_loader() + tmp_file = tmp_path / "drums.wav" + tmp_file.write_bytes(b"\x00") + + result = pipeline.run({"drums": tmp_file}) + + midi_obj = result.stem_midi_data["drums"] + assert len(midi_obj.instruments) > 0, "MIDI object should have at least one instrument" + assert midi_obj.instruments[0].is_drum is True, ( + "Drum stem MIDI should have is_drum=True on instrument[0]" + ) + + +def test_non_drum_stem_midi_is_drum_false(tmp_path: pathlib.Path) -> None: + """Non-drum stem result has instruments[0].is_drum == False.""" + pipeline, mock_loader = _make_pipeline_with_mock_loader() + tmp_file = tmp_path / "bass.wav" + tmp_file.write_bytes(b"\x00") + + result = pipeline.run({"bass": tmp_file}) + + midi_obj = result.stem_midi_data["bass"] + assert len(midi_obj.instruments) > 0, "MIDI object should have at least one instrument" + assert midi_obj.instruments[0].is_drum is False, ( + "Non-drum stem MIDI should have is_drum=False on instrument[0]" + ) + + +# --------------------------------------------------------------------------- +# 3-stage progress callback test +# --------------------------------------------------------------------------- + + +def test_drum_path_reports_progress_3_stages(tmp_path: pathlib.Path) -> None: + """_report() is called at least 3 times specifically during drum stem processing. + + Per RESEARCH.md Pitfall 6, the drum branch must report at 3 stages: + (1) before _ensure_adtof() / convert_drum_to_midi() — before loading + (2) after convert_drum_to_midi() returns — after loading/predict + (3) after predict (done) — final per-stem marker + """ + pipeline, mock_loader = _make_pipeline_with_mock_loader() + tmp_file = tmp_path / "drums.wav" + tmp_file.write_bytes(b"\x00") + + progress_callback = MagicMock() + pipeline._progress_callback = progress_callback + + pipeline.run({"drums": tmp_file}) + + # Overall pipeline reports: initial 2.0, base_pct at start of stem, plus + # at least 3 drum-specific stages, plus 80.0 and 100.0 at the end. + # Total must be at least 5 (2 framing + 3 drum-specific). + assert progress_callback.call_count >= 5, ( + f"Expected at least 5 _report() calls during drum stem processing, " + f"got {progress_callback.call_count}" + ) + + # Extract all percentage values passed to the callback + reported_pcts = [c.args[0] for c in progress_callback.call_args_list] + + # The drum branch must fire 3 calls in quick succession inside the stem loop. + # All drum-branch calls occur after base_pct (5.0 for 1 stem) and before 80.0. + drum_branch_calls = [p for p in reported_pcts if 5.0 < p < 80.0] + assert len(drum_branch_calls) >= 3, ( + f"Expected at least 3 progress callbacks in the drum branch range (5.0 < pct < 80.0), " + f"got {len(drum_branch_calls)}: {drum_branch_calls}" + ) + + +# --------------------------------------------------------------------------- +# Eviction tests +# --------------------------------------------------------------------------- + + +def test_drum_evict_called_after_loop(tmp_path: pathlib.Path) -> None: + """evict_drum_model() is called after the stems loop when drum stems were present.""" + pipeline, mock_loader = _make_pipeline_with_mock_loader() + tmp_file = tmp_path / "drums.wav" + tmp_file.write_bytes(b"\x00") + + pipeline.run({"drums": tmp_file}) + + assert mock_loader.evict_drum_model.called, ( + "evict_drum_model() should be called after processing drum stems" + ) + + +def test_no_drum_evict_when_no_drums(tmp_path: pathlib.Path) -> None: + """evict_drum_model() is NOT called when no drum stems in request.""" + pipeline, mock_loader = _make_pipeline_with_mock_loader() + tmp_file = tmp_path / "bass.wav" + tmp_file.write_bytes(b"\x00") + + pipeline.run({"bass": tmp_file}) + + assert not mock_loader.evict_drum_model.called, ( + "evict_drum_model() should NOT be called when no drum stems are present" + ) diff --git a/tests/test_notes_to_midi_drum.py b/tests/test_notes_to_midi_drum.py new file mode 100644 index 0000000..8d81fe9 --- /dev/null +++ b/tests/test_notes_to_midi_drum.py @@ -0,0 +1,69 @@ +"""Unit tests for is_drum parameter and 60ms cap in utils.midi_io.notes_to_midi(). + +Written FIRST (TDD RED phase). Tests for is_drum=True behavior will fail with +TypeError until notes_to_midi() is extended with the is_drum parameter (Task 2). +""" +from __future__ import annotations + +import pytest + + +def test_notes_to_midi_drum_channel() -> None: + """notes_to_midi(events, is_drum=True) produces an instrument routed to channel 10.""" + from utils.midi_io import notes_to_midi + + result = notes_to_midi([(0.5, 1.5, 35, 100)], is_drum=True) + + assert result.instruments[0].is_drum is True + assert result.instruments[0].name == "StemForge Drums" + + +def test_notes_to_midi_regression() -> None: + """notes_to_midi() without is_drum produces identical output to pre-change behavior.""" + from utils.midi_io import notes_to_midi + + # Default call (no is_drum arg) — must match original Acoustic Grand Piano behavior. + result_default = notes_to_midi([(0.5, 1.5, 60, 100)]) + assert result_default.instruments[0].is_drum is False + assert result_default.instruments[0].program == 0 + assert result_default.instruments[0].name == "StemForge" + + # Explicit is_drum=False — must be identical to the default. + result_explicit = notes_to_midi([(0.5, 1.5, 60, 100)], is_drum=False) + assert result_explicit.instruments[0].is_drum is False + assert result_explicit.instruments[0].program == 0 + assert result_explicit.instruments[0].name == "StemForge" + + +def test_drum_note_60ms_cap() -> None: + """Drum notes longer than 60ms are clamped to exactly 60ms duration.""" + from utils.midi_io import notes_to_midi + + # Note with start=1.0, end=2.0 (1000ms) should be clamped to end=1.06. + result = notes_to_midi([(1.0, 2.0, 35, 100)], is_drum=True) + + assert len(result.instruments[0].notes) == 1 + note = result.instruments[0].notes[0] + assert note.end == pytest.approx(1.06, abs=1e-6) + + +def test_drum_note_short_passthrough() -> None: + """Drum notes already shorter than 60ms are not modified.""" + from utils.midi_io import notes_to_midi + + # Note with start=1.0, end=1.03 (30ms) should pass through unchanged. + result = notes_to_midi([(1.0, 1.03, 35, 100)], is_drum=True) + + assert len(result.instruments[0].notes) == 1 + note = result.instruments[0].notes[0] + assert note.end == pytest.approx(1.03, abs=1e-6) + + +def test_drum_degenerate_note_filtered() -> None: + """Degenerate notes (end <= start) are filtered before the 60ms cap is applied.""" + from utils.midi_io import notes_to_midi + + # Note with start=1.0, end=0.5 — end < start, must be discarded entirely. + result = notes_to_midi([(1.0, 0.5, 35, 100)], is_drum=True) + + assert len(result.instruments[0].notes) == 0 diff --git a/utils/drum_map.py b/utils/drum_map.py new file mode 100644 index 0000000..bf5862e --- /dev/null +++ b/utils/drum_map.py @@ -0,0 +1,71 @@ +"""GM drum mapping constants for automatic drum transcription. + +Provides the canonical mapping between ADTOF model output class indices +and General MIDI percussion note numbers (channel 10). +""" +from __future__ import annotations + +from enum import IntEnum + + +class AdtofDrumClass(IntEnum): + """ADTOF 5-class drum output indices. + + Values match the model's output channel ordering from + ``xavriley/ADTOF-pytorch`` ``LABELS_5 = [35, 38, 47, 42, 49]``. + """ + KICK = 0 + SNARE = 1 + TOM = 2 + HI_HAT = 3 + CYMBAL = 4 + + +# Internal mapping: AdtofDrumClass -> GM note number. +# NON-SEQUENTIAL: index 2 (tom) maps to 47, index 3 (hi-hat) maps to 42. +# This ordering is preserved from ADTOF model training — do not sort numerically. +_ADTOF_GM: dict[AdtofDrumClass, int] = { + AdtofDrumClass.KICK: 35, # Acoustic Bass Drum + AdtofDrumClass.SNARE: 38, # Acoustic Snare + AdtofDrumClass.TOM: 47, # Mid Tom + AdtofDrumClass.HI_HAT: 42, # Closed Hi-Hat + AdtofDrumClass.CYMBAL: 49, # Crash Cymbal 1 +} + +# Flat {int_index: gm_note} dict for downstream consumers. +ADTOF_5CLASS_GM_NOTE: dict[int, int] = {int(k): v for k, v in _ADTOF_GM.items()} + +# Human-readable GM drum names for logging and Phase 4 UI. +GM_DRUM_NAMES: dict[int, str] = { + 35: "Acoustic Bass Drum", + 38: "Acoustic Snare", + 42: "Closed Hi-Hat", + 47: "Mid Tom", + 49: "Crash Cymbal 1", +} + +# --- 7-class expansion reference (ECLASS-01, v2) --- +# AdtofDrumClass would add: OPEN_HI_HAT = 5, RIDE = 6 +# GM notes: open hi-hat = 46, ride cymbal = 51 +# Crash/ride split requires retraining or post-hoc amplitude heuristics. + + +def gm_note(drum_class: AdtofDrumClass) -> int: + """Return the General MIDI note number for *drum_class*. + + Parameters + ---------- + drum_class: + An :class:`AdtofDrumClass` member. + + Returns + ------- + int + GM percussion note number (channel 10). + + Raises + ------ + KeyError + If *drum_class* is not a valid ADTOF class. + """ + return _ADTOF_GM[drum_class] diff --git a/utils/midi_io.py b/utils/midi_io.py index a0fcc7e..c40ec2c 100644 --- a/utils/midi_io.py +++ b/utils/midi_io.py @@ -86,6 +86,7 @@ def notes_to_midi( ticks_per_beat: int = 480, tempo_bpm: float = 120.0, lyrics: list[LyricEvent] | None = None, + is_drum: bool = False, ) -> MidiData: """Convert a list of note events to a MIDI data object. @@ -97,6 +98,13 @@ def notes_to_midi( MIDI ticks per quarter-note (resolution). tempo_bpm: Tempo in beats per minute for the generated MIDI file. + lyrics: + Optional list of ``(time_sec, word)`` lyric events to embed. + is_drum: + When ``True``, the instrument is routed to General MIDI channel 10 + (percussion channel) and note durations are clamped to 60 ms. + When ``False`` (default), behavior is identical to pre-change + behaviour: Acoustic Grand Piano on a melodic channel. Returns ------- @@ -107,14 +115,23 @@ def notes_to_midi( resolution=int(ticks_per_beat), initial_tempo=float(tempo_bpm), ) - instrument = pretty_midi.Instrument( - program=pretty_midi.instrument_name_to_program("Acoustic Grand Piano"), - name="StemForge", - ) + if is_drum: + instrument = pretty_midi.Instrument( + program=0, + is_drum=True, + name="StemForge Drums", + ) + else: + instrument = pretty_midi.Instrument( + program=pretty_midi.instrument_name_to_program("Acoustic Grand Piano"), + name="StemForge", + ) for start, end, pitch, velocity in note_events: # Guard against degenerate notes (zero or negative duration). if end <= start: continue + if is_drum: + end = min(end, start + 0.06) # 60 ms hard cap for percussion note = pretty_midi.Note( velocity=max(1, min(127, int(velocity))), pitch=max(0, min(127, int(pitch))), diff --git a/uv.lock b/uv.lock index ca5d5db..ef33897 100644 --- a/uv.lock +++ b/uv.lock @@ -125,6 +125,17 @@ requires-dist = [ [package.metadata.requires-dev] dev = [] +[[package]] +name = "adtof-pytorch" +version = "0.1.0" +source = { git = "https://github.com/xavriley/ADTOF-pytorch?rev=85c192e#85c192e78f716ea0b111cc8a5ee4a8f6a3a4f8a9" } +dependencies = [ + { name = "librosa", marker = "sys_platform == 'linux'" }, + { name = "numpy", marker = "sys_platform == 'linux'" }, + { name = "pretty-midi", marker = "sys_platform == 'linux'" }, + { name = "torch", marker = "sys_platform == 'linux'" }, +] + [[package]] name = "ai-edge-litert" version = "2.1.2" @@ -2474,6 +2485,7 @@ source = { editable = "." } dependencies = [ { name = "accelerate", marker = "sys_platform == 'linux'" }, { name = "ace-step", marker = "sys_platform == 'linux'" }, + { name = "adtof-pytorch", marker = "sys_platform == 'linux'" }, { name = "ai-edge-litert", marker = "sys_platform == 'linux'" }, { name = "audio-separator", marker = "sys_platform == 'linux'" }, { name = "bs-roformer", marker = "sys_platform == 'linux'" }, @@ -2544,6 +2556,7 @@ dev = [ requires-dist = [ { name = "accelerate", specifier = ">=0.30.0" }, { name = "ace-step", editable = "Ace-Step-Wrangler/vendor/ACE-Step-1.5" }, + { name = "adtof-pytorch", git = "https://github.com/xavriley/ADTOF-pytorch?rev=85c192e" }, { name = "ai-edge-litert", specifier = ">=1.1.0" }, { name = "audio-separator", directory = "vendor/python-audio-separator" }, { name = "bs-roformer", specifier = ">=0.3.2" },