Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 44 additions & 12 deletions fairseq/criterions/wav2vec_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.utils import index_put, is_xla_tensor


@register_criterion('wav2vec')
Expand Down Expand Up @@ -41,9 +42,18 @@ def forward(self, model, sample, reduce=True, log_pred=False):
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])

logits = model.get_logits(net_output).float()
target = model.get_targets(sample, net_output)

if is_xla_tensor(logits):
# tpu-comment: since dynamic shapes lead to recompilations on xla,
# we don't shrink tensors using mask_indices.
# Instead, we do the following when computing loss:
mi = sample['net_input']['mask_indices'].reshape(logits.size(0))
target = index_put(target, ~mi, -1)

# XXX: handle weights on xla.
weights = None
if hasattr(model, 'get_target_weights') and not self.infonce:
weights = model.get_target_weights(target, net_output)
Expand All @@ -53,11 +63,24 @@ def forward(self, model, sample, reduce=True, log_pred=False):
losses = []

if self.infonce:
loss = F.cross_entropy(logits, target, reduction="sum" if reduce else "none",)
loss = F.cross_entropy(
logits, target, reduction="sum" if reduce else "none",
ignore_index=-1,
)
else:
loss = F.binary_cross_entropy_with_logits(logits, target.float(), weights, reduction="sum" if reduce else "none",)
loss = F.binary_cross_entropy_with_logits(
logits, target.float(), weights,
reduction="sum" if reduce else "none",
ignore_index=-1,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this wont actually work because binary cross entropy with logits does not have ignore index. you need to use reduction="none" here and then zero out the loss coming from unmasked states

)

sample_size = target.numel() if self.infonce else target.long().sum().item()
if 'sample_size' in sample and self.infonce:
sample_size = sample['sample_size']
elif 'mask_indices' in sample['net_input'] and self.infonce:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe remove "and self.infonce" part?

sample_size = sample['net_input']['mask_indices'].sum()
else:
# XXX: if not self.infonce, is xla path working correctly?
sample_size = target.numel() if self.infonce else target.long().sum()
losses.append(loss)

if self.loss_weights is not None:
Expand All @@ -75,19 +98,22 @@ def forward(self, model, sample, reduce=True, log_pred=False):
losses.append(p)

logging_output = {
'loss': loss.item() if reduce else loss,
'loss': loss.detach(),
'ntokens': sample_size,
'nsentences': sample['id'].numel(),
'sample_size': sample_size,
}

for lk in self.log_keys:
if lk in net_output:
logging_output[lk] = float((net_output[lk]))
value = net_output[lk]
if not is_xla_tensor(value):
value = float(value)
logging_output[lk] = value

if len(losses) > 1:
for i, l in enumerate(losses):
logging_output[f'loss_{i}'] = l.item()
logging_output[f'loss_{i}'] = l.detach()

if self.infonce:
with torch.no_grad():
Expand All @@ -98,14 +124,20 @@ def forward(self, model, sample, reduce=True, log_pred=False):
assert logits.dim() > 1, logits.shape
max = logits.argmax(-1) == 0
min = logits.argmin(-1) == 0
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = max.numel()
if is_xla_tensor(logits):
max, min = max * mi, min * mi
both = max & min
corr = max.long().sum() - both.long().sum()
count = mi.sum()
else:
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = float(max.numel())

logging_output["correct"] = corr
logging_output["count"] = count

if log_pred:
if log_pred and not is_xla_tensor(logits):
logging_output['logits'] = logits.cpu().numpy()
logging_output['target'] = target.cpu().numpy()
return loss, sample_size, logging_output
Expand All @@ -132,7 +164,7 @@ def reduce_metrics(logging_outputs) -> None:
if total > 0:
metrics.log_derived(
"accuracy",
lambda meters: round(meters["_correct"].sum / meters["_total"].sum, 5)
lambda meters: meters["_correct"].sum / meters["_total"].sum
if meters["_total"].sum > 0
else float("nan"),
)
Expand All @@ -154,4 +186,4 @@ def logging_outputs_can_be_summed() -> bool:
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return False
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you keep it at false, do you still see larger accuracy etc (i know we tried 1 node, but still)

