Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
# StemForge .gitignore
# =============================================================================

# GSD planning (local-only, not committed)
.planning/

# -----------------------------------------------------------------------------
# Python
# -----------------------------------------------------------------------------
Expand Down
13 changes: 12 additions & 1 deletion backend/api/midi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from backend.services.job_manager import job_manager
from backend.services.session_store import SessionStore, TrackState, get_user_session
from backend.services import pipeline_manager
from models.registry import DrumMidiSpec, list_specs
from utils.paths import MIDI_DIR

router = APIRouter(prefix="/api/midi", tags=["midi"])
Expand Down Expand Up @@ -86,6 +87,14 @@
"Drums & percussion": True,
}


def _build_adt_model_list() -> list[dict]:
"""Return a list of ADT model metadata dicts from the model registry."""
return [
{"model_id": s.model_id, "display_name": s.display_name, "tooltip": s.description}
for s in list_specs(DrumMidiSpec)
]

# ─── SoundFont discovery & state ─────────────────────────────────────────

_SF2_SEARCH_PATHS = [
Expand Down Expand Up @@ -119,6 +128,7 @@ class ExtractRequest(BaseModel):
onset_threshold: float = 0.5
frame_threshold: float = 0.3
min_note_ms: float = 58.0
adt_model: str = "adtof-drums"


class RenderRequest(BaseModel):
Expand Down Expand Up @@ -346,11 +356,12 @@ def get_midi_stems(session: SessionStore = Depends(get_user_session)) -> dict:

@router.get("/gm-programs")
def get_gm_programs() -> dict:
"""Return list of all 128 GM program names and smart defaults per stem label."""
"""Return list of all 128 GM program names, smart defaults per stem label, and ADT models."""
return {
"programs": GM_PROGRAMS,
"defaults": STEM_DEFAULT_PROGRAM,
"drum_stems": STEM_IS_DRUM,
"adt_models": _build_adt_model_list(),
}


Expand Down
31 changes: 30 additions & 1 deletion frontend/components/midi.js
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ function clearChildren(elem) {
let gmPrograms = [];
let stemDefaults = {};
let drumStems = {};
let adtModels = [];

/** LilyPond availability (checked on init). */
let _lilypondAvailable = false;
Expand Down Expand Up @@ -45,6 +46,14 @@ export function initMidi() {
),
);

const adtGroup = el('div', { className: 'form-group hidden', id: 'midi-adt-group' },
el('label', {}, 'ADT Model'),
el('select', { id: 'midi-adt-model' }),
el('p', { className: 'text-dim', style: { fontSize: '12px', margin: '4px 0 0' } },
'Best results with acoustic drums. Electronic/programmed drums may have lower accuracy.',
),
);

const keyGroup = el('div', { className: 'form-group' },
el('label', {}, 'Key'),
el('select', { id: 'midi-key' },
Expand Down Expand Up @@ -104,7 +113,7 @@ export function initMidi() {
const importInput = el('input', { type: 'file', accept: '.mid,.midi', style: { display: 'none' }, id: 'midi-import-input' });
const importBtn = el('button', { className: 'btn btn-sm', id: 'midi-import' }, 'Import MIDI file');

left.append(stemSection, keyGroup, bpmGroup, tsGroup, onsetGroup, frameGroup, sf2Group, extractBtn, importInput, importBtn);
left.append(stemSection, adtGroup, keyGroup, bpmGroup, tsGroup, onsetGroup, frameGroup, sf2Group, extractBtn, importInput, importBtn);

// ─── Right: results ───
const right = el('div', { className: 'col-right' });
Expand Down Expand Up @@ -158,6 +167,8 @@ export function initMidi() {
document.getElementById('midi-start').disabled = false;
});

document.getElementById('midi-stems').addEventListener('change', syncAdtGroupVisibility);

// Load GM programs, current soundfont, and check LilyPond on init
loadGmPrograms();
loadCurrentSoundfont();
Expand All @@ -170,6 +181,14 @@ async function loadGmPrograms() {
gmPrograms = data.programs || [];
stemDefaults = data.defaults || {};
drumStems = data.drum_stems || {};
adtModels = data.adt_models || [];
const adtSel = document.getElementById('midi-adt-model');
if (adtSel) {
clearChildren(adtSel);
for (const m of adtModels) {
adtSel.appendChild(el('option', { value: m.model_id, title: m.tooltip || '' }, m.display_name));
}
}
} catch { /* fail silently, will use defaults */ }
}

Expand Down Expand Up @@ -244,6 +263,15 @@ function populateStemCheckboxes(stemPaths) {
),
);
}
syncAdtGroupVisibility();
}

