Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/musicagent/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ def vocab_unified(self) -> Path:
"""Path to unified (melody + chord) vocabulary."""
return self.data_processed / "vocab_unified.json"

@property
def vocab_melody(self) -> Path:
"""Path to melody-only vocabulary."""
return self.data_processed / "vocab_melody.json"

@property
def vocab_chord(self) -> Path:
"""Path to chord-only vocabulary."""
return self.data_processed / "vocab_chord.json"


@dataclass
class OfflineConfig:
Expand Down
91 changes: 84 additions & 7 deletions src/musicagent/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def __init__(self, config: DataConfig, split: str = "train"):
self.config = config
self.split = split

# In the current pipeline we require a unified vocabulary produced by
# ``scripts/preprocess.py``.
# Load unified vocabulary (still needed for transposition tables and
# backward compatibility with data stored in unified ID space).
if not config.vocab_unified.exists():
raise FileNotFoundError(
f"Unified vocabulary not found at {config.vocab_unified}. "
Expand All @@ -42,16 +42,59 @@ def __init__(self, config: DataConfig, split: str = "train"):
self.unified_token_to_id: dict[str, int] = token_to_id
self.unified_id_to_token: dict[int, str] = {idx: tok for tok, idx in token_to_id.items()}

# Expose unified vocab size for models that operate directly in this
# space. We use 1 + max(ID) rather than len(token_to_id) to remain
# robust to any sparse ID layouts created at preprocessing time.
# Expose unified vocab size. We use 1 + max(ID) rather than len(token_to_id)
# to remain robust to any sparse ID layouts created at preprocessing time.
if token_to_id:
self.unified_vocab_size: int = max(token_to_id.values()) + 1
else:
self.unified_vocab_size = 0

# Derive melody / chord views from the unified vocabulary based on
# token naming conventions used in preprocessing.
# Load separate melody and chord vocabularies for models that use
# separate embedding tables.
if not config.vocab_melody.exists():
raise FileNotFoundError(
f"Melody vocabulary not found at {config.vocab_melody}. "
"Please preprocess the dataset with scripts/preprocess.py."
)
if not config.vocab_chord.exists():
raise FileNotFoundError(
f"Chord vocabulary not found at {config.vocab_chord}. "
"Please preprocess the dataset with scripts/preprocess.py."
)

with open(config.vocab_melody) as f:
melody_vocab = json.load(f)
with open(config.vocab_chord) as f:
chord_vocab = json.load(f)

self.melody_token_to_id: dict[str, int] = melody_vocab.get("token_to_id", {})
self.chord_token_to_id: dict[str, int] = chord_vocab.get("token_to_id", {})

self.melody_id_to_token: dict[int, str] = {
idx: tok for tok, idx in self.melody_token_to_id.items()
}
self.chord_id_to_token: dict[int, str] = {
idx: tok for tok, idx in self.chord_token_to_id.items()
}

# Compute vocab sizes dynamically from max ID + 1 to handle sparse layouts.
if self.melody_token_to_id:
self.melody_vocab_size: int = max(self.melody_token_to_id.values()) + 1
else:
self.melody_vocab_size = 0

if self.chord_token_to_id:
self.chord_vocab_size: int = max(self.chord_token_to_id.values()) + 1
else:
self.chord_vocab_size = 0

# Compute offset for converting unified chord IDs to chord vocab IDs.
# In unified vocab, chord tokens are offset by melody_size.
# unified_chord_id = chord_vocab_id + _melody_offset
self._melody_offset = self._compute_melody_offset()

# Derive melody / chord views from the unified vocabulary for
# transposition table construction.
#
# - Melody tokens: "pitch_{midi}_on" / "pitch_{midi}_hold"
# - Chord tokens: "{Root}:{quality}/{inv}_on" / "_hold"
Expand All @@ -69,6 +112,40 @@ def __init__(self, config: DataConfig, split: str = "train"):

self._build_transposition_tables()

def _compute_melody_offset(self) -> int:
"""Compute offset between unified and chord vocab IDs.

In the unified vocabulary, chord tokens are assigned IDs starting after
all melody tokens. This offset is: unified_chord_id - chord_vocab_id.
"""
# Find a chord token that exists in both vocabs to compute offset
for token, chord_id in self.chord_token_to_id.items():
# Skip special tokens (they have same ID in both vocabs)
if token.startswith("<") or token == "rest":
continue
unified_id = self.unified_token_to_id.get(token)
if unified_id is not None:
return unified_id - chord_id
return 0 # fallback if no chord tokens found

def _unified_to_chord_id(self, unified_id: int) -> int:
"""Convert a unified vocab ID to chord vocab ID.

Special tokens (0-3) have the same ID in both vocabs.
Chord tokens are offset: chord_id = unified_id - _melody_offset.
"""
if unified_id < 4: # Special tokens: pad, sos, eos, rest
return unified_id
return unified_id - self._melody_offset

def _unified_to_melody_id(self, unified_id: int) -> int:
"""Convert a unified vocab ID to melody vocab ID.

Melody tokens have the same ID in unified and melody vocabs,
so this is an identity mapping.
"""
return unified_id

def _build_transposition_tables(self) -> None:
"""Pre-compute transposition tables for melody and chords."""
max_transpose = self.config.max_transpose
Expand Down
22 changes: 19 additions & 3 deletions src/musicagent/data/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,17 +78,33 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
chord_frames = chord_frames[start : start + max_frames]

# On-the-fly random transposition in [-max_transpose, max_transpose].
semitones = random.randint(-self.config.max_transpose, self.config.max_transpose)

# Transposition operates in unified ID space.
semitones = random.randint(
-self.config.max_transpose, self.config.max_transpose
)

melody_frames = self._transpose_melody(melody_frames, semitones)
chord_frames = self._transpose_chord(chord_frames, semitones)
else:
# Validation/Test: Just truncate (no augmentation)
melody_frames = melody_frames[:max_frames]
chord_frames = chord_frames[:max_frames]

# Convert chord frames from unified ID space to chord vocab space.
# Melody frames stay as-is (same IDs in unified and melody vocab).
chord_frames_converted = np.array(
[self._unified_to_chord_id(int(x)) for x in chord_frames], dtype=np.int64
)

# Re-add SOS and EOS to ensure proper sequence structure
src_seq = np.concatenate([[self.config.sos_id], melody_frames, [self.config.eos_id]])
tgt_seq = np.concatenate([[self.config.sos_id], chord_frames, [self.config.eos_id]])
# SOS/EOS have same ID (1, 2) in both vocab spaces
src_seq = np.concatenate(
[[self.config.sos_id], melody_frames, [self.config.eos_id]]
)
tgt_seq = np.concatenate(
[[self.config.sos_id], chord_frames_converted, [self.config.eos_id]]
)

return {
"src": torch.tensor(src_seq, dtype=torch.long),
Expand Down
112 changes: 66 additions & 46 deletions src/musicagent/data/online.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,25 +89,10 @@ def __init__(self, config: DataConfig, split: str = "train"):
"check preprocessing and DataConfig settings."
)

# Build unified vocabulary views (melody tokens + chord tokens).
#
# `BaseDataset` has already loaded the unified vocabulary created at
# preprocessing time, exposing:
# - `unified_vocab_size`
# - `vocab_melody` / `vocab_chord` (excluding specials)
#
# For convenience (training logs, tests) we expose logical per-track
# vocab sizes that *include* the shared special tokens.
special_ids = {
self.config.pad_id,
self.config.sos_id,
self.config.eos_id,
self.config.rest_id,
}
num_specials = len(special_ids)

self.melody_vocab_size = len(self.vocab_melody) + num_specials
self.chord_vocab_size = len(self.vocab_chord) + num_specials
# `BaseDataset` has already loaded the separate melody and chord
# vocabularies, exposing `melody_vocab_size` and `chord_vocab_size`
# computed from the separate vocab files. We use those inherited values
# directly rather than recalculating from the unified vocab views.

def _melody_to_unified(self, token_id: int) -> int:
"""Convert melody token ID to unified vocab ID.
Expand All @@ -133,35 +118,48 @@ def _interleave(
self,
melody_seq: np.ndarray,
chord_seq: np.ndarray,
) -> np.ndarray:
) -> tuple[np.ndarray, np.ndarray]:
"""Interleave melody and chord sequences: [SOS, y₁, x₁, y₂, x₂, ...].

We prepend an SOS token (as a chord) to the sequence.
Then we alternate chord (y) and melody (x).

Sequence:
0: SOS (Chord)
1: y₁ (Chord)
2: x₁ (Melody)
3: y₂ (Chord)
4: x₂ (Melody)
0: SOS (Chord) - in chord vocab space
1: y₁ (Chord) - in chord vocab space
2: x₁ (Melody) - in melody vocab space
3: y₂ (Chord) - in chord vocab space
4: x₂ (Melody) - in melody vocab space
...

Returns:
interleaved: Token IDs where chord positions use chord vocab IDs
and melody positions use melody vocab IDs.
is_melody: Boolean mask, True for melody positions (even indices > 0).
"""
seq_len = len(melody_seq)
total_len = seq_len * 2 + 1

# Add space for SOS at the beginning
interleaved = np.zeros(seq_len * 2 + 1, dtype=np.int64)
interleaved = np.zeros(total_len, dtype=np.int64)
is_melody = np.zeros(total_len, dtype=np.bool_)

# Prepend SOS (as chord token)
# Use the SOS ID from config, mapped to unified chord vocab space
interleaved[0] = self._chord_to_unified(self.config.sos_id)
# Prepend SOS (as chord token in chord vocab space)
# SOS has same ID (1) in both vocab spaces
interleaved[0] = self.config.sos_id
is_melody[0] = False # SOS is treated as chord position

for t in range(seq_len):
# y_t goes to 2*t + 1
interleaved[2 * t + 1] = self._chord_to_unified(chord_seq[t])
# x_t goes to 2*t + 2
interleaved[2 * t + 2] = self._melody_to_unified(melody_seq[t])
# y_t (chord) goes to position 2*t + 1, convert to chord vocab ID
chord_unified_id = int(chord_seq[t])
interleaved[2 * t + 1] = self._unified_to_chord_id(chord_unified_id)
is_melody[2 * t + 1] = False

# x_t (melody) goes to position 2*t + 2, stays in melody vocab ID
interleaved[2 * t + 2] = int(melody_seq[t])
is_melody[2 * t + 2] = True

return interleaved
return interleaved, is_melody

def __len__(self) -> int:
"""Number of sequences with at least one usable frame."""
Expand Down Expand Up @@ -252,12 +250,14 @@ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
tgt_seq = chord_frames[:usable_len]

# Interleave: [SOS, y₁, x₁, y₂, x₂, ...]
interleaved = self._interleave(src_seq, tgt_seq)
# Returns (interleaved_ids, is_melody_mask)
interleaved, is_melody = self._interleave(src_seq, tgt_seq)

# For language modeling, input is [:-1] and target is [1:]
# But we return the full sequence; the training loop handles the shift
return {
"input_ids": torch.tensor(interleaved, dtype=torch.long),
"is_melody": torch.tensor(is_melody, dtype=torch.bool),
}


Expand All @@ -268,34 +268,54 @@ def make_online_collate_fn(pad_id: int = 0):
pad_id: Token ID used for padding (should match DataConfig.pad_id).

Returns:
A collate function that pads sequences to uniform length.
A collate function that pads sequences to uniform length and returns
a dictionary with both input_ids and is_melody tensors.
"""

def collate_fn(batch: list) -> torch.Tensor:
"""Collate function for online dataset (single interleaved sequence).
def collate_fn(batch: list) -> dict[str, torch.Tensor]:
"""Collate function for online dataset (interleaved sequence + mask).

Online sequences are variable-length (due to cropping at the frame level),
so we right-pad them to a common length for batching.
so we right-pad them to a common length for batching. The is_melody mask
is padded with False (padding positions are treated as chord positions
but will be ignored via attention mask).
"""
sequences = [x["input_ids"] for x in batch]
masks = [x["is_melody"] for x in batch]

if not sequences:
return torch.empty(0, 0, dtype=torch.long)
return {
"input_ids": torch.empty(0, 0, dtype=torch.long),
"is_melody": torch.empty(0, 0, dtype=torch.bool),
}

max_len = max(seq.size(0) for seq in sequences)
device = sequences[0].device
dtype = sequences[0].dtype

padded = torch.full(
# Pad input_ids with pad_id
padded_ids = torch.full(
(len(sequences), max_len),
fill_value=pad_id,
dtype=dtype,
dtype=torch.long,
device=device,
)

for i, seq in enumerate(sequences):
# Pad is_melody with False (pad positions treated as chord for embedding,
# but will be masked out in attention anyway)
padded_mask = torch.zeros(
(len(sequences), max_len),
dtype=torch.bool,
device=device,
)

for i, (seq, mask) in enumerate(zip(sequences, masks)):
length = seq.size(0)
padded[i, :length] = seq
padded_ids[i, :length] = seq
padded_mask[i, :length] = mask

return padded
return {
"input_ids": padded_ids,
"is_melody": padded_mask,
}

return collate_fn
26 changes: 16 additions & 10 deletions src/musicagent/eval/offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,21 +100,27 @@ def evaluate_offline(

dataset = getattr(test_loader, "dataset", None)

# In the unified pipeline, sequences on disk already use the unified ID
# space produced by preprocessing. If explicit mappings are not provided, we
# always decode via this unified vocabulary.
# With separate vocabularies, melody IDs remain in unified/melody space
# (same IDs), but chord IDs are now in chord vocab space. We must use
# the appropriate mapping for each.
if id_to_melody is None or id_to_chord is None:
if dataset is None:
raise ValueError(
"id_to_melody/id_to_chord not provided and test_loader has no dataset."
)

if hasattr(dataset, "unified_id_to_token"):
if hasattr(dataset, "melody_id_to_token") and hasattr(dataset, "chord_id_to_token"):
id_to_melody = dataset.melody_id_to_token # type: ignore[assignment]
id_to_chord = dataset.chord_id_to_token # type: ignore[assignment]
elif hasattr(dataset, "unified_id_to_token"):
# Fallback for legacy datasets without separate vocab files
unified_map: dict[int, str] = dataset.unified_id_to_token # type: ignore[assignment]
id_to_melody = unified_map
id_to_chord = unified_map
else:
raise ValueError("Dataset does not expose unified_id_to_token for decoding.")
raise ValueError(
"Dataset does not expose melody_id_to_token/chord_id_to_token for decoding."
)

# --- 1. Test loss / perplexity (teacher-forced) ---
criterion = nn.CrossEntropyLoss(ignore_index=d_cfg.pad_id)
Expand Down Expand Up @@ -289,14 +295,14 @@ def main():
collate_fn=collate,
)

# Load model (unified ID space)
vocab_size = test_ds.unified_vocab_size
chord_token_ids = sorted(test_ds.vocab_chord.values())
# Load model with separate vocab sizes
melody_vocab_size = test_ds.melody_vocab_size
chord_vocab_size = test_ds.chord_vocab_size
model = OfflineTransformer(
m_cfg,
d_cfg,
vocab_size=vocab_size,
chord_token_ids=chord_token_ids,
melody_vocab_size=melody_vocab_size,
chord_vocab_size=chord_vocab_size,
).to(device)
state_dict = safe_load_state_dict(args.checkpoint, map_location=device)
model.load_state_dict(state_dict)
Expand Down
Loading