-
Notifications
You must be signed in to change notification settings - Fork 0
W2v2 convdebug #3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: orig-w2v2
Are you sure you want to change the base?
Changes from all commits
7e6f3cf
6a03e5c
433ca76
19535a2
a500587
b49f103
73e2f3b
8802a31
dc60592
be3ca6a
30f3761
af6d4d6
9ce9909
7cd1be0
9b98663
91cdca2
2c59cec
5565622
45bde79
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
|
|
||
| from fairseq import metrics, utils | ||
| from fairseq.criterions import FairseqCriterion, register_criterion | ||
| from fairseq.utils import index_put, is_xla_tensor | ||
|
|
||
|
|
||
| @register_criterion('wav2vec') | ||
|
|
@@ -41,9 +42,18 @@ def forward(self, model, sample, reduce=True, log_pred=False): | |
| 3) logging outputs to display while training | ||
| """ | ||
| net_output = model(**sample['net_input']) | ||
|
|
||
| logits = model.get_logits(net_output).float() | ||
| target = model.get_targets(sample, net_output) | ||
|
|
||
| if is_xla_tensor(logits): | ||
| # tpu-comment: since dynamic shapes lead to recompilations on xla, | ||
| # we don't shrink tensors using mask_indices. | ||
| # Instead, we do the following when computing loss: | ||
| mi = sample['net_input']['mask_indices'].reshape(logits.size(0)) | ||
| target = index_put(target, ~mi, -1) | ||
|
|
||
| # XXX: handle weights on xla. | ||
| weights = None | ||
| if hasattr(model, 'get_target_weights') and not self.infonce: | ||
| weights = model.get_target_weights(target, net_output) | ||
|
|
@@ -53,11 +63,24 @@ def forward(self, model, sample, reduce=True, log_pred=False): | |
| losses = [] | ||
|
|
||
| if self.infonce: | ||
| loss = F.cross_entropy(logits, target, reduction="sum" if reduce else "none",) | ||
| loss = F.cross_entropy( | ||
| logits, target, reduction="sum" if reduce else "none", | ||
| ignore_index=-1, | ||
| ) | ||
| else: | ||
| loss = F.binary_cross_entropy_with_logits(logits, target.float(), weights, reduction="sum" if reduce else "none",) | ||
| loss = F.binary_cross_entropy_with_logits( | ||
| logits, target.float(), weights, | ||
| reduction="sum" if reduce else "none", | ||
| ignore_index=-1, | ||
| ) | ||
|
|
||
| sample_size = target.numel() if self.infonce else target.long().sum().item() | ||
| if 'sample_size' in sample and self.infonce: | ||
| sample_size = sample['sample_size'] | ||
| elif 'mask_indices' in sample['net_input'] and self.infonce: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe remove "and self.infonce" part? |
||
| sample_size = sample['net_input']['mask_indices'].sum() | ||
| else: | ||
| # XXX: if not self.infonce, is xla path working correctly? | ||
| sample_size = target.numel() if self.infonce else target.long().sum() | ||
| losses.append(loss) | ||
|
|
||
| if self.loss_weights is not None: | ||
|
|
@@ -75,19 +98,22 @@ def forward(self, model, sample, reduce=True, log_pred=False): | |
| losses.append(p) | ||
|
|
||
| logging_output = { | ||
| 'loss': loss.item() if reduce else loss, | ||
| 'loss': loss.detach(), | ||
| 'ntokens': sample_size, | ||
| 'nsentences': sample['id'].numel(), | ||
| 'sample_size': sample_size, | ||
| } | ||
|
|
||
| for lk in self.log_keys: | ||
| if lk in net_output: | ||
| logging_output[lk] = float((net_output[lk])) | ||
| value = net_output[lk] | ||
| if not is_xla_tensor(value): | ||
| value = float(value) | ||
| logging_output[lk] = value | ||
|
|
||
| if len(losses) > 1: | ||
| for i, l in enumerate(losses): | ||
| logging_output[f'loss_{i}'] = l.item() | ||
| logging_output[f'loss_{i}'] = l.detach() | ||
|
|
||
| if self.infonce: | ||
| with torch.no_grad(): | ||
|
|
@@ -98,14 +124,20 @@ def forward(self, model, sample, reduce=True, log_pred=False): | |
| assert logits.dim() > 1, logits.shape | ||
| max = logits.argmax(-1) == 0 | ||
| min = logits.argmin(-1) == 0 | ||
| both = max & min | ||
| corr = max.long().sum().item() - both.long().sum().item() | ||
| count = max.numel() | ||
| if is_xla_tensor(logits): | ||
| max, min = max * mi, min * mi | ||
| both = max & min | ||
| corr = max.long().sum() - both.long().sum() | ||
| count = mi.sum() | ||
| else: | ||
| both = max & min | ||
| corr = max.long().sum().item() - both.long().sum().item() | ||
| count = float(max.numel()) | ||
|
|
||
| logging_output["correct"] = corr | ||
| logging_output["count"] = count | ||
|
|
||
| if log_pred: | ||
| if log_pred and not is_xla_tensor(logits): | ||
| logging_output['logits'] = logits.cpu().numpy() | ||
| logging_output['target'] = target.cpu().numpy() | ||
| return loss, sample_size, logging_output | ||
|
|
@@ -132,7 +164,7 @@ def reduce_metrics(logging_outputs) -> None: | |
| if total > 0: | ||
| metrics.log_derived( | ||
| "accuracy", | ||
| lambda meters: round(meters["_correct"].sum / meters["_total"].sum, 5) | ||
| lambda meters: meters["_correct"].sum / meters["_total"].sum | ||
| if meters["_total"].sum > 0 | ||
| else float("nan"), | ||
| ) | ||
|
|
@@ -154,4 +186,4 @@ def logging_outputs_can_be_summed() -> bool: | |
| across workers prior to calling `reduce_metrics`. Setting this | ||
| to True will improves distributed training speed. | ||
| """ | ||
| return False | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if you keep it at false, do you still see larger accuracy etc (i know we tried 1 node, but still) |
||
| return True | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,7 +12,8 @@ | |
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
| from .. import FairseqDataset | ||
| from .. import FairseqDataset, BaseWrapperDataset | ||
| from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -27,6 +28,8 @@ def __init__( | |
| min_length=0, | ||
| pad=False, | ||
| normalize=False, | ||
| compute_mask_indices=False, | ||
| args=None, | ||
| ): | ||
| super().__init__() | ||
|
|
||
|
|
@@ -40,6 +43,12 @@ def __init__( | |
| self.pad = pad | ||
| self.shuffle = shuffle | ||
| self.normalize = normalize | ||
| self.compute_mask_indices = compute_mask_indices | ||
| if self.compute_mask_indices: | ||
| self.args = args | ||
| self._features_size_map = {} | ||
| self._C = self.args.encoder_embed_dim | ||
| self._conv_feature_layers = eval(self.args.conv_feature_layers) | ||
|
|
||
| def __getitem__(self, index): | ||
| raise NotImplementedError() | ||
|
|
@@ -71,6 +80,45 @@ def crop_to_max_size(self, wav, target_size): | |
| end = size - diff + start | ||
| return wav[start:end] | ||
|
|
||
| def _compute_mask_indices(self, dims, padding_mask): | ||
| B, T, C = dims | ||
| mask_indices, mask_channel_indices = None, None | ||
| if self.args.mask_prob > 0: | ||
| mask_indices = compute_mask_indices( | ||
| (B, T), | ||
| padding_mask, | ||
| self.args.mask_prob, | ||
| self.args.mask_length, | ||
| self.args.mask_selection, | ||
| self.args.mask_other, | ||
| min_masks=2, | ||
| no_overlap=self.args.no_mask_overlap, | ||
| min_space=self.args.mask_min_space, | ||
| ) | ||
| mask_indices = torch.from_numpy(mask_indices) | ||
| if self.args.mask_channel_prob > 0: | ||
| mask_channel_indices = compute_mask_indices( | ||
| (B, C), | ||
| None, | ||
| self.args.mask_channel_prob, | ||
| self.args.mask_channel_length, | ||
| self.args.mask_channel_selection, | ||
| self.args.mask_channel_other, | ||
| no_overlap=self.args.no_mask_channel_overlap, | ||
| min_space=self.args.mask_channel_min_space, | ||
| ) | ||
| mask_channel_indices = ( | ||
| torch.from_numpy(mask_channel_indices) | ||
| .unsqueeze(1) | ||
| .expand(-1, T, -1) | ||
| ) | ||
|
|
||
| return mask_indices, mask_channel_indices | ||
|
|
||
| @staticmethod | ||
| def _bucket_tensor(tensor, num_pad, value): | ||
| return F.pad(tensor, (0, num_pad), value=value) | ||
|
|
||
| def collater(self, samples): | ||
| samples = [ | ||
| s | ||
|
|
@@ -106,9 +154,131 @@ def collater(self, samples): | |
| collated_sources[i] = self.crop_to_max_size(source, target_size) | ||
|
|
||
| input = {"source": collated_sources} | ||
| out = {"id": torch.LongTensor([s["id"] for s in samples])} | ||
| if self.pad: | ||
| input["padding_mask"] = padding_mask | ||
| return {"id": torch.LongTensor([s["id"] for s in samples]), "net_input": input} | ||
|
|
||
| if hasattr(self, 'num_buckets') and self.num_buckets > 0: | ||
| assert self.pad, "Cannot bucket without padding first." | ||
| bucket = max(self._bucketed_sizes[s['id']] for s in samples) | ||
| num_pad = bucket - collated_sources.size(-1) | ||
| if num_pad: | ||
| input['source'] = self._bucket_tensor( | ||
| collated_sources, num_pad, 0 | ||
| ) | ||
| input['padding_mask'] = self._bucket_tensor( | ||
| padding_mask, num_pad, True | ||
| ) | ||
|
|
||
| if self.compute_mask_indices: | ||
| B = input['source'].size(0) | ||
| T = self._get_mask_indices_dims(input['source'].size(-1)) | ||
| padding_mask_reshaped = input['padding_mask'].clone() | ||
| extra = padding_mask_reshaped.size(1) % T | ||
| if extra > 0: | ||
| padding_mask_reshaped = padding_mask_reshaped[:, :-extra] | ||
| padding_mask_reshaped = padding_mask_reshaped.view( | ||
| padding_mask_reshaped.size(0), T, -1 | ||
| ) | ||
| padding_mask_reshaped = padding_mask_reshaped.all(-1) | ||
| mask_indices, mask_channel_indices = self._compute_mask_indices( | ||
| (B, T, self._C), padding_mask_reshaped, | ||
| ) | ||
| input["mask_indices"] = mask_indices | ||
| import pdb | ||
| pdb.set_trace() | ||
| # FIXME: implement | ||
| #input["neg_idxs"] = self._get_neg_idxs(mask_indices) | ||
| input["tszs_after_mask"] = mask_indices.sum(-1).tolist() | ||
| input["mask_channel_indices"] = mask_channel_indices | ||
| out['sample_size'] = mask_indices.sum().item() | ||
|
|
||
| out["net_input"] = input | ||
| return out | ||
|
|
||
| def _get_neg_idxs(self, mask_indices): | ||
| return | ||
|
|
||
| def _sample_negatives(self, mask_indices): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where is this actually used?? |
||
| """ | ||
| Sampling negatives during model's forward is problematic on XLA. | ||
| That's why we do it here during data prep when run on XLA. | ||
| """ | ||
| self.n_negatives = self.args.num_negatives | ||
| self.cross_sample_negatives = self.args.cross_sample_negatives | ||
| if self.n_negatives == 0 and self.cross_sample_negatives == 0: | ||
| # FIXME: handle this | ||
| return y.new(0) | ||
|
|
||
| (bsz, tsz), fsz = mask_indices.shape, self.args.final_dim | ||
| high = mask_indices.sum(-1).max().item() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this means that sometimes you will sample negatives from masked timesteps for examples that are shorter than the longest one. why not sample separately per each example in the batch and use correct high for each example? |
||
| cross_high = high * bsz | ||
|
|
||
| with torch.no_grad(): | ||
| assert high > 1, f"{bsz,tsz,fsz}" | ||
|
|
||
| if self.n_negatives > 0: | ||
| tszs = ( | ||
| torch.arange(tsz) | ||
| .unsqueeze(-1) | ||
| .expand(-1, self.n_negatives) | ||
| .flatten() | ||
| ) | ||
|
|
||
| ts = torch.arange(tsz) | ||
|
|
||
| neg_idxs = torch.randint( | ||
| low=0, high=high-1, size=(bsz, tsz, self.n_negatives) | ||
| ) | ||
| neg_idxs = torch.stack([ | ||
| ts[mask_indices[j]][ni] for j, ni in enumerate(neg_idxs) | ||
| ]).reshape(bsz, -1) | ||
| neg_idxs[neg_idxs >= tszs] += 1 | ||
| import pdb | ||
| pdb.set_trace() | ||
|
|
||
| if self.cross_sample_negatives > 0: | ||
| raise NotImplementedError('Implement for XLA.') | ||
| tszs = ( | ||
| torch.arange(num) | ||
| .unsqueeze(-1) | ||
| .expand(-1, self.cross_sample_negatives) | ||
| .flatten() | ||
| ) | ||
|
|
||
| cross_neg_idxs = torch.randint( | ||
| low=0, | ||
| high=cross_high - 1, | ||
| size=(bsz, self.cross_sample_negatives * num), | ||
| ) | ||
| cross_neg_idxs[cross_neg_idxs >= tszs] += 1 | ||
|
|
||
| if self.n_negatives > 0: | ||
| for i in range(1, bsz): | ||
| neg_idxs[i] += i * high | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is problematic because previously it assumed "high" is the number of timesteps, but we've redefined high above to be something smaller. you need to do neg_idxs[i] += i * tsz here |
||
| else: | ||
| neg_idxs = cross_neg_idxs | ||
|
|
||
| if self.cross_sample_negatives > 0 and self.n_negatives > 0: | ||
| neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) | ||
|
|
||
| negs = y[neg_idxs.view(-1)] | ||
| negs = negs.view( | ||
| bsz, tsz, self.n_negatives + self.cross_sample_negatives, fsz | ||
| ).permute( | ||
| 2, 0, 1, 3 | ||
| ) # to NxBxTxC | ||
| return negs, neg_idxs | ||
|
|
||
| def _get_mask_indices_dims(self, size, padding=0, dilation=1): | ||
| if size not in self._features_size_map: | ||
| L_in = size | ||
| for (_, kernel_size, stride) in self._conv_feature_layers: | ||
| L_out = L_in + 2*padding - dilation*(kernel_size-1) - 1 | ||
| L_out = 1 + L_out // stride | ||
| L_in = L_out | ||
| self._features_size_map[size] = L_out | ||
| return self._features_size_map[size] | ||
|
|
||
| def num_tokens(self, index): | ||
| return self.size(index) | ||
|
|
@@ -144,6 +314,9 @@ def __init__( | |
| min_length=0, | ||
| pad=False, | ||
| normalize=False, | ||
| compute_mask_indices=False, | ||
| args=None, | ||
| num_buckets=0, | ||
| ): | ||
| super().__init__( | ||
| sample_rate=sample_rate, | ||
|
|
@@ -153,6 +326,8 @@ def __init__( | |
| min_length=min_length, | ||
| pad=pad, | ||
| normalize=normalize, | ||
| compute_mask_indices=compute_mask_indices, | ||
| args=args, | ||
| ) | ||
|
|
||
| self.fnames = [] | ||
|
|
@@ -169,8 +344,26 @@ def __init__( | |
| continue | ||
| self.fnames.append(items[0]) | ||
| self.sizes.append(sz) | ||
| self.set_bucket_info(num_buckets) | ||
| logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples") | ||
|
|
||
| def set_bucket_info(self, num_buckets): | ||
| self.num_buckets = num_buckets | ||
| if self.num_buckets > 0: | ||
| self._collated_sizes = np.minimum( | ||
| np.array(self.sizes), self.max_sample_size, | ||
| ) | ||
| self.buckets = get_buckets( | ||
| self._collated_sizes, self.num_buckets, | ||
| ) | ||
| self._bucketed_sizes = get_bucketed_sizes( | ||
| self._collated_sizes, self.buckets | ||
| ) | ||
| logger.info( | ||
| f"{len(self.buckets)} bucket(s) for the audio dataset: " | ||
| f"{self.buckets}" | ||
| ) | ||
|
|
||
| def __getitem__(self, index): | ||
| import soundfile as sf | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this wont actually work because binary cross entropy with logits does not have ignore index. you need to use reduction="none" here and then zero out the loss coming from unmasked states