function syncAdtGroupVisibility() {
const hasDrum = Array.from(
document.querySelectorAll('#midi-stems input[type="checkbox"]:checked')
).some(cb => isDrumStem(cb.value));
const adtGroup = document.getElementById('midi-adt-group');
if (adtGroup) adtGroup.classList.toggle('hidden', !hasDrum);
}

async function startExtraction() {
Expand All @@ -268,6 +296,7 @@ async function startExtraction() {
time_signature: document.getElementById('midi-ts').value,
onset_threshold: parseFloat(document.getElementById('midi-onset').value),
frame_threshold: parseFloat(document.getElementById('midi-frame').value),
adt_model: document.getElementById('midi-adt-model')?.value || 'adtof-drums',
}),
});

Expand Down
76 changes: 75 additions & 1 deletion models/midi_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(self) -> None:
self._bp_loader = BasicPitchModelLoader()
self._model: Any | None = None # BasicPitch TF model
self._whisper_model: Any | None = None # faster-whisper WhisperModel
self._adtof_backend: Any | None = None # AdtofBackend (lazy-loaded)

# ------------------------------------------------------------------
# Lifecycle
Expand Down Expand Up @@ -93,12 +94,41 @@ def is_loaded(self) -> bool:
return self._model is not None

def evict(self) -> None:
"""Release both models from memory and trigger GC."""
"""Release all models from memory and trigger GC."""
self._bp_loader.evict()
self._model = None
self._whisper_model = None
self.evict_drum_model()
log.debug("MidiModelLoader: models evicted.")

# ------------------------------------------------------------------
# Internal: ADTOF lazy loader
# ------------------------------------------------------------------

def _ensure_adtof(self) -> Any:
"""Load the ADTOF drum transcription backend on first use."""
if self._adtof_backend is not None:
return self._adtof_backend
try:
from pipelines.adtof_backend import AdtofBackend
except ImportError as exc:
raise ModelLoadError(
"ADTOF backend is not available.",
model_name="adtof-drums",
) from exc
backend = AdtofBackend()
backend.load()
self._adtof_backend = backend
log.info("MidiModelLoader: ADTOF drum backend ready.")
return self._adtof_backend

def evict_drum_model(self) -> None:
"""Evict the ADTOF backend only, leaving BasicPitch intact."""
if self._adtof_backend is not None:
self._adtof_backend.evict()
self._adtof_backend = None
log.debug("MidiModelLoader: ADTOF backend evicted.")

# ------------------------------------------------------------------
# Internal: Whisper lazy loader
# ------------------------------------------------------------------
Expand Down Expand Up @@ -379,3 +409,47 @@ def convert_vocal_to_midi(
path.name, len(events), len(lyrics), len(segments),
)
return events, lyrics

def convert_drum_to_midi(
self,
path: pathlib.Path,
*,
duration: float = 0.0,
) -> list[NoteEvent]:
"""Transcribe a drum stem to GM note events via ADTOF.

Parameters
----------
path:
Drum audio file (must be 44100 Hz).
duration:
If positive, clip events to this many seconds.

Returns
-------
list[NoteEvent]
GM channel-10 note events from ADTOF_5CLASS_GM_NOTE.

Raises
------
:class:`~utils.errors.ModelLoadError`
If the ADTOF backend cannot be loaded.
:class:`~utils.errors.PipelineExecutionError`
If drum transcription fails.
"""
backend = self._ensure_adtof()
try:
events = backend.predict(path)
except PipelineExecutionError:
raise
except Exception as exc:
raise PipelineExecutionError(
f"Drum transcription failed for '{path.name}': {exc}",
pipeline_name="midi",
) from exc

