From e2d12b87ddca12f2846c5cb818815338c3e869ec Mon Sep 17 00:00:00 2001 From: kurt Date: Sat, 21 Mar 2026 08:41:57 +1000 Subject: [PATCH 1/3] Add ADTOF drum transcription backend and registry Introduce dedicated Automatic Drum Transcription (ADT) support using ADTOF-pytorch. Drums were previously routed through BasicPitch (a pitched-instrument model) which produced poor results on unpitched percussion. - Add DrumMidiSpec to model registry with 5-class ADTOF descriptor - Add drum_map.py: GM MIDI note mapping for ADT class outputs - Add adtof_backend.py: AdtofBackend with load/predict/evict lifecycle - Extend notes_to_midi with is_drum parameter (channel 10, fixed 60ms note duration, no pitch bend) - Add adtof-pytorch dependency --- models/registry.py | 66 +++++++++++++ pipelines/adtof_backend.py | 198 +++++++++++++++++++++++++++++++++++++ pyproject.toml | 3 +- utils/drum_map.py | 71 +++++++++++++ utils/midi_io.py | 25 ++++- uv.lock | 13 +++ 6 files changed, 371 insertions(+), 5 deletions(-) create mode 100644 pipelines/adtof_backend.py create mode 100644 utils/drum_map.py diff --git a/models/registry.py b/models/registry.py index 973d847..8b2df0d 100644 --- a/models/registry.py +++ b/models/registry.py @@ -231,6 +231,26 @@ class StableAudioSpec(ModelSpec): conditioning_keys: tuple[str, ...] +@dataclass(frozen=True, slots=True) +class DrumMidiSpec(ModelSpec): + """Descriptor for a drum transcription (ADT) model. + + Additional fields + ----------------- + class_count: + Number of drum output classes this model produces. + class_labels: + Human-readable label for each class, in model output order. + checkpoint_url: + Direct download URL for model weights; empty string means weights + are bundled in the pip package. + """ + + class_count: int = 0 + class_labels: tuple[str, ...] = () + checkpoint_url: str = "" + + @dataclass(frozen=True, slots=True) class RoformerSpec(ModelSpec): """Descriptor for a BS-Roformer / MelBand-Roformer separation model. @@ -742,6 +762,38 @@ def _register(spec: ModelSpec) -> ModelSpec: default_num_overlap=2, )) +# --------------------------------------------------------------------------- +# Drum transcription (ADT) models +# --------------------------------------------------------------------------- + +ADTOF_DRUMS = _register(DrumMidiSpec( + model_id="adtof-drums", + display_name="ADTOF Drums (5-class)", + version="0.1.0", + source="xavriley/ADTOF-pytorch", + device="auto", + gpu_capable=True, + device_fallback="cpu", + device_quirks="", + sample_rate=44_100, + hop_size=0, + chunk_size=0, + max_duration_seconds=0.0, + default_bpm=0.0, + default_key="", + default_time_signature="", + quantize_grid="none", + default_min_note_ms=0.0, + capabilities=frozenset({"transcribe", "drum_transcription", "gpu_acceleration"}), + cache_subdir="adtof", + description="Automatic drum transcription — 5-class (kick, snare, tom, hi-hat, cymbal).", + preprocessing="Mono 44.1 kHz; mel spectrogram; context window of 9 frames.", + postprocessing="Peak picking per class; GM note mapping; 60 ms fixed note duration.", + class_count=5, + class_labels=("kick", "snare", "tom", "hi_hat", "cymbal"), + checkpoint_url="", +)) + # --------------------------------------------------------------------------- # Convenience constants # --------------------------------------------------------------------------- @@ -854,6 +906,9 @@ def get_loader_kwargs(model_id: str) -> dict[str, Any]: "config_url": spec.config_url, } + if isinstance(spec, DrumMidiSpec): + return {"cache_dir": spec.cache_dir} + raise NotImplementedError( f"No loader kwargs defined for spec type {type(spec).__name__!r}." ) @@ -925,6 +980,9 @@ def get_pipeline_defaults(model_id: str) -> dict[str, Any]: "num_overlap": spec.default_num_overlap, } + if isinstance(spec, DrumMidiSpec): + return {"class_count": spec.class_count} + raise NotImplementedError( f"No pipeline defaults defined for spec type {type(spec).__name__!r}." ) @@ -1004,6 +1062,14 @@ def get_gui_metadata(model_id: str) -> dict[str, Any]: "target_instrument": spec.target_instrument, } + if isinstance(spec, DrumMidiSpec): + return { + "class_count": spec.class_count, + "class_labels": list(spec.class_labels), + "model_choices": [s.model_id for s in list_specs(DrumMidiSpec)], + "tooltip": spec.description, + } + raise NotImplementedError( f"No GUI metadata defined for spec type {type(spec).__name__!r}." ) diff --git a/pipelines/adtof_backend.py b/pipelines/adtof_backend.py new file mode 100644 index 0000000..6421335 --- /dev/null +++ b/pipelines/adtof_backend.py @@ -0,0 +1,198 @@ +"""ADTOF-pytorch automatic drum transcription backend. + +Provides :class:`AdtofBackendProtocol` (structural interface for ADT backends) +and :class:`AdtofBackend` (ADTOF Frame-RNN implementation). + +Usage:: + + backend = AdtofBackend() + backend.load(device="cuda") + events = backend.predict(Path("drums.wav")) + backend.evict() +""" +from __future__ import annotations + +import logging +import pathlib +from typing import Protocol, runtime_checkable + +import soundfile as sf +import torch + +from utils.errors import InvalidInputError, ModelLoadError, PipelineExecutionError +from utils.midi_io import NoteEvent + +logger = logging.getLogger(__name__) + +NOTE_DURATION: float = 0.06 # 60ms — matches notes_to_midi() drum cap +DEFAULT_VELOCITY: int = 100 # Fixed velocity; amplitude-based is v2 (ECLASS-02) +_VALID_GM_NOTES: frozenset[int] = frozenset({35, 38, 42, 47, 49}) + + +def _peaks_to_note_events( + peaks: dict[int, list[float]], + duration: float = NOTE_DURATION, + velocity: int = DEFAULT_VELOCITY, +) -> list[NoteEvent]: + """Convert PeakPicker output to sorted NoteEvent list. + + Parameters + ---------- + peaks: + ``{gm_note: [onset_time_sec, ...]}`` from ``PeakPicker.pick()``. + duration: + Fixed note duration in seconds. + velocity: + Fixed MIDI velocity (1-127). + + Returns + ------- + list[NoteEvent] + Sorted by onset time ascending. + """ + events: list[NoteEvent] = [] + for gm_note, times in peaks.items(): + gm_note_int = int(gm_note) + assert gm_note_int in _VALID_GM_NOTES, ( + f"Unexpected GM note {gm_note_int} from PeakPicker; " + f"expected one of {sorted(_VALID_GM_NOTES)}" + ) + for onset in times: + events.append((onset, onset + duration, gm_note_int, velocity)) + events.sort(key=lambda e: e[0]) + return events + + +@runtime_checkable +class AdtofBackendProtocol(Protocol): + """Structural interface for automatic drum transcription backends. + + Implementations are not required to inherit from this class. Any class + with matching ``load``, ``predict``, and ``evict`` method signatures will + satisfy ``isinstance(obj, AdtofBackendProtocol)`` checks at runtime. + """ + + def load(self, device: str = "cpu") -> None: + """Load model weights into memory on the given device.""" + ... + + def predict(self, audio_path: pathlib.Path) -> list[NoteEvent]: + """Transcribe drum audio; return NoteEvent tuples, GM notes only.""" + ... + + def evict(self) -> None: + """Release model weights from memory.""" + ... + + +class AdtofBackend: + """ADTOF Frame-RNN drum transcription backend. + + Implements :class:`AdtofBackendProtocol` via structural subtyping + (no inheritance required). + + Example:: + + backend = AdtofBackend() + backend.load(device="cuda") + events = backend.predict(Path("drums_stem.wav")) + backend.evict() + """ + + def __init__(self) -> None: + self._model: torch.nn.Module | None = None + self._device: str = "cpu" + + def load(self, device: str = "cpu") -> None: + """Load ADTOF Frame-RNN weights into memory. + + Imports adtof_pytorch functions at call time to preserve lazy loading. + Wraps any exception from weight loading in :class:`ModelLoadError`. + + Parameters + ---------- + device: + Torch device string, e.g. ``"cpu"``, ``"cuda"``, ``"cuda:0"``. + """ + try: + import adtof_pytorch + + n_bins = adtof_pytorch.calculate_n_bins() + model = adtof_pytorch.create_frame_rnn_model(n_bins) + model.eval() + weights_path = adtof_pytorch.get_default_weights_path() + model = adtof_pytorch.load_pytorch_weights(model, str(weights_path), strict=False) + model.to(device) + self._model = model + self._device = device + logger.info("ADTOF model loaded on %s", device) + except Exception as exc: + raise ModelLoadError(str(exc), model_name="adtof-drums") from exc + + def predict(self, audio_path: pathlib.Path) -> list[NoteEvent]: + """Transcribe drum audio and return NoteEvent tuples. + + Raises + ------ + InvalidInputError + If the audio sample rate is not 44100 Hz. + PipelineExecutionError + If the model forward pass or onset detection fails. + RuntimeError + If called before ``load()``. + """ + # --- Sample rate guard (ADT-02) --- + info = sf.info(str(audio_path)) + if info.samplerate != 44100: + raise InvalidInputError( + f"ADTOF requires 44100 Hz audio; got {info.samplerate} Hz: {audio_path}", + field="audio_path", + ) + + if self._model is None: + raise RuntimeError("Model not loaded — call load() first") + + try: + from adtof_pytorch import ( + load_audio_for_model, + PeakPicker, + FRAME_RNN_THRESHOLDS, + LABELS_5, + ) + + # --- Forward pass (ADT-01: zero disk writes) --- + x = load_audio_for_model(str(audio_path)) + x = x.to(self._device) + with torch.no_grad(): + pred = self._model(x).cpu().numpy() + + # --- Onset detection (ADT-04) --- + picker = PeakPicker(thresholds=FRAME_RNN_THRESHOLDS, fps=100) + picked = picker.pick(pred, labels=LABELS_5, label_offset=0)[0] + + # --- Convert to NoteEvent --- + events = _peaks_to_note_events(picked) + logger.info("ADTOF transcription: %d note events", len(events)) + return events + + except InvalidInputError: + raise + except Exception as exc: + raise PipelineExecutionError( + f"ADTOF prediction failed: {exc}", + pipeline_name="adtof_backend", + ) from exc + + def evict(self) -> None: + """Release model weights from GPU/CPU memory. + + Moves the model to CPU before clearing the reference so PyTorch can + free GPU memory. Calls ``torch.cuda.empty_cache()`` when the model + was loaded on a CUDA device. + """ + if self._model is not None: + self._model.cpu() + self._model = None + if self._device.startswith("cuda"): + torch.cuda.empty_cache() + logger.info("ADTOF model evicted") diff --git a/pyproject.toml b/pyproject.toml index b4efa29..a64cbe3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,8 @@ dependencies = [ # GitHub-pinned ML components (renamed distributions) "demucs @ git+https://github.com/facebookresearch/demucs.git@v4.0.1", - #"basic-pitch @ git+https://github.com/spotify/basic-pitch.git", + "adtof-pytorch @ git+https://github.com/xavriley/ADTOF-pytorch@85c192e", + #"basic-pitch @ git+https://github.com/spotify/basic-pitch.git", # Demucs deps "dora-search==0.1.12", diff --git a/utils/drum_map.py b/utils/drum_map.py new file mode 100644 index 0000000..bf5862e --- /dev/null +++ b/utils/drum_map.py @@ -0,0 +1,71 @@ +"""GM drum mapping constants for automatic drum transcription. + +Provides the canonical mapping between ADTOF model output class indices +and General MIDI percussion note numbers (channel 10). +""" +from __future__ import annotations + +from enum import IntEnum + + +class AdtofDrumClass(IntEnum): + """ADTOF 5-class drum output indices. + + Values match the model's output channel ordering from + ``xavriley/ADTOF-pytorch`` ``LABELS_5 = [35, 38, 47, 42, 49]``. + """ + KICK = 0 + SNARE = 1 + TOM = 2 + HI_HAT = 3 + CYMBAL = 4 + + +# Internal mapping: AdtofDrumClass -> GM note number. +# NON-SEQUENTIAL: index 2 (tom) maps to 47, index 3 (hi-hat) maps to 42. +# This ordering is preserved from ADTOF model training — do not sort numerically. +_ADTOF_GM: dict[AdtofDrumClass, int] = { + AdtofDrumClass.KICK: 35, # Acoustic Bass Drum + AdtofDrumClass.SNARE: 38, # Acoustic Snare + AdtofDrumClass.TOM: 47, # Mid Tom + AdtofDrumClass.HI_HAT: 42, # Closed Hi-Hat + AdtofDrumClass.CYMBAL: 49, # Crash Cymbal 1 +} + +# Flat {int_index: gm_note} dict for downstream consumers. +ADTOF_5CLASS_GM_NOTE: dict[int, int] = {int(k): v for k, v in _ADTOF_GM.items()} + +# Human-readable GM drum names for logging and Phase 4 UI. +GM_DRUM_NAMES: dict[int, str] = { + 35: "Acoustic Bass Drum", + 38: "Acoustic Snare", + 42: "Closed Hi-Hat", + 47: "Mid Tom", + 49: "Crash Cymbal 1", +} + +# --- 7-class expansion reference (ECLASS-01, v2) --- +# AdtofDrumClass would add: OPEN_HI_HAT = 5, RIDE = 6 +# GM notes: open hi-hat = 46, ride cymbal = 51 +# Crash/ride split requires retraining or post-hoc amplitude heuristics. + + +def gm_note(drum_class: AdtofDrumClass) -> int: + """Return the General MIDI note number for *drum_class*. + + Parameters + ---------- + drum_class: + An :class:`AdtofDrumClass` member. + + Returns + ------- + int + GM percussion note number (channel 10). + + Raises + ------ + KeyError + If *drum_class* is not a valid ADTOF class. + """ + return _ADTOF_GM[drum_class] diff --git a/utils/midi_io.py b/utils/midi_io.py index a0fcc7e..c40ec2c 100644 --- a/utils/midi_io.py +++ b/utils/midi_io.py @@ -86,6 +86,7 @@ def notes_to_midi( ticks_per_beat: int = 480, tempo_bpm: float = 120.0, lyrics: list[LyricEvent] | None = None, + is_drum: bool = False, ) -> MidiData: """Convert a list of note events to a MIDI data object. @@ -97,6 +98,13 @@ def notes_to_midi( MIDI ticks per quarter-note (resolution). tempo_bpm: Tempo in beats per minute for the generated MIDI file. + lyrics: + Optional list of ``(time_sec, word)`` lyric events to embed. + is_drum: + When ``True``, the instrument is routed to General MIDI channel 10 + (percussion channel) and note durations are clamped to 60 ms. + When ``False`` (default), behavior is identical to pre-change + behaviour: Acoustic Grand Piano on a melodic channel. Returns ------- @@ -107,14 +115,23 @@ def notes_to_midi( resolution=int(ticks_per_beat), initial_tempo=float(tempo_bpm), ) - instrument = pretty_midi.Instrument( - program=pretty_midi.instrument_name_to_program("Acoustic Grand Piano"), - name="StemForge", - ) + if is_drum: + instrument = pretty_midi.Instrument( + program=0, + is_drum=True, + name="StemForge Drums", + ) + else: + instrument = pretty_midi.Instrument( + program=pretty_midi.instrument_name_to_program("Acoustic Grand Piano"), + name="StemForge", + ) for start, end, pitch, velocity in note_events: # Guard against degenerate notes (zero or negative duration). if end <= start: continue + if is_drum: + end = min(end, start + 0.06) # 60 ms hard cap for percussion note = pretty_midi.Note( velocity=max(1, min(127, int(velocity))), pitch=max(0, min(127, int(pitch))), diff --git a/uv.lock b/uv.lock index ca5d5db..ef33897 100644 --- a/uv.lock +++ b/uv.lock @@ -125,6 +125,17 @@ requires-dist = [ [package.metadata.requires-dev] dev = [] +[[package]] +name = "adtof-pytorch" +version = "0.1.0" +source = { git = "https://github.com/xavriley/ADTOF-pytorch?rev=85c192e#85c192e78f716ea0b111cc8a5ee4a8f6a3a4f8a9" } +dependencies = [ + { name = "librosa", marker = "sys_platform == 'linux'" }, + { name = "numpy", marker = "sys_platform == 'linux'" }, + { name = "pretty-midi", marker = "sys_platform == 'linux'" }, + { name = "torch", marker = "sys_platform == 'linux'" }, +] + [[package]] name = "ai-edge-litert" version = "2.1.2" @@ -2474,6 +2485,7 @@ source = { editable = "." } dependencies = [ { name = "accelerate", marker = "sys_platform == 'linux'" }, { name = "ace-step", marker = "sys_platform == 'linux'" }, + { name = "adtof-pytorch", marker = "sys_platform == 'linux'" }, { name = "ai-edge-litert", marker = "sys_platform == 'linux'" }, { name = "audio-separator", marker = "sys_platform == 'linux'" }, { name = "bs-roformer", marker = "sys_platform == 'linux'" }, @@ -2544,6 +2556,7 @@ dev = [ requires-dist = [ { name = "accelerate", specifier = ">=0.30.0" }, { name = "ace-step", editable = "Ace-Step-Wrangler/vendor/ACE-Step-1.5" }, + { name = "adtof-pytorch", git = "https://github.com/xavriley/ADTOF-pytorch?rev=85c192e" }, { name = "ai-edge-litert", specifier = ">=1.1.0" }, { name = "audio-separator", directory = "vendor/python-audio-separator" }, { name = "bs-roformer", specifier = ">=0.3.2" }, From d480197fd0d6fc34adeb79ffa4acd63c085d5cad Mon Sep 17 00:00:00 2001 From: kurt Date: Sat, 21 Mar 2026 08:42:11 +1000 Subject: [PATCH 2/3] Wire drum routing into MIDI pipeline and API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Route drum stems through ADTOF instead of BasicPitch in the MIDI extraction pipeline. The existing FluidSynth render already handled channel 10 correctly — this closes the transcription gap. - Add ADTOF lazy loading and eviction to MidiModelLoader - Add drum detection branch to MidiPipeline with 3-stage progress - Extend ExtractRequest with adt_model parameter - Add adt_models list to /api/midi/gm-programs response --- backend/api/midi.py | 13 ++++++- models/midi_loader.py | 76 +++++++++++++++++++++++++++++++++++++- pipelines/midi_pipeline.py | 33 ++++++++++++++++- 3 files changed, 118 insertions(+), 4 deletions(-) diff --git a/backend/api/midi.py b/backend/api/midi.py index cea82a6..5b13dc3 100644 --- a/backend/api/midi.py +++ b/backend/api/midi.py @@ -14,6 +14,7 @@ from backend.services.job_manager import job_manager from backend.services.session_store import SessionStore, TrackState, get_user_session from backend.services import pipeline_manager +from models.registry import DrumMidiSpec, list_specs from utils.paths import MIDI_DIR router = APIRouter(prefix="/api/midi", tags=["midi"]) @@ -86,6 +87,14 @@ "Drums & percussion": True, } + +def _build_adt_model_list() -> list[dict]: + """Return a list of ADT model metadata dicts from the model registry.""" + return [ + {"model_id": s.model_id, "display_name": s.display_name, "tooltip": s.description} + for s in list_specs(DrumMidiSpec) + ] + # ─── SoundFont discovery & state ───────────────────────────────────────── _SF2_SEARCH_PATHS = [ @@ -119,6 +128,7 @@ class ExtractRequest(BaseModel): onset_threshold: float = 0.5 frame_threshold: float = 0.3 min_note_ms: float = 58.0 + adt_model: str = "adtof-drums" class RenderRequest(BaseModel): @@ -346,11 +356,12 @@ def get_midi_stems(session: SessionStore = Depends(get_user_session)) -> dict: @router.get("/gm-programs") def get_gm_programs() -> dict: - """Return list of all 128 GM program names and smart defaults per stem label.""" + """Return list of all 128 GM program names, smart defaults per stem label, and ADT models.""" return { "programs": GM_PROGRAMS, "defaults": STEM_DEFAULT_PROGRAM, "drum_stems": STEM_IS_DRUM, + "adt_models": _build_adt_model_list(), } diff --git a/models/midi_loader.py b/models/midi_loader.py index 112bcd7..254a533 100644 --- a/models/midi_loader.py +++ b/models/midi_loader.py @@ -56,6 +56,7 @@ def __init__(self) -> None: self._bp_loader = BasicPitchModelLoader() self._model: Any | None = None # BasicPitch TF model self._whisper_model: Any | None = None # faster-whisper WhisperModel + self._adtof_backend: Any | None = None # AdtofBackend (lazy-loaded) # ------------------------------------------------------------------ # Lifecycle @@ -93,12 +94,41 @@ def is_loaded(self) -> bool: return self._model is not None def evict(self) -> None: - """Release both models from memory and trigger GC.""" + """Release all models from memory and trigger GC.""" self._bp_loader.evict() self._model = None self._whisper_model = None + self.evict_drum_model() log.debug("MidiModelLoader: models evicted.") + # ------------------------------------------------------------------ + # Internal: ADTOF lazy loader + # ------------------------------------------------------------------ + + def _ensure_adtof(self) -> Any: + """Load the ADTOF drum transcription backend on first use.""" + if self._adtof_backend is not None: + return self._adtof_backend + try: + from pipelines.adtof_backend import AdtofBackend + except ImportError as exc: + raise ModelLoadError( + "ADTOF backend is not available.", + model_name="adtof-drums", + ) from exc + backend = AdtofBackend() + backend.load() + self._adtof_backend = backend + log.info("MidiModelLoader: ADTOF drum backend ready.") + return self._adtof_backend + + def evict_drum_model(self) -> None: + """Evict the ADTOF backend only, leaving BasicPitch intact.""" + if self._adtof_backend is not None: + self._adtof_backend.evict() + self._adtof_backend = None + log.debug("MidiModelLoader: ADTOF backend evicted.") + # ------------------------------------------------------------------ # Internal: Whisper lazy loader # ------------------------------------------------------------------ @@ -379,3 +409,47 @@ def convert_vocal_to_midi( path.name, len(events), len(lyrics), len(segments), ) return events, lyrics + + def convert_drum_to_midi( + self, + path: pathlib.Path, + *, + duration: float = 0.0, + ) -> list[NoteEvent]: + """Transcribe a drum stem to GM note events via ADTOF. + + Parameters + ---------- + path: + Drum audio file (must be 44100 Hz). + duration: + If positive, clip events to this many seconds. + + Returns + ------- + list[NoteEvent] + GM channel-10 note events from ADTOF_5CLASS_GM_NOTE. + + Raises + ------ + :class:`~utils.errors.ModelLoadError` + If the ADTOF backend cannot be loaded. + :class:`~utils.errors.PipelineExecutionError` + If drum transcription fails. + """ + backend = self._ensure_adtof() + try: + events = backend.predict(path) + except PipelineExecutionError: + raise + except Exception as exc: + raise PipelineExecutionError( + f"Drum transcription failed for '{path.name}': {exc}", + pipeline_name="midi", + ) from exc + + if duration > 0.0: + events = [(s, min(e, duration), p, v) for s, e, p, v in events if s < duration] + + log.debug("convert_drum_to_midi: %s -> %d drum events", path.name, len(events)) + return events diff --git a/pipelines/midi_pipeline.py b/pipelines/midi_pipeline.py index 8558aa7..af6d6d7 100644 --- a/pipelines/midi_pipeline.py +++ b/pipelines/midi_pipeline.py @@ -53,6 +53,14 @@ "vocals", "Singing voice", }) +# Stem labels that represent drum / percussion tracks — routed to ADTOF +# drum transcription instead of BasicPitch. +# Must match STEM_IS_DRUM keys in backend/api/midi.py. +_DRUM_STEM_LABELS: frozenset[str] = frozenset({ + "drums", # Demucs htdemucs / mdx_extra output label + "Drums & percussion", # BS-Roformer jarredou-6stem / zfturbo-4stem output label +}) + # --------------------------------------------------------------------------- # Configuration @@ -319,6 +327,16 @@ def run(self, stems: dict[str, pathlib.Path]) -> "MidiResult": ) if lyrics: track_lyrics[label] = lyrics + + elif label in _DRUM_STEM_LABELS: + log.info("MidiPipeline: routing '%s' to drum ADT path.", label) + self._report(base_pct + 2.0) # stage 1: before loading + notes = self._loader.convert_drum_to_midi( + path, duration=cfg.duration_seconds, + ) + self._report(base_pct + 10.0) # stage 2: after loading/predict + self._report(base_pct + (1.0 / total) * 70.0) # stage 3: done + else: notes = self._loader.convert_audio_to_midi( path, @@ -335,10 +353,19 @@ def run(self, stems: dict[str, pathlib.Path]) -> "MidiResult": # Build per-stem MIDI object in memory (not written to disk yet) stem_lyrics = track_lyrics.get(label) - stem_midi_data[label] = self._build_stem_midi(label, notes, cfg, stem_lyrics) + stem_midi_data[label] = self._build_stem_midi( + label, notes, cfg, stem_lyrics, + is_drum=(label in _DRUM_STEM_LABELS), + ) self._report(base_pct + (1.0 / total) * 70.0) + # Evict ADTOF backend after processing all drum stems to free GPU memory. + # AdtofBackend.evict() handles torch.cuda.empty_cache() internally — + # no backend/api/midi.py modification needed. + if any(label in _DRUM_STEM_LABELS for label in stems): + self._loader.evict_drum_model() + else: # Text-only: generate a diatonic chord progression self._report(20.0) @@ -418,10 +445,12 @@ def _build_stem_midi( notes: list[NoteEvent], cfg: MidiConfig, lyrics: list[LyricEvent] | None = None, + is_drum: bool = False, ) -> Any: """Build and return a PrettyMIDI object for *stem_name* (no disk write).""" return notes_to_midi( - notes, ticks_per_beat=_TICKS_PER_BEAT, tempo_bpm=cfg.bpm, lyrics=lyrics + notes, ticks_per_beat=_TICKS_PER_BEAT, tempo_bpm=cfg.bpm, + lyrics=lyrics, is_drum=is_drum, ) def write_stem_midi( From 520d8eb9b7c8957025521fd5691d6bf781af6f81 Mon Sep 17 00:00:00 2001 From: kurt Date: Sat, 21 Mar 2026 08:42:30 +1000 Subject: [PATCH 3/3] Add ADT model selector UI and comprehensive tests Show an ADT model dropdown in the MIDI panel when drum stems are selected. Add tests covering drum map, ADTOF backend protocol, registry spec, pipeline routing, loader integration, and API endpoint behavior. --- .gitignore | 3 + frontend/components/midi.js | 31 ++- tests/test_adtof_backend.py | 349 ++++++++++++++++++++++++++++ tests/test_drum_map.py | 58 +++++ tests/test_drum_midi_integration.py | 134 +++++++++++ tests/test_drum_midi_spec.py | 81 +++++++ tests/test_midi_loader_drum.py | 161 +++++++++++++ tests/test_midi_pipeline_routing.py | 212 +++++++++++++++++ tests/test_notes_to_midi_drum.py | 69 ++++++ 9 files changed, 1097 insertions(+), 1 deletion(-) create mode 100644 tests/test_adtof_backend.py create mode 100644 tests/test_drum_map.py create mode 100644 tests/test_drum_midi_integration.py create mode 100644 tests/test_drum_midi_spec.py create mode 100644 tests/test_midi_loader_drum.py create mode 100644 tests/test_midi_pipeline_routing.py create mode 100644 tests/test_notes_to_midi_drum.py diff --git a/.gitignore b/.gitignore index a987ae0..386a443 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,9 @@ # StemForge .gitignore # ============================================================================= +# GSD planning (local-only, not committed) +.planning/ + # ----------------------------------------------------------------------------- # Python # ----------------------------------------------------------------------------- diff --git a/frontend/components/midi.js b/frontend/components/midi.js index 94ef5a6..21239e5 100644 --- a/frontend/components/midi.js +++ b/frontend/components/midi.js @@ -15,6 +15,7 @@ function clearChildren(elem) { let gmPrograms = []; let stemDefaults = {}; let drumStems = {}; +let adtModels = []; /** LilyPond availability (checked on init). */ let _lilypondAvailable = false; @@ -45,6 +46,14 @@ export function initMidi() { ), ); + const adtGroup = el('div', { className: 'form-group hidden', id: 'midi-adt-group' }, + el('label', {}, 'ADT Model'), + el('select', { id: 'midi-adt-model' }), + el('p', { className: 'text-dim', style: { fontSize: '12px', margin: '4px 0 0' } }, + 'Best results with acoustic drums. Electronic/programmed drums may have lower accuracy.', + ), + ); + const keyGroup = el('div', { className: 'form-group' }, el('label', {}, 'Key'), el('select', { id: 'midi-key' }, @@ -104,7 +113,7 @@ export function initMidi() { const importInput = el('input', { type: 'file', accept: '.mid,.midi', style: { display: 'none' }, id: 'midi-import-input' }); const importBtn = el('button', { className: 'btn btn-sm', id: 'midi-import' }, 'Import MIDI file'); - left.append(stemSection, keyGroup, bpmGroup, tsGroup, onsetGroup, frameGroup, sf2Group, extractBtn, importInput, importBtn); + left.append(stemSection, adtGroup, keyGroup, bpmGroup, tsGroup, onsetGroup, frameGroup, sf2Group, extractBtn, importInput, importBtn); // ─── Right: results ─── const right = el('div', { className: 'col-right' }); @@ -158,6 +167,8 @@ export function initMidi() { document.getElementById('midi-start').disabled = false; }); + document.getElementById('midi-stems').addEventListener('change', syncAdtGroupVisibility); + // Load GM programs, current soundfont, and check LilyPond on init loadGmPrograms(); loadCurrentSoundfont(); @@ -170,6 +181,14 @@ async function loadGmPrograms() { gmPrograms = data.programs || []; stemDefaults = data.defaults || {}; drumStems = data.drum_stems || {}; + adtModels = data.adt_models || []; + const adtSel = document.getElementById('midi-adt-model'); + if (adtSel) { + clearChildren(adtSel); + for (const m of adtModels) { + adtSel.appendChild(el('option', { value: m.model_id, title: m.tooltip || '' }, m.display_name)); + } + } } catch { /* fail silently, will use defaults */ } } @@ -244,6 +263,15 @@ function populateStemCheckboxes(stemPaths) { ), ); } + syncAdtGroupVisibility(); +} + +function syncAdtGroupVisibility() { + const hasDrum = Array.from( + document.querySelectorAll('#midi-stems input[type="checkbox"]:checked') + ).some(cb => isDrumStem(cb.value)); + const adtGroup = document.getElementById('midi-adt-group'); + if (adtGroup) adtGroup.classList.toggle('hidden', !hasDrum); } async function startExtraction() { @@ -268,6 +296,7 @@ async function startExtraction() { time_signature: document.getElementById('midi-ts').value, onset_threshold: parseFloat(document.getElementById('midi-onset').value), frame_threshold: parseFloat(document.getElementById('midi-frame').value), + adt_model: document.getElementById('midi-adt-model')?.value || 'adtof-drums', }), }); diff --git a/tests/test_adtof_backend.py b/tests/test_adtof_backend.py new file mode 100644 index 0000000..25550d0 --- /dev/null +++ b/tests/test_adtof_backend.py @@ -0,0 +1,349 @@ +"""Tests for AdtofBackendProtocol and AdtofBackend load/evict lifecycle. + +RED phase: All tests expected to fail until pipelines/adtof_backend.py is implemented. +""" +from __future__ import annotations + +import pathlib +from unittest.mock import MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Protocol structural checks +# --------------------------------------------------------------------------- + + +def test_protocol_structural_check() -> None: + """A duck-typed class with load/predict/evict passes isinstance check.""" + from pipelines.adtof_backend import AdtofBackendProtocol + + class DummyBackend: + def load(self, device: str = "cpu") -> None: + ... + + def predict(self, audio_path: pathlib.Path) -> list: + return [] + + def evict(self) -> None: + ... + + assert isinstance(DummyBackend(), AdtofBackendProtocol) is True + + +def test_protocol_rejects_incomplete() -> None: + """A class missing predict() fails isinstance check against AdtofBackendProtocol.""" + from pipelines.adtof_backend import AdtofBackendProtocol + + class IncompleteBackend: + def load(self, device: str = "cpu") -> None: + ... + + def evict(self) -> None: + ... + + assert isinstance(IncompleteBackend(), AdtofBackendProtocol) is False + + +# --------------------------------------------------------------------------- +# Load / evict lifecycle +# --------------------------------------------------------------------------- + +_MOCK_PATCHES = [ + "adtof_pytorch.calculate_n_bins", + "adtof_pytorch.create_frame_rnn_model", + "adtof_pytorch.get_default_weights_path", + "adtof_pytorch.load_pytorch_weights", +] + + +def _make_mock_model() -> MagicMock: + mock_model = MagicMock() + mock_model.eval.return_value = mock_model + mock_model.to.return_value = mock_model + mock_model.cpu.return_value = mock_model + return mock_model + + +def test_load_creates_model() -> None: + """After load(), self._model is not None.""" + from pipelines.adtof_backend import AdtofBackend + + mock_model = _make_mock_model() + + with ( + patch("adtof_pytorch.calculate_n_bins", return_value=168), + patch("adtof_pytorch.create_frame_rnn_model", return_value=mock_model), + patch("adtof_pytorch.get_default_weights_path", return_value=pathlib.Path("/fake/weights.pth")), + patch("adtof_pytorch.load_pytorch_weights", return_value=mock_model), + ): + backend = AdtofBackend() + backend.load(device="cpu") + + assert backend._model is not None + + +def test_load_sets_eval_mode() -> None: + """After load(), model.eval() was called.""" + from pipelines.adtof_backend import AdtofBackend + + mock_model = _make_mock_model() + + with ( + patch("adtof_pytorch.calculate_n_bins", return_value=168), + patch("adtof_pytorch.create_frame_rnn_model", return_value=mock_model), + patch("adtof_pytorch.get_default_weights_path", return_value=pathlib.Path("/fake/weights.pth")), + patch("adtof_pytorch.load_pytorch_weights", return_value=mock_model), + ): + backend = AdtofBackend() + backend.load(device="cpu") + + assert mock_model.eval.called is True + + +def test_evict_clears_model() -> None: + """After load() then evict(), self._model is None.""" + from pipelines.adtof_backend import AdtofBackend + + mock_model = _make_mock_model() + + with ( + patch("adtof_pytorch.calculate_n_bins", return_value=168), + patch("adtof_pytorch.create_frame_rnn_model", return_value=mock_model), + patch("adtof_pytorch.get_default_weights_path", return_value=pathlib.Path("/fake/weights.pth")), + patch("adtof_pytorch.load_pytorch_weights", return_value=mock_model), + ): + backend = AdtofBackend() + backend.load(device="cpu") + backend.evict() + + assert backend._model is None + + +def test_load_after_evict() -> None: + """After evict(), calling load() again creates a new model. + + create_frame_rnn_model should be called twice (once per load call). + """ + from pipelines.adtof_backend import AdtofBackend + + mock_model = _make_mock_model() + mock_create = MagicMock(return_value=mock_model) + + with ( + patch("adtof_pytorch.calculate_n_bins", return_value=168), + patch("adtof_pytorch.create_frame_rnn_model", mock_create), + patch("adtof_pytorch.get_default_weights_path", return_value=pathlib.Path("/fake/weights.pth")), + patch("adtof_pytorch.load_pytorch_weights", return_value=mock_model), + ): + backend = AdtofBackend() + backend.load(device="cpu") + backend.evict() + backend.load(device="cpu") + + assert backend._model is not None + assert mock_create.call_count == 2 + + +def test_load_wraps_errors() -> None: + """If create_frame_rnn_model raises, load() wraps it in ModelLoadError.""" + from pipelines.adtof_backend import AdtofBackend + from utils.errors import ModelLoadError + + with ( + patch("adtof_pytorch.calculate_n_bins", return_value=168), + patch("adtof_pytorch.create_frame_rnn_model", side_effect=RuntimeError("bad weights")), + patch("adtof_pytorch.get_default_weights_path", return_value=pathlib.Path("/fake/weights.pth")), + ): + backend = AdtofBackend() + with pytest.raises(ModelLoadError): + backend.load(device="cpu") + + +# --------------------------------------------------------------------------- +# predict() — sample rate guard, NoteEvent conversion, correctness +# --------------------------------------------------------------------------- + +import builtins +import numpy as np +import torch + + +def _make_predict_backend() -> "AdtofBackend": + """Return an AdtofBackend with _model and _device pre-set (skip load()).""" + from pipelines.adtof_backend import AdtofBackend + + mock_model = MagicMock() + # forward pass returns (1, 10, 5) sigmoid activations — all 0.9 + mock_model.return_value = torch.full((1, 10, 5), 0.9) + backend = AdtofBackend() + backend._model = mock_model + backend._device = "cpu" + return backend + + +def _make_sf_info(samplerate: int = 44100) -> MagicMock: + info = MagicMock() + info.samplerate = samplerate + return info + + +def _make_audio_tensor() -> torch.Tensor: + return torch.zeros(1, 10, 168, 1) + + +def _make_peak_picker_mock(peaks: dict) -> MagicMock: + """Return a mock PeakPicker class whose instance .pick() returns [peaks].""" + mock_picker_instance = MagicMock() + mock_picker_instance.pick.return_value = [peaks] + mock_picker_cls = MagicMock(return_value=mock_picker_instance) + return mock_picker_cls + + +_KNOWN_PEAKS = {35: [0.5, 1.2], 38: [0.3], 47: [0.8], 42: [1.0, 1.5], 49: []} + + +def test_predict_rejects_wrong_sample_rate() -> None: + """predict() raises InvalidInputError when audio is not 44100 Hz.""" + from utils.errors import InvalidInputError + + backend = _make_predict_backend() + with patch("soundfile.info", return_value=_make_sf_info(22050)): + with pytest.raises(InvalidInputError) as exc_info: + backend.predict(pathlib.Path("/fake/drums.wav")) + assert "44100" in str(exc_info.value) + + +def test_predict_no_disk_writes(tmp_path: pathlib.Path, monkeypatch: pytest.MonkeyPatch) -> None: + """predict() with mocked model does not create any files.""" + backend = _make_predict_backend() + monkeypatch.chdir(tmp_path) + + # Track open() calls — raise if any write-mode open occurs + real_open = builtins.open + + def guarded_open(file, mode="r", *args, **kwargs): + if "w" in str(mode) or "x" in str(mode): + raise AssertionError(f"Unexpected write-mode open({file!r}, {mode!r})") + return real_open(file, mode, *args, **kwargs) + + mock_picker_cls = _make_peak_picker_mock(_KNOWN_PEAKS) + + with ( + patch("soundfile.info", return_value=_make_sf_info(44100)), + patch("adtof_pytorch.load_audio_for_model", return_value=_make_audio_tensor()), + patch("adtof_pytorch.PeakPicker", mock_picker_cls), + patch("adtof_pytorch.FRAME_RNN_THRESHOLDS", {}), + patch("adtof_pytorch.LABELS_5", [35, 38, 47, 42, 49]), + patch("builtins.open", side_effect=guarded_open), + ): + backend.predict(pathlib.Path("/fake/drums.wav")) + + assert list(tmp_path.iterdir()) == [] + + +def test_note_events_gm_notes_only() -> None: + """All NoteEvent pitches are exclusively in {35, 38, 42, 47, 49}.""" + backend = _make_predict_backend() + mock_picker_cls = _make_peak_picker_mock(_KNOWN_PEAKS) + + with ( + patch("soundfile.info", return_value=_make_sf_info()), + patch("adtof_pytorch.load_audio_for_model", return_value=_make_audio_tensor()), + patch("adtof_pytorch.PeakPicker", mock_picker_cls), + patch("adtof_pytorch.FRAME_RNN_THRESHOLDS", {}), + patch("adtof_pytorch.LABELS_5", [35, 38, 47, 42, 49]), + ): + events = backend.predict(pathlib.Path("/fake/drums.wav")) + + assert len(events) > 0 + assert all(e[2] in {35, 38, 42, 47, 49} for e in events) + + +def test_non_sequential_label_mapping() -> None: + """Tom (GM 47) and hi-hat (GM 42) are mapped correctly (non-sequential).""" + backend = _make_predict_backend() + # Only tom=47 at t=1.0 and hi-hat=42 at t=2.0 + peaks = {47: [1.0], 42: [2.0]} + mock_picker_cls = _make_peak_picker_mock(peaks) + + with ( + patch("soundfile.info", return_value=_make_sf_info()), + patch("adtof_pytorch.load_audio_for_model", return_value=_make_audio_tensor()), + patch("adtof_pytorch.PeakPicker", mock_picker_cls), + patch("adtof_pytorch.FRAME_RNN_THRESHOLDS", {}), + patch("adtof_pytorch.LABELS_5", [35, 38, 47, 42, 49]), + ): + events = backend.predict(pathlib.Path("/fake/drums.wav")) + + assert len(events) == 2 + tom_event = next(e for e in events if abs(e[0] - 1.0) < 1e-9) + hihat_event = next(e for e in events if abs(e[0] - 2.0) < 1e-9) + assert tom_event[2] == 47, f"Expected GM 47 (tom) at t=1.0, got {tom_event[2]}" + assert hihat_event[2] == 42, f"Expected GM 42 (hi-hat) at t=2.0, got {hihat_event[2]}" + + +def test_note_event_duration_and_velocity() -> None: + """Each NoteEvent has 60ms duration and velocity 100.""" + backend = _make_predict_backend() + mock_picker_cls = _make_peak_picker_mock({35: [0.5], 38: [1.0]}) + + with ( + patch("soundfile.info", return_value=_make_sf_info()), + patch("adtof_pytorch.load_audio_for_model", return_value=_make_audio_tensor()), + patch("adtof_pytorch.PeakPicker", mock_picker_cls), + patch("adtof_pytorch.FRAME_RNN_THRESHOLDS", {}), + patch("adtof_pytorch.LABELS_5", [35, 38, 47, 42, 49]), + ): + events = backend.predict(pathlib.Path("/fake/drums.wav")) + + assert len(events) == 2 + for e in events: + assert e[3] == 100, f"Expected velocity=100, got {e[3]}" + assert abs((e[1] - e[0]) - 0.06) < 1e-9, f"Expected duration=0.06s, got {e[1]-e[0]}" + + +def test_note_events_sorted_by_onset() -> None: + """Output NoteEvents are sorted by start time ascending.""" + backend = _make_predict_backend() + # Out-of-order across classes: kick at t=2.0, snare at t=0.5, hihat at t=1.0 + mock_picker_cls = _make_peak_picker_mock({35: [2.0], 38: [0.5], 42: [1.0]}) + + with ( + patch("soundfile.info", return_value=_make_sf_info()), + patch("adtof_pytorch.load_audio_for_model", return_value=_make_audio_tensor()), + patch("adtof_pytorch.PeakPicker", mock_picker_cls), + patch("adtof_pytorch.FRAME_RNN_THRESHOLDS", {}), + patch("adtof_pytorch.LABELS_5", [35, 38, 47, 42, 49]), + ): + events = backend.predict(pathlib.Path("/fake/drums.wav")) + + onsets = [e[0] for e in events] + assert onsets == sorted(onsets), f"Events not sorted by onset: {onsets}" + + +def test_predict_raises_pipeline_error() -> None: + """If model forward pass raises, predict() wraps it in PipelineExecutionError.""" + from utils.errors import PipelineExecutionError + + backend = _make_predict_backend() + backend._model = MagicMock(side_effect=RuntimeError("CUDA OOM")) + + with ( + patch("soundfile.info", return_value=_make_sf_info()), + patch("adtof_pytorch.load_audio_for_model", return_value=_make_audio_tensor()), + patch("adtof_pytorch.PeakPicker", _make_peak_picker_mock({})), + patch("adtof_pytorch.FRAME_RNN_THRESHOLDS", {}), + patch("adtof_pytorch.LABELS_5", [35, 38, 47, 42, 49]), + ): + with pytest.raises(PipelineExecutionError): + backend.predict(pathlib.Path("/fake/drums.wav")) + + +def test_peaks_to_note_events_empty() -> None: + """_peaks_to_note_events({}) returns empty list.""" + from pipelines.adtof_backend import _peaks_to_note_events + + result = _peaks_to_note_events({}) + assert result == [] diff --git a/tests/test_drum_map.py b/tests/test_drum_map.py new file mode 100644 index 0000000..5dc91e4 --- /dev/null +++ b/tests/test_drum_map.py @@ -0,0 +1,58 @@ +"""Unit tests for utils.drum_map — GM drum constants and gm_note(). + +These tests are written FIRST (TDD RED phase) against the not-yet-existing +utils/drum_map module. All tests should FAIL with ImportError until +utils/drum_map.py is created (Task 2). +""" +from __future__ import annotations + + +def test_adtof_drum_class_enum() -> None: + """AdtofDrumClass enum has exactly 5 members with correct integer values.""" + from utils.drum_map import AdtofDrumClass + + assert AdtofDrumClass.KICK == 0 + assert AdtofDrumClass.SNARE == 1 + assert AdtofDrumClass.TOM == 2 + assert AdtofDrumClass.HI_HAT == 3 + assert AdtofDrumClass.CYMBAL == 4 + assert len(AdtofDrumClass) == 5 + + +def test_drum_map_adtof_5class() -> None: + """ADTOF_5CLASS_GM_NOTE maps class indices to exact GM note numbers.""" + from utils.drum_map import ADTOF_5CLASS_GM_NOTE + + assert ADTOF_5CLASS_GM_NOTE == {0: 35, 1: 38, 2: 47, 3: 42, 4: 49} + + +def test_gm_note_non_sequential() -> None: + """gm_note() preserves the non-sequential ordering: tom(47) > hi-hat(42).""" + from utils.drum_map import AdtofDrumClass, gm_note + + # Tom (index 2) maps to GM note 47. + assert gm_note(AdtofDrumClass.TOM) == 47 + # Hi-hat (index 3) maps to GM note 42 — numerically lower than tom despite higher index. + assert gm_note(AdtofDrumClass.HI_HAT) == 42 + + +def test_gm_note_all_classes() -> None: + """gm_note() returns the correct GM note number for every AdtofDrumClass member.""" + from utils.drum_map import AdtofDrumClass, gm_note + + assert gm_note(AdtofDrumClass.KICK) == 35 + assert gm_note(AdtofDrumClass.SNARE) == 38 + assert gm_note(AdtofDrumClass.TOM) == 47 + assert gm_note(AdtofDrumClass.HI_HAT) == 42 + assert gm_note(AdtofDrumClass.CYMBAL) == 49 + + +def test_gm_drum_names() -> None: + """GM_DRUM_NAMES maps GM note numbers to human-readable percussion names.""" + from utils.drum_map import GM_DRUM_NAMES + + assert GM_DRUM_NAMES[35] == "Acoustic Bass Drum" + assert GM_DRUM_NAMES[38] == "Acoustic Snare" + assert GM_DRUM_NAMES[47] == "Mid Tom" + assert GM_DRUM_NAMES[42] == "Closed Hi-Hat" + assert GM_DRUM_NAMES[49] == "Crash Cymbal 1" diff --git a/tests/test_drum_midi_integration.py b/tests/test_drum_midi_integration.py new file mode 100644 index 0000000..dd8fb6f --- /dev/null +++ b/tests/test_drum_midi_integration.py @@ -0,0 +1,134 @@ +"""Integration tests for drum MIDI channel routing and API endpoint. + +Tests: + - test_drum_stem_writes_channel_10: drum stem -> is_drum=True, survives round-trip + - test_non_drum_stem_writes_channel_1: non-drum stem -> is_drum=False, survives round-trip + - test_gm_programs_includes_adt_models: /api/midi/gm-programs returns adt_models list + +RED phase notes: + - test_drum_stem_writes_channel_10 and test_non_drum_stem_writes_channel_1 should + already pass (channel routing is in place from Phase 3). + - test_gm_programs_includes_adt_models MUST fail until backend/api/midi.py is + extended to include the adt_models key. +""" +from __future__ import annotations + +import pathlib +import tempfile + +import pretty_midi +import pytest +from unittest.mock import MagicMock + +from pipelines.midi_pipeline import MidiPipeline, MidiConfig + + +# --------------------------------------------------------------------------- +# Test infrastructure — reused from test_midi_pipeline_routing.py pattern +# --------------------------------------------------------------------------- + + +def _make_pipeline_with_mock_loader(): + """Create a MidiPipeline with a mock loader, bypassing load_model().""" + pipeline = MidiPipeline() + pipeline._config = MidiConfig() + mock_loader = MagicMock() + # convert_drum_to_midi returns one kick note event + mock_loader.convert_drum_to_midi.return_value = [(0.1, 0.16, 35, 100)] + # convert_audio_to_midi returns one pitched note + mock_loader.convert_audio_to_midi.return_value = [(0.1, 0.4, 60, 80)] + # convert_vocal_to_midi returns (notes, lyrics) + mock_loader.convert_vocal_to_midi.return_value = ([(0.1, 0.5, 60, 80)], []) + pipeline._loader = mock_loader + pipeline.is_loaded = True + return pipeline, mock_loader + + +# --------------------------------------------------------------------------- +# Channel routing integration tests (with round-trip serialization) +# --------------------------------------------------------------------------- + + +def test_drum_stem_writes_channel_10(tmp_path: pathlib.Path) -> None: + """Drum stem produces is_drum=True in-memory and after MIDI file round-trip.""" + pipeline, _ = _make_pipeline_with_mock_loader() + tmp_file = tmp_path / "drums.wav" + tmp_file.write_bytes(b"\x00") + + result = pipeline.run({"drums": tmp_file}) + + midi_obj = result.stem_midi_data["drums"] + assert len(midi_obj.instruments) > 0, "MIDI object should have at least one instrument" + + # In-memory assertion + assert midi_obj.instruments[0].is_drum is True, ( + "Drum stem MIDI should have is_drum=True on instrument[0] in memory" + ) + + # Round-trip: write to .mid file, read back, verify is_drum survives + mid_path = tmp_path / "drums_roundtrip.mid" + midi_obj.write(str(mid_path)) + reloaded = pretty_midi.PrettyMIDI(str(mid_path)) + assert len(reloaded.instruments) > 0, "Round-tripped MIDI should have at least one instrument" + assert reloaded.instruments[0].is_drum is True, ( + "Drum stem MIDI should have is_drum=True after MIDI file round-trip" + ) + + +def test_non_drum_stem_writes_channel_1(tmp_path: pathlib.Path) -> None: + """Non-drum stem produces is_drum=False in-memory and after MIDI file round-trip.""" + pipeline, _ = _make_pipeline_with_mock_loader() + tmp_file = tmp_path / "bass.wav" + tmp_file.write_bytes(b"\x00") + + result = pipeline.run({"bass": tmp_file}) + + midi_obj = result.stem_midi_data["bass"] + assert len(midi_obj.instruments) > 0, "MIDI object should have at least one instrument" + + # In-memory assertion + assert midi_obj.instruments[0].is_drum is False, ( + "Non-drum stem MIDI should have is_drum=False on instrument[0] in memory" + ) + + # Round-trip: write to .mid file, read back, verify is_drum survives + mid_path = tmp_path / "bass_roundtrip.mid" + midi_obj.write(str(mid_path)) + reloaded = pretty_midi.PrettyMIDI(str(mid_path)) + assert len(reloaded.instruments) > 0, "Round-tripped MIDI should have at least one instrument" + assert reloaded.instruments[0].is_drum is False, ( + "Non-drum stem MIDI should have is_drum=False after MIDI file round-trip" + ) + + +# --------------------------------------------------------------------------- +# API endpoint test: /api/midi/gm-programs must include adt_models +# --------------------------------------------------------------------------- + + +def test_gm_programs_includes_adt_models() -> None: + """GET /api/midi/gm-programs returns 200 with adt_models key containing adtof-drums.""" + from fastapi.testclient import TestClient + from backend.main import app + + client = TestClient(app) + response = client.get("/api/midi/gm-programs") + + assert response.status_code == 200, ( + f"Expected 200 from /api/midi/gm-programs, got {response.status_code}" + ) + + data = response.json() + assert "adt_models" in data, ( + f"Response missing 'adt_models' key. Keys present: {list(data.keys())}" + ) + + adt_models = data["adt_models"] + assert isinstance(adt_models, list), ( + f"Expected adt_models to be a list, got {type(adt_models)}" + ) + + model_ids = [entry.get("model_id") for entry in adt_models] + assert "adtof-drums" in model_ids, ( + f"Expected 'adtof-drums' in adt_models list. Got: {model_ids}" + ) diff --git a/tests/test_drum_midi_spec.py b/tests/test_drum_midi_spec.py new file mode 100644 index 0000000..6224773 --- /dev/null +++ b/tests/test_drum_midi_spec.py @@ -0,0 +1,81 @@ +""" +Tests for DrumMidiSpec registry entry and ADTOF drum model registration. + +Phase 01 Plan 02 — RED phase. +""" +from __future__ import annotations + + +def test_drum_midi_spec_importable() -> None: + """DrumMidiSpec is importable from models.registry and is a ModelSpec subclass.""" + from models.registry import DrumMidiSpec, ModelSpec + + assert isinstance(DrumMidiSpec, type) + assert issubclass(DrumMidiSpec, ModelSpec) + + +def test_adtof_registry_entry() -> None: + """list_specs(DrumMidiSpec) returns an entry for 'adtof-drums' with correct fields.""" + from models.registry import DrumMidiSpec, list_specs + + specs = list_specs(DrumMidiSpec) + assert len(specs) >= 1 + + match = next((s for s in specs if s.model_id == "adtof-drums"), None) + assert match is not None, "No entry with model_id='adtof-drums' found" + + spec = match # type: ignore[assignment] + assert spec.class_count == 5 + assert spec.class_labels == ("kick", "snare", "tom", "hi_hat", "cymbal") + assert spec.sample_rate == 44_100 + + +def test_adtof_capabilities() -> None: + """ADTOF registry entry has the correct capability tags.""" + from models.registry import get_spec + + spec = get_spec("adtof-drums") + assert spec.capabilities == frozenset({"transcribe", "drum_transcription", "gpu_acceleration"}) + + +def test_adtof_cache_subdir() -> None: + """ADTOF registry entry uses cache_subdir='adtof'.""" + from models.registry import get_spec + + spec = get_spec("adtof-drums") + assert spec.cache_subdir == "adtof" + + +def test_adtof_checkpoint_url_empty() -> None: + """ADTOF checkpoint_url is empty (weights bundled in pip package).""" + from models.registry import get_spec + + spec = get_spec("adtof-drums") + assert spec.checkpoint_url == "" + + +def test_get_loader_kwargs_adtof() -> None: + """get_loader_kwargs('adtof-drums') returns a dict containing 'cache_dir' without raising.""" + from models.registry import get_loader_kwargs + + kw = get_loader_kwargs("adtof-drums") + assert isinstance(kw, dict) + assert "cache_dir" in kw + + +def test_get_pipeline_defaults_adtof() -> None: + """get_pipeline_defaults('adtof-drums') returns a dict containing 'class_count' without raising.""" + from models.registry import get_pipeline_defaults + + defaults = get_pipeline_defaults("adtof-drums") + assert isinstance(defaults, dict) + assert "class_count" in defaults + + +def test_get_gui_metadata_adtof() -> None: + """get_gui_metadata('adtof-drums') returns a dict containing 'tooltip' without raising.""" + from models.registry import get_gui_metadata + + meta = get_gui_metadata("adtof-drums") + assert isinstance(meta, dict) + assert "tooltip" in meta diff --git a/tests/test_midi_loader_drum.py b/tests/test_midi_loader_drum.py new file mode 100644 index 0000000..10e8351 --- /dev/null +++ b/tests/test_midi_loader_drum.py @@ -0,0 +1,161 @@ +"""Tests for MidiModelLoader ADTOF drum transcription extension. + +RED phase: Tests for _ensure_adtof(), convert_drum_to_midi(), evict_drum_model(). +All tests expected to fail until models/midi_loader.py is extended. +""" +from __future__ import annotations + +import pathlib +from unittest.mock import MagicMock, patch, sentinel + +import pytest + +from models.midi_loader import MidiModelLoader +from utils.errors import PipelineExecutionError + + +# --------------------------------------------------------------------------- +# Lazy loading +# --------------------------------------------------------------------------- + + +def test_adtof_lazy_not_loaded_at_init() -> None: + """_adtof_backend is None immediately after construction — ADTOF not loaded at init.""" + loader = MidiModelLoader() + assert loader._adtof_backend is None + + +def test_ensure_adtof_returns_backend() -> None: + """_ensure_adtof() returns an AdtofBackend instance after first call.""" + loader = MidiModelLoader() + + mock_backend_instance = MagicMock() + mock_backend_cls = MagicMock(return_value=mock_backend_instance) + + with patch("pipelines.adtof_backend.AdtofBackend", mock_backend_cls): + # Patch the import inside _ensure_adtof via the module path it imports from + with patch.dict("sys.modules", {}): + import pipelines.adtof_backend as _adtof_mod + orig_cls = getattr(_adtof_mod, "AdtofBackend", None) + _adtof_mod.AdtofBackend = mock_backend_cls # type: ignore[attr-defined] + try: + # Patch the deferred import inside the method + with patch("models.midi_loader.MidiModelLoader._ensure_adtof", wraps=None) as _mock: + pass + finally: + if orig_cls is not None: + _adtof_mod.AdtofBackend = orig_cls + + # Simpler approach: directly test by patching the deferred import + loader2 = MidiModelLoader() + mock_instance = MagicMock() + mock_cls = MagicMock(return_value=mock_instance) + + import sys + # Create a mock for pipelines.adtof_backend that has AdtofBackend + mock_adtof_module = MagicMock() + mock_adtof_module.AdtofBackend = mock_cls + with patch.dict(sys.modules, {"pipelines.adtof_backend": mock_adtof_module}): + result = loader2._ensure_adtof() + + assert result is mock_instance + mock_instance.load.assert_called_once() + + +def test_ensure_adtof_caches_instance() -> None: + """Second call to _ensure_adtof() returns the same object (cached).""" + loader = MidiModelLoader() + mock_instance = MagicMock() + mock_cls = MagicMock(return_value=mock_instance) + + import sys + mock_adtof_module = MagicMock() + mock_adtof_module.AdtofBackend = mock_cls + with patch.dict(sys.modules, {"pipelines.adtof_backend": mock_adtof_module}): + first = loader._ensure_adtof() + second = loader._ensure_adtof() + + assert first is second + # AdtofBackend() constructor called only once + assert mock_cls.call_count == 1 + + +# --------------------------------------------------------------------------- +# convert_drum_to_midi +# --------------------------------------------------------------------------- + + +def test_convert_drum_to_midi_calls_predict() -> None: + """convert_drum_to_midi(path) calls backend.predict(path) and returns its result.""" + loader = MidiModelLoader() + expected = [(0.1, 0.16, 35, 100)] + + mock_backend = MagicMock() + mock_backend.predict.return_value = expected + loader._adtof_backend = mock_backend + + result = loader.convert_drum_to_midi(pathlib.Path("dummy.wav")) + + mock_backend.predict.assert_called_once_with(pathlib.Path("dummy.wav")) + assert result == expected + + +def test_convert_drum_to_midi_wraps_exception() -> None: + """Non-PipelineExecutionError from predict() is wrapped in PipelineExecutionError.""" + loader = MidiModelLoader() + + mock_backend = MagicMock() + mock_backend.predict.side_effect = ValueError("boom") + loader._adtof_backend = mock_backend + + with pytest.raises(PipelineExecutionError) as exc_info: + loader.convert_drum_to_midi(pathlib.Path("dummy.wav")) + + assert "midi" in exc_info.value.pipeline_name # type: ignore[union-attr] + + +# --------------------------------------------------------------------------- +# evict +# --------------------------------------------------------------------------- + + +def test_evict_clears_adtof_backend() -> None: + """evict() sets _adtof_backend to None and calls backend.evict().""" + loader = MidiModelLoader() + mock_backend = MagicMock() + loader._adtof_backend = mock_backend + + loader.evict() + + assert loader._adtof_backend is None + mock_backend.evict.assert_called_once() + + +def test_evict_no_error_when_none() -> None: + """evict() with _adtof_backend=None does not raise any exception.""" + loader = MidiModelLoader() + assert loader._adtof_backend is None + + # Must not raise + loader.evict() + + +# --------------------------------------------------------------------------- +# evict_drum_model +# --------------------------------------------------------------------------- + + +def test_evict_drum_model_leaves_basicpitch() -> None: + """evict_drum_model() clears _adtof_backend only; _model (BasicPitch) is untouched.""" + loader = MidiModelLoader() + mock_backend = MagicMock() + loader._adtof_backend = mock_backend + # Simulate a loaded BasicPitch model + sentinel_bp = sentinel.basicpitch_model + loader._model = sentinel_bp # type: ignore[assignment] + + loader.evict_drum_model() + + assert loader._adtof_backend is None + mock_backend.evict.assert_called_once() + assert loader._model is sentinel_bp diff --git a/tests/test_midi_pipeline_routing.py b/tests/test_midi_pipeline_routing.py new file mode 100644 index 0000000..08be30b --- /dev/null +++ b/tests/test_midi_pipeline_routing.py @@ -0,0 +1,212 @@ +"""Tests for MidiPipeline drum stem routing branch. + +RED phase: Tests for _DRUM_STEM_LABELS, drum routing in run(), is_drum parameter +on _build_stem_midi(), 3-stage progress callbacks, and post-loop eviction. +All tests expected to fail until pipelines/midi_pipeline.py is extended. +""" +from __future__ import annotations + +import pathlib +from unittest.mock import MagicMock, call + +import pytest + +from pipelines.midi_pipeline import MidiPipeline, MidiConfig + + +# --------------------------------------------------------------------------- +# Test infrastructure +# --------------------------------------------------------------------------- + + +def _make_pipeline_with_mock_loader(): + """Create a MidiPipeline with a mock loader, bypassing load_model().""" + pipeline = MidiPipeline() + pipeline._config = MidiConfig() + mock_loader = MagicMock() + # convert_drum_to_midi returns one kick note event + mock_loader.convert_drum_to_midi.return_value = [(0.1, 0.16, 35, 100)] + # convert_audio_to_midi returns one pitched note + mock_loader.convert_audio_to_midi.return_value = [(0.1, 0.4, 60, 80)] + # convert_vocal_to_midi returns (notes, lyrics) + mock_loader.convert_vocal_to_midi.return_value = ([(0.1, 0.5, 60, 80)], []) + pipeline._loader = mock_loader + pipeline.is_loaded = True + return pipeline, mock_loader + + +# --------------------------------------------------------------------------- +# Routing tests +# --------------------------------------------------------------------------- + + +def test_drum_label_routes_to_adt(tmp_path: pathlib.Path) -> None: + """A stem labeled 'drums' calls convert_drum_to_midi(), not convert_audio_to_midi().""" + pipeline, mock_loader = _make_pipeline_with_mock_loader() + tmp_file = tmp_path / "drums.wav" + tmp_file.write_bytes(b"\x00") + + pipeline.run({"drums": tmp_file}) + + assert mock_loader.convert_drum_to_midi.called, ( + "convert_drum_to_midi() should be called for 'drums' label" + ) + assert not mock_loader.convert_audio_to_midi.called, ( + "convert_audio_to_midi() should NOT be called for 'drums' label" + ) + + +def test_roformer_drum_label_routes_to_adt(tmp_path: pathlib.Path) -> None: + """A stem labeled 'Drums & percussion' calls convert_drum_to_midi(), not convert_audio_to_midi().""" + pipeline, mock_loader = _make_pipeline_with_mock_loader() + tmp_file = tmp_path / "drums_perc.wav" + tmp_file.write_bytes(b"\x00") + + pipeline.run({"Drums & percussion": tmp_file}) + + assert mock_loader.convert_drum_to_midi.called, ( + "convert_drum_to_midi() should be called for 'Drums & percussion' label" + ) + assert not mock_loader.convert_audio_to_midi.called, ( + "convert_audio_to_midi() should NOT be called for 'Drums & percussion' label" + ) + + +def test_vocal_label_not_routed_to_drum(tmp_path: pathlib.Path) -> None: + """A stem labeled 'vocals' calls convert_vocal_to_midi(), not convert_drum_to_midi().""" + pipeline, mock_loader = _make_pipeline_with_mock_loader() + tmp_file = tmp_path / "vocals.wav" + tmp_file.write_bytes(b"\x00") + + pipeline.run({"vocals": tmp_file}) + + assert mock_loader.convert_vocal_to_midi.called, ( + "convert_vocal_to_midi() should be called for 'vocals' label" + ) + assert not mock_loader.convert_drum_to_midi.called, ( + "convert_drum_to_midi() should NOT be called for 'vocals' label" + ) + + +def test_bass_label_not_routed_to_drum(tmp_path: pathlib.Path) -> None: + """A stem labeled 'bass' calls convert_audio_to_midi(), not convert_drum_to_midi().""" + pipeline, mock_loader = _make_pipeline_with_mock_loader() + tmp_file = tmp_path / "bass.wav" + tmp_file.write_bytes(b"\x00") + + pipeline.run({"bass": tmp_file}) + + assert mock_loader.convert_audio_to_midi.called, ( + "convert_audio_to_midi() should be called for 'bass' label" + ) + assert not mock_loader.convert_drum_to_midi.called, ( + "convert_drum_to_midi() should NOT be called for 'bass' label" + ) + + +# --------------------------------------------------------------------------- +# is_drum parameter tests +# --------------------------------------------------------------------------- + + +def test_drum_stem_midi_is_drum_true(tmp_path: pathlib.Path) -> None: + """Drum stem result has instruments[0].is_drum == True.""" + pipeline, mock_loader = _make_pipeline_with_mock_loader() + tmp_file = tmp_path / "drums.wav" + tmp_file.write_bytes(b"\x00") + + result = pipeline.run({"drums": tmp_file}) + + midi_obj = result.stem_midi_data["drums"] + assert len(midi_obj.instruments) > 0, "MIDI object should have at least one instrument" + assert midi_obj.instruments[0].is_drum is True, ( + "Drum stem MIDI should have is_drum=True on instrument[0]" + ) + + +def test_non_drum_stem_midi_is_drum_false(tmp_path: pathlib.Path) -> None: + """Non-drum stem result has instruments[0].is_drum == False.""" + pipeline, mock_loader = _make_pipeline_with_mock_loader() + tmp_file = tmp_path / "bass.wav" + tmp_file.write_bytes(b"\x00") + + result = pipeline.run({"bass": tmp_file}) + + midi_obj = result.stem_midi_data["bass"] + assert len(midi_obj.instruments) > 0, "MIDI object should have at least one instrument" + assert midi_obj.instruments[0].is_drum is False, ( + "Non-drum stem MIDI should have is_drum=False on instrument[0]" + ) + + +# --------------------------------------------------------------------------- +# 3-stage progress callback test +# --------------------------------------------------------------------------- + + +def test_drum_path_reports_progress_3_stages(tmp_path: pathlib.Path) -> None: + """_report() is called at least 3 times specifically during drum stem processing. + + Per RESEARCH.md Pitfall 6, the drum branch must report at 3 stages: + (1) before _ensure_adtof() / convert_drum_to_midi() — before loading + (2) after convert_drum_to_midi() returns — after loading/predict + (3) after predict (done) — final per-stem marker + """ + pipeline, mock_loader = _make_pipeline_with_mock_loader() + tmp_file = tmp_path / "drums.wav" + tmp_file.write_bytes(b"\x00") + + progress_callback = MagicMock() + pipeline._progress_callback = progress_callback + + pipeline.run({"drums": tmp_file}) + + # Overall pipeline reports: initial 2.0, base_pct at start of stem, plus + # at least 3 drum-specific stages, plus 80.0 and 100.0 at the end. + # Total must be at least 5 (2 framing + 3 drum-specific). + assert progress_callback.call_count >= 5, ( + f"Expected at least 5 _report() calls during drum stem processing, " + f"got {progress_callback.call_count}" + ) + + # Extract all percentage values passed to the callback + reported_pcts = [c.args[0] for c in progress_callback.call_args_list] + + # The drum branch must fire 3 calls in quick succession inside the stem loop. + # All drum-branch calls occur after base_pct (5.0 for 1 stem) and before 80.0. + drum_branch_calls = [p for p in reported_pcts if 5.0 < p < 80.0] + assert len(drum_branch_calls) >= 3, ( + f"Expected at least 3 progress callbacks in the drum branch range (5.0 < pct < 80.0), " + f"got {len(drum_branch_calls)}: {drum_branch_calls}" + ) + + +# --------------------------------------------------------------------------- +# Eviction tests +# --------------------------------------------------------------------------- + + +def test_drum_evict_called_after_loop(tmp_path: pathlib.Path) -> None: + """evict_drum_model() is called after the stems loop when drum stems were present.""" + pipeline, mock_loader = _make_pipeline_with_mock_loader() + tmp_file = tmp_path / "drums.wav" + tmp_file.write_bytes(b"\x00") + + pipeline.run({"drums": tmp_file}) + + assert mock_loader.evict_drum_model.called, ( + "evict_drum_model() should be called after processing drum stems" + ) + + +def test_no_drum_evict_when_no_drums(tmp_path: pathlib.Path) -> None: + """evict_drum_model() is NOT called when no drum stems in request.""" + pipeline, mock_loader = _make_pipeline_with_mock_loader() + tmp_file = tmp_path / "bass.wav" + tmp_file.write_bytes(b"\x00") + + pipeline.run({"bass": tmp_file}) + + assert not mock_loader.evict_drum_model.called, ( + "evict_drum_model() should NOT be called when no drum stems are present" + ) diff --git a/tests/test_notes_to_midi_drum.py b/tests/test_notes_to_midi_drum.py new file mode 100644 index 0000000..8d81fe9 --- /dev/null +++ b/tests/test_notes_to_midi_drum.py @@ -0,0 +1,69 @@ +"""Unit tests for is_drum parameter and 60ms cap in utils.midi_io.notes_to_midi(). + +Written FIRST (TDD RED phase). Tests for is_drum=True behavior will fail with +TypeError until notes_to_midi() is extended with the is_drum parameter (Task 2). +""" +from __future__ import annotations + +import pytest + + +def test_notes_to_midi_drum_channel() -> None: + """notes_to_midi(events, is_drum=True) produces an instrument routed to channel 10.""" + from utils.midi_io import notes_to_midi + + result = notes_to_midi([(0.5, 1.5, 35, 100)], is_drum=True) + + assert result.instruments[0].is_drum is True + assert result.instruments[0].name == "StemForge Drums" + + +def test_notes_to_midi_regression() -> None: + """notes_to_midi() without is_drum produces identical output to pre-change behavior.""" + from utils.midi_io import notes_to_midi + + # Default call (no is_drum arg) — must match original Acoustic Grand Piano behavior. + result_default = notes_to_midi([(0.5, 1.5, 60, 100)]) + assert result_default.instruments[0].is_drum is False + assert result_default.instruments[0].program == 0 + assert result_default.instruments[0].name == "StemForge" + + # Explicit is_drum=False — must be identical to the default. + result_explicit = notes_to_midi([(0.5, 1.5, 60, 100)], is_drum=False) + assert result_explicit.instruments[0].is_drum is False + assert result_explicit.instruments[0].program == 0 + assert result_explicit.instruments[0].name == "StemForge" + + +def test_drum_note_60ms_cap() -> None: + """Drum notes longer than 60ms are clamped to exactly 60ms duration.""" + from utils.midi_io import notes_to_midi + + # Note with start=1.0, end=2.0 (1000ms) should be clamped to end=1.06. + result = notes_to_midi([(1.0, 2.0, 35, 100)], is_drum=True) + + assert len(result.instruments[0].notes) == 1 + note = result.instruments[0].notes[0] + assert note.end == pytest.approx(1.06, abs=1e-6) + + +def test_drum_note_short_passthrough() -> None: + """Drum notes already shorter than 60ms are not modified.""" + from utils.midi_io import notes_to_midi + + # Note with start=1.0, end=1.03 (30ms) should pass through unchanged. + result = notes_to_midi([(1.0, 1.03, 35, 100)], is_drum=True) + + assert len(result.instruments[0].notes) == 1 + note = result.instruments[0].notes[0] + assert note.end == pytest.approx(1.03, abs=1e-6) + + +def test_drum_degenerate_note_filtered() -> None: + """Degenerate notes (end <= start) are filtered before the 60ms cap is applied.""" + from utils.midi_io import notes_to_midi + + # Note with start=1.0, end=0.5 — end < start, must be discarded entirely. + result = notes_to_midi([(1.0, 0.5, 35, 100)], is_drum=True) + + assert len(result.instruments[0].notes) == 0