return True
197 changes: 195 additions & 2 deletions fairseq/data/audio/raw_audio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import torch
import torch.nn.functional as F

from .. import FairseqDataset
from .. import FairseqDataset, BaseWrapperDataset
from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes

logger = logging.getLogger(__name__)

Expand All @@ -27,6 +28,8 @@ def __init__(
min_length=0,
pad=False,
normalize=False,
compute_mask_indices=False,
args=None,
):
super().__init__()

Expand All @@ -40,6 +43,12 @@ def __init__(
self.pad = pad
self.shuffle = shuffle
self.normalize = normalize
self.compute_mask_indices = compute_mask_indices
if self.compute_mask_indices:
self.args = args
self._features_size_map = {}
self._C = self.args.encoder_embed_dim
self._conv_feature_layers = eval(self.args.conv_feature_layers)

def __getitem__(self, index):
raise NotImplementedError()
Expand Down Expand Up @@ -71,6 +80,45 @@ def crop_to_max_size(self, wav, target_size):
end = size - diff + start
return wav[start:end]

def _compute_mask_indices(self, dims, padding_mask):
B, T, C = dims
mask_indices, mask_channel_indices = None, None
if self.args.mask_prob > 0:
mask_indices = compute_mask_indices(
(B, T),
padding_mask,
self.args.mask_prob,
self.args.mask_length,
self.args.mask_selection,
self.args.mask_other,
min_masks=2,
no_overlap=self.args.no_mask_overlap,
min_space=self.args.mask_min_space,
)
mask_indices = torch.from_numpy(mask_indices)
if self.args.mask_channel_prob > 0:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.args.mask_channel_prob,
self.args.mask_channel_length,
self.args.mask_channel_selection,
self.args.mask_channel_other,
no_overlap=self.args.no_mask_channel_overlap,
min_space=self.args.mask_channel_min_space,
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices)
.unsqueeze(1)
.expand(-1, T, -1)
)

return mask_indices, mask_channel_indices

@staticmethod
def _bucket_tensor(tensor, num_pad, value):
return F.pad(tensor, (0, num_pad), value=value)

def collater(self, samples):
samples = [
s
Expand Down Expand Up @@ -106,9 +154,131 @@ def collater(self, samples):
collated_sources[i] = self.crop_to_max_size(source, target_size)

input = {"source": collated_sources}
out = {"id": torch.LongTensor([s["id"] for s in samples])}
if self.pad:
input["padding_mask"] = padding_mask
return {"id": torch.LongTensor([s["id"] for s in samples]), "net_input": input}

if hasattr(self, 'num_buckets') and self.num_buckets > 0:
assert self.pad, "Cannot bucket without padding first."
bucket = max(self._bucketed_sizes[s['id']] for s in samples)
num_pad = bucket - collated_sources.size(-1)
if num_pad:
input['source'] = self._bucket_tensor(
collated_sources, num_pad, 0
)
input['padding_mask'] = self._bucket_tensor(
padding_mask, num_pad, True
)

if self.compute_mask_indices:
B = input['source'].size(0)
T = self._get_mask_indices_dims(input['source'].size(-1))
padding_mask_reshaped = input['padding_mask'].clone()
extra = padding_mask_reshaped.size(1) % T
if extra > 0:
padding_mask_reshaped = padding_mask_reshaped[:, :-extra]
padding_mask_reshaped = padding_mask_reshaped.view(
padding_mask_reshaped.size(0), T, -1
)
padding_mask_reshaped = padding_mask_reshaped.all(-1)
mask_indices, mask_channel_indices = self._compute_mask_indices(
(B, T, self._C), padding_mask_reshaped,
)
input["mask_indices"] = mask_indices
import pdb
pdb.set_trace()
# FIXME: implement
#input["neg_idxs"] = self._get_neg_idxs(mask_indices)
input["tszs_after_mask"] = mask_indices.sum(-1).tolist()
input["mask_channel_indices"] = mask_channel_indices
out['sample_size'] = mask_indices.sum().item()

out["net_input"] = input
return out

def _get_neg_idxs(self, mask_indices):
return

def _sample_negatives(self, mask_indices):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is this actually used??

"""
Sampling negatives during model's forward is problematic on XLA.
That's why we do it here during data prep when run on XLA.
"""
self.n_negatives = self.args.num_negatives
self.cross_sample_negatives = self.args.cross_sample_negatives
if self.n_negatives == 0 and self.cross_sample_negatives == 0:
# FIXME: handle this
return y.new(0)

(bsz, tsz), fsz = mask_indices.shape, self.args.final_dim
high = mask_indices.sum(-1).max().item()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this means that sometimes you will sample negatives from masked timesteps for examples that are shorter than the longest one. why not sample separately per each example in the batch and use correct high for each example?

cross_high = high * bsz

with torch.no_grad():
assert high > 1, f"{bsz,tsz,fsz}"

if self.n_negatives > 0:
tszs = (
torch.arange(tsz)
.unsqueeze(-1)
.expand(-1, self.n_negatives)
.flatten()
)

ts = torch.arange(tsz)

neg_idxs = torch.randint(
low=0, high=high-1, size=(bsz, tsz, self.n_negatives)
)
neg_idxs = torch.stack([
ts[mask_indices[j]][ni] for j, ni in enumerate(neg_idxs)
]).reshape(bsz, -1)
neg_idxs[neg_idxs >= tszs] += 1
import pdb
pdb.set_trace()

if self.cross_sample_negatives > 0:
raise NotImplementedError('Implement for XLA.')
tszs = (
torch.arange(num)
.unsqueeze(-1)
.expand(-1, self.cross_sample_negatives)
.flatten()
)

cross_neg_idxs = torch.randint(
low=0,
high=cross_high - 1,
size=(bsz, self.cross_sample_negatives * num),
)
cross_neg_idxs[cross_neg_idxs >= tszs] += 1

if self.n_negatives > 0:
for i in range(1, bsz):
neg_idxs[i] += i * high
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is problematic because previously it assumed "high" is the number of timesteps, but we've redefined high above to be something smaller. you need to do neg_idxs[i] += i * tsz here

else:
neg_idxs = cross_neg_idxs

if self.cross_sample_negatives > 0 and self.n_negatives > 0:
neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1)

negs = y[neg_idxs.view(-1)]
negs = negs.view(
bsz, tsz, self.n_negatives + self.cross_sample_negatives, fsz
).permute(
2, 0, 1, 3
) # to NxBxTxC
return negs, neg_idxs

def _get_mask_indices_dims(self, size, padding=0, dilation=1):
if size not in self._features_size_map:
L_in = size
for (_, kernel_size, stride) in self._conv_feature_layers:
L_out = L_in + 2*padding - dilation*(kernel_size-1) - 1
L_out = 1 + L_out // stride
L_in = L_out
self._features_size_map[size] = L_out
return self._features_size_map[size]

def num_tokens(self, index):
return self.size(index)
Expand Down Expand Up @@ -144,6 +314,9 @@ def __init__(
min_length=0,
pad=False,
normalize=False,
compute_mask_indices=False,
args=None,
num_buckets=0,
):
super().__init__(
sample_rate=sample_rate,
Expand All @@ -153,6 +326,8 @@ def __init__(
min_length=min_length,
pad=pad,
normalize=normalize,
compute_mask_indices=compute_mask_indices,
args=args,
)

self.fnames = []
Expand All @@ -169,8 +344,26 @@ def __init__(
continue
self.fnames.append(items[0])
self.sizes.append(sz)
self.set_bucket_info(num_buckets)
logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples")

def set_bucket_info(self, num_buckets):
self.num_buckets = num_buckets
if self.num_buckets > 0:
self._collated_sizes = np.minimum(
np.array(self.sizes), self.max_sample_size,
)
self.buckets = get_buckets(
self._collated_sizes, self.num_buckets,
)
self._bucketed_sizes = get_bucketed_sizes(
self._collated_sizes, self.buckets
)
logger.info(
f"{len(self.buckets)} bucket(s) for the audio dataset: "
f"{self.buckets}"
)

def __getitem__(self, index):
import soundfile as sf

Expand Down
Loading