if duration > 0.0:
events = [(s, min(e, duration), p, v) for s, e, p, v in events if s < duration]

log.debug("convert_drum_to_midi: %s -> %d drum events", path.name, len(events))
return events
66 changes: 66 additions & 0 deletions models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,26 @@ class StableAudioSpec(ModelSpec):
conditioning_keys: tuple[str, ...]


@dataclass(frozen=True, slots=True)
class DrumMidiSpec(ModelSpec):
"""Descriptor for a drum transcription (ADT) model.

Additional fields
-----------------
class_count:
Number of drum output classes this model produces.
class_labels:
Human-readable label for each class, in model output order.
checkpoint_url:
Direct download URL for model weights; empty string means weights
are bundled in the pip package.
"""

class_count: int = 0
class_labels: tuple[str, ...] = ()
checkpoint_url: str = ""


@dataclass(frozen=True, slots=True)
class RoformerSpec(ModelSpec):
"""Descriptor for a BS-Roformer / MelBand-Roformer separation model.
Expand Down Expand Up @@ -742,6 +762,38 @@ def _register(spec: ModelSpec) -> ModelSpec:
default_num_overlap=2,
))

# ---------------------------------------------------------------------------
# Drum transcription (ADT) models
# ---------------------------------------------------------------------------

ADTOF_DRUMS = _register(DrumMidiSpec(
model_id="adtof-drums",
display_name="ADTOF Drums (5-class)",
version="0.1.0",
source="xavriley/ADTOF-pytorch",
device="auto",
gpu_capable=True,
device_fallback="cpu",
device_quirks="",
sample_rate=44_100,
hop_size=0,
chunk_size=0,
max_duration_seconds=0.0,
default_bpm=0.0,
default_key="",
default_time_signature="",
quantize_grid="none",
default_min_note_ms=0.0,
capabilities=frozenset({"transcribe", "drum_transcription", "gpu_acceleration"}),
cache_subdir="adtof",
description="Automatic drum transcription — 5-class (kick, snare, tom, hi-hat, cymbal).",
preprocessing="Mono 44.1 kHz; mel spectrogram; context window of 9 frames.",
postprocessing="Peak picking per class; GM note mapping; 60 ms fixed note duration.",
class_count=5,
class_labels=("kick", "snare", "tom", "hi_hat", "cymbal"),
checkpoint_url="",
))

# ---------------------------------------------------------------------------
# Convenience constants
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -854,6 +906,9 @@ def get_loader_kwargs(model_id: str) -> dict[str, Any]:
"config_url": spec.config_url,
}

if isinstance(spec, DrumMidiSpec):
return {"cache_dir": spec.cache_dir}

raise NotImplementedError(
f"No loader kwargs defined for spec type {type(spec).__name__!r}."
)
Expand Down Expand Up @@ -925,6 +980,9 @@ def get_pipeline_defaults(model_id: str) -> dict[str, Any]:
"num_overlap": spec.default_num_overlap,
}

if isinstance(spec, DrumMidiSpec):
return {"class_count": spec.class_count}

raise NotImplementedError(
f"No pipeline defaults defined for spec type {type(spec).__name__!r}."
)
Expand Down Expand Up @@ -1004,6 +1062,14 @@ def get_gui_metadata(model_id: str) -> dict[str, Any]:
"target_instrument": spec.target_instrument,
}

if isinstance(spec, DrumMidiSpec):
return {
"class_count": spec.class_count,
"class_labels": list(spec.class_labels),
"model_choices": [s.model_id for s in list_specs(DrumMidiSpec)],
"tooltip": spec.description,
}

raise NotImplementedError(
f"No GUI metadata defined for spec type {type(spec).__name__!r}."
)
Loading