diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index 019db62249..364c716300 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -10,6 +10,7 @@ from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.utils import index_put, is_xla_tensor @register_criterion('wav2vec') @@ -41,9 +42,18 @@ def forward(self, model, sample, reduce=True, log_pred=False): 3) logging outputs to display while training """ net_output = model(**sample['net_input']) + logits = model.get_logits(net_output).float() target = model.get_targets(sample, net_output) + if is_xla_tensor(logits): + # tpu-comment: since dynamic shapes lead to recompilations on xla, + # we don't shrink tensors using mask_indices. + # Instead, we do the following when computing loss: + mi = sample['net_input']['mask_indices'].reshape(logits.size(0)) + target = index_put(target, ~mi, -1) + + # XXX: handle weights on xla. weights = None if hasattr(model, 'get_target_weights') and not self.infonce: weights = model.get_target_weights(target, net_output) @@ -53,11 +63,24 @@ def forward(self, model, sample, reduce=True, log_pred=False): losses = [] if self.infonce: - loss = F.cross_entropy(logits, target, reduction="sum" if reduce else "none",) + loss = F.cross_entropy( + logits, target, reduction="sum" if reduce else "none", + ignore_index=-1, + ) else: - loss = F.binary_cross_entropy_with_logits(logits, target.float(), weights, reduction="sum" if reduce else "none",) + loss = F.binary_cross_entropy_with_logits( + logits, target.float(), weights, + reduction="sum" if reduce else "none", + ignore_index=-1, + ) - sample_size = target.numel() if self.infonce else target.long().sum().item() + if 'sample_size' in sample and self.infonce: + sample_size = sample['sample_size'] + elif 'mask_indices' in sample['net_input'] and self.infonce: + sample_size = sample['net_input']['mask_indices'].sum() + else: + # XXX: if not self.infonce, is xla path working correctly? + sample_size = target.numel() if self.infonce else target.long().sum() losses.append(loss) if self.loss_weights is not None: @@ -75,7 +98,7 @@ def forward(self, model, sample, reduce=True, log_pred=False): losses.append(p) logging_output = { - 'loss': loss.item() if reduce else loss, + 'loss': loss.detach(), 'ntokens': sample_size, 'nsentences': sample['id'].numel(), 'sample_size': sample_size, @@ -83,11 +106,14 @@ def forward(self, model, sample, reduce=True, log_pred=False): for lk in self.log_keys: if lk in net_output: - logging_output[lk] = float((net_output[lk])) + value = net_output[lk] + if not is_xla_tensor(value): + value = float(value) + logging_output[lk] = value if len(losses) > 1: for i, l in enumerate(losses): - logging_output[f'loss_{i}'] = l.item() + logging_output[f'loss_{i}'] = l.detach() if self.infonce: with torch.no_grad(): @@ -98,14 +124,20 @@ def forward(self, model, sample, reduce=True, log_pred=False): assert logits.dim() > 1, logits.shape max = logits.argmax(-1) == 0 min = logits.argmin(-1) == 0 - both = max & min - corr = max.long().sum().item() - both.long().sum().item() - count = max.numel() + if is_xla_tensor(logits): + max, min = max * mi, min * mi + both = max & min + corr = max.long().sum() - both.long().sum() + count = mi.sum() + else: + both = max & min + corr = max.long().sum().item() - both.long().sum().item() + count = float(max.numel()) logging_output["correct"] = corr logging_output["count"] = count - if log_pred: + if log_pred and not is_xla_tensor(logits): logging_output['logits'] = logits.cpu().numpy() logging_output['target'] = target.cpu().numpy() return loss, sample_size, logging_output @@ -132,7 +164,7 @@ def reduce_metrics(logging_outputs) -> None: if total > 0: metrics.log_derived( "accuracy", - lambda meters: round(meters["_correct"].sum / meters["_total"].sum, 5) + lambda meters: meters["_correct"].sum / meters["_total"].sum if meters["_total"].sum > 0 else float("nan"), ) @@ -154,4 +186,4 @@ def logging_outputs_can_be_summed() -> bool: across workers prior to calling `reduce_metrics`. Setting this to True will improves distributed training speed. """ - return False + return True diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index 675b095647..4d9b60c372 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -12,7 +12,8 @@ import torch import torch.nn.functional as F -from .. import FairseqDataset +from .. import FairseqDataset, BaseWrapperDataset +from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes logger = logging.getLogger(__name__) @@ -27,6 +28,8 @@ def __init__( min_length=0, pad=False, normalize=False, + compute_mask_indices=False, + args=None, ): super().__init__() @@ -40,6 +43,12 @@ def __init__( self.pad = pad self.shuffle = shuffle self.normalize = normalize + self.compute_mask_indices = compute_mask_indices + if self.compute_mask_indices: + self.args = args + self._features_size_map = {} + self._C = self.args.encoder_embed_dim + self._conv_feature_layers = eval(self.args.conv_feature_layers) def __getitem__(self, index): raise NotImplementedError() @@ -71,6 +80,45 @@ def crop_to_max_size(self, wav, target_size): end = size - diff + start return wav[start:end] + def _compute_mask_indices(self, dims, padding_mask): + B, T, C = dims + mask_indices, mask_channel_indices = None, None + if self.args.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.args.mask_prob, + self.args.mask_length, + self.args.mask_selection, + self.args.mask_other, + min_masks=2, + no_overlap=self.args.no_mask_overlap, + min_space=self.args.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices) + if self.args.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.args.mask_channel_prob, + self.args.mask_channel_length, + self.args.mask_channel_selection, + self.args.mask_channel_other, + no_overlap=self.args.no_mask_channel_overlap, + min_space=self.args.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .unsqueeze(1) + .expand(-1, T, -1) + ) + + return mask_indices, mask_channel_indices + + @staticmethod + def _bucket_tensor(tensor, num_pad, value): + return F.pad(tensor, (0, num_pad), value=value) + def collater(self, samples): samples = [ s @@ -106,9 +154,131 @@ def collater(self, samples): collated_sources[i] = self.crop_to_max_size(source, target_size) input = {"source": collated_sources} + out = {"id": torch.LongTensor([s["id"] for s in samples])} if self.pad: input["padding_mask"] = padding_mask - return {"id": torch.LongTensor([s["id"] for s in samples]), "net_input": input} + + if hasattr(self, 'num_buckets') and self.num_buckets > 0: + assert self.pad, "Cannot bucket without padding first." + bucket = max(self._bucketed_sizes[s['id']] for s in samples) + num_pad = bucket - collated_sources.size(-1) + if num_pad: + input['source'] = self._bucket_tensor( + collated_sources, num_pad, 0 + ) + input['padding_mask'] = self._bucket_tensor( + padding_mask, num_pad, True + ) + + if self.compute_mask_indices: + B = input['source'].size(0) + T = self._get_mask_indices_dims(input['source'].size(-1)) + padding_mask_reshaped = input['padding_mask'].clone() + extra = padding_mask_reshaped.size(1) % T + if extra > 0: + padding_mask_reshaped = padding_mask_reshaped[:, :-extra] + padding_mask_reshaped = padding_mask_reshaped.view( + padding_mask_reshaped.size(0), T, -1 + ) + padding_mask_reshaped = padding_mask_reshaped.all(-1) + mask_indices, mask_channel_indices = self._compute_mask_indices( + (B, T, self._C), padding_mask_reshaped, + ) + input["mask_indices"] = mask_indices + import pdb + pdb.set_trace() + # FIXME: implement + #input["neg_idxs"] = self._get_neg_idxs(mask_indices) + input["tszs_after_mask"] = mask_indices.sum(-1).tolist() + input["mask_channel_indices"] = mask_channel_indices + out['sample_size'] = mask_indices.sum().item() + + out["net_input"] = input + return out + + def _get_neg_idxs(self, mask_indices): + return + + def _sample_negatives(self, mask_indices): + """ + Sampling negatives during model's forward is problematic on XLA. + That's why we do it here during data prep when run on XLA. + """ + self.n_negatives = self.args.num_negatives + self.cross_sample_negatives = self.args.cross_sample_negatives + if self.n_negatives == 0 and self.cross_sample_negatives == 0: + # FIXME: handle this + return y.new(0) + + (bsz, tsz), fsz = mask_indices.shape, self.args.final_dim + high = mask_indices.sum(-1).max().item() + cross_high = high * bsz + + with torch.no_grad(): + assert high > 1, f"{bsz,tsz,fsz}" + + if self.n_negatives > 0: + tszs = ( + torch.arange(tsz) + .unsqueeze(-1) + .expand(-1, self.n_negatives) + .flatten() + ) + + ts = torch.arange(tsz) + + neg_idxs = torch.randint( + low=0, high=high-1, size=(bsz, tsz, self.n_negatives) + ) + neg_idxs = torch.stack([ + ts[mask_indices[j]][ni] for j, ni in enumerate(neg_idxs) + ]).reshape(bsz, -1) + neg_idxs[neg_idxs >= tszs] += 1 + import pdb + pdb.set_trace() + + if self.cross_sample_negatives > 0: + raise NotImplementedError('Implement for XLA.') + tszs = ( + torch.arange(num) + .unsqueeze(-1) + .expand(-1, self.cross_sample_negatives) + .flatten() + ) + + cross_neg_idxs = torch.randint( + low=0, + high=cross_high - 1, + size=(bsz, self.cross_sample_negatives * num), + ) + cross_neg_idxs[cross_neg_idxs >= tszs] += 1 + + if self.n_negatives > 0: + for i in range(1, bsz): + neg_idxs[i] += i * high + else: + neg_idxs = cross_neg_idxs + + if self.cross_sample_negatives > 0 and self.n_negatives > 0: + neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) + + negs = y[neg_idxs.view(-1)] + negs = negs.view( + bsz, tsz, self.n_negatives + self.cross_sample_negatives, fsz + ).permute( + 2, 0, 1, 3 + ) # to NxBxTxC + return negs, neg_idxs + + def _get_mask_indices_dims(self, size, padding=0, dilation=1): + if size not in self._features_size_map: + L_in = size + for (_, kernel_size, stride) in self._conv_feature_layers: + L_out = L_in + 2*padding - dilation*(kernel_size-1) - 1 + L_out = 1 + L_out // stride + L_in = L_out + self._features_size_map[size] = L_out + return self._features_size_map[size] def num_tokens(self, index): return self.size(index) @@ -144,6 +314,9 @@ def __init__( min_length=0, pad=False, normalize=False, + compute_mask_indices=False, + args=None, + num_buckets=0, ): super().__init__( sample_rate=sample_rate, @@ -153,6 +326,8 @@ def __init__( min_length=min_length, pad=pad, normalize=normalize, + compute_mask_indices=compute_mask_indices, + args=args, ) self.fnames = [] @@ -169,8 +344,26 @@ def __init__( continue self.fnames.append(items[0]) self.sizes.append(sz) + self.set_bucket_info(num_buckets) logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples") + def set_bucket_info(self, num_buckets): + self.num_buckets = num_buckets + if self.num_buckets > 0: + self._collated_sizes = np.minimum( + np.array(self.sizes), self.max_sample_size, + ) + self.buckets = get_buckets( + self._collated_sizes, self.num_buckets, + ) + self._bucketed_sizes = get_bucketed_sizes( + self._collated_sizes, self.buckets + ) + logger.info( + f"{len(self.buckets)} bucket(s) for the audio dataset: " + f"{self.buckets}" + ) + def __getitem__(self, index): import soundfile as sf diff --git a/fairseq/data/bucket_pad_length_dataset.py b/fairseq/data/bucket_pad_length_dataset.py index 6f53d01188..e4ce3c40ef 100644 --- a/fairseq/data/bucket_pad_length_dataset.py +++ b/fairseq/data/bucket_pad_length_dataset.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from fairseq.data import BaseWrapperDataset +from fairseq.data.data_utils import get_buckets, get_bucketed_sizes class BucketPadLengthDataset(BaseWrapperDataset): @@ -30,42 +31,43 @@ def __init__( num_buckets, pad_idx, left_pad, + tensor_key=None, ): super().__init__(dataset) self.pad_idx = pad_idx self.left_pad = left_pad assert num_buckets > 0 - self.buckets = np.unique( - np.percentile( - sizes, - np.linspace(0, 100, num_buckets + 1), - interpolation='lower', - )[1:] - ) + self.buckets = get_buckets(sizes, num_buckets) + self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets) + self._tensor_key = tensor_key - def get_bucketed_sizes(orig_sizes, buckets): - sizes = np.copy(orig_sizes) - assert np.min(sizes) >= 0 - start_val = -1 - for end_val in buckets: - mask = (sizes > start_val) & (sizes <= end_val) - sizes[mask] = end_val - start_val = end_val - return sizes + def _set_tensor(self, item, val): + if self._tensor_key is None: + return val + item[self._tensor_key] = val + return item - self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets) + def _get_tensor(self, item): + if self._tensor_key is None: + return item + return item[self._tensor_key] - def __getitem__(self, index): - item = self.dataset[index] - bucket_size = self._bucketed_sizes[index] - num_pad = bucket_size - item.size(-1) + def _pad(self, tensor, bucket_size, dim=-1): + num_pad = bucket_size - tensor.size(dim) return F.pad( - item, + tensor, (num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad), value=self.pad_idx, ) + def __getitem__(self, index): + item = self.dataset[index] + bucket_size = self._bucketed_sizes[index] + tensor = self._get_tensor(item) + padded = self._pad(tensor, bucket_size) + return self._set_tensor(item, padded) + @property def sizes(self): return self._bucketed_sizes diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 57991a8802..e6818219cb 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -272,6 +272,7 @@ def post_process(sentence: str, symbol: str): sentence = (sentence + " ").replace(symbol, "").rstrip() return sentence + def compute_mask_indices( shape: Tuple[int, int], padding_mask: Optional[torch.Tensor], @@ -283,6 +284,7 @@ def compute_mask_indices( no_overlap: bool = False, min_space: int = 0, ) -> np.ndarray: +#) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Computes random mask spans for a given shape @@ -393,3 +395,25 @@ def arrange(s, e, length, keep_length): mask[i, mask_idc] = True return mask + + +def get_buckets(sizes, num_buckets): + buckets = np.unique( + np.percentile( + sizes, + np.linspace(0, 100, num_buckets + 1), + interpolation='lower', + )[1:] + ) + return buckets + + +def get_bucketed_sizes(orig_sizes, buckets): + sizes = np.copy(orig_sizes) + assert np.min(sizes) >= 0 + start_val = -1 + for end_val in buckets: + mask = (sizes > start_val) & (sizes <= end_val) + sizes[mask] = end_val + start_val = end_val + return sizes diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 7ee89adce9..0060d8c189 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -113,7 +113,6 @@ def distributed_init(args): args.device_id = xm.get_local_ordinal() args.distributed_rank = xm.get_ordinal() xm.rendezvous('distributed_init') # wait for all workers - xm.mark_step() if is_master(args): logging.getLogger().setLevel(logging.INFO) @@ -182,7 +181,10 @@ def call_main(args, main, **kwargs): xmp.spawn( fn=distributed_main, args=(main, args, kwargs), - nprocs=8, # use all 8 TPU cores + # tpu-comment: + # 8 devices in one TPU VM, is the max processes to be spawned. + # The rest is driven by xm.distributed.xla_dist + nprocs=min(args.distributed_world_size, 8), ) else: # single GPU main diff --git a/fairseq/logging/metrics.py b/fairseq/logging/metrics.py index 6ca1d201e0..0bac4c25b2 100644 --- a/fairseq/logging/metrics.py +++ b/fairseq/logging/metrics.py @@ -289,3 +289,11 @@ def load_state_dict(state_dict): for name, agg_state in state_dict.items(): _aggregators[name] = MetersDict() _aggregators[name].load_state_dict(agg_state) + + +def xla_metrics_report(): + try: + import torch_xla.debug.metrics as met + print(met.metrics_report()) + except ImportError: + return diff --git a/fairseq/metsumm.py b/fairseq/metsumm.py new file mode 100644 index 0000000000..ce0cfa0594 --- /dev/null +++ b/fairseq/metsumm.py @@ -0,0 +1,18 @@ +# FIXME: remove this file +def metsumm(stepno=''): + if hasattr(metsumm, 'STEPNO'): + metsumm.STEPNO += stepno.lower()=="before forward" + else: + metsumm.STEPNO = 0 + try: + import torch_xla.debug.metrics as met + x = met.metrics_report().split('\n') + for i, line in enumerate(x): + if 'CompileTime' in line or 'aten::' in line: + key = line.split()[-1] + value = x[i+1].split()[-1] + print('step {}-{}, key {}, value {}'.format( + metsumm.STEPNO, stepno, key, value) + ) + except RuntimeError: + return diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index 226f035ba8..5196797df2 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -27,7 +27,7 @@ TransposeLast, ) from fairseq.modules.transformer_sentence_encoder import init_bert_params -from fairseq.utils import buffered_arange +from fairseq.utils import buffered_arange, index_put, is_xla_tensor @register_model("wav2vec2") @@ -405,47 +405,66 @@ def build_model(cls, args, task=None): return cls(args) - def apply_mask(self, x, padding_mask): + def apply_mask( + self, x, padding_mask, + mask_indices=None, mask_channel_indices=None, + ): B, T, C = x.shape if self.mask_prob > 0: - mask_indices = compute_mask_indices( - (B, T), - padding_mask, - self.mask_prob, - self.mask_length, - self.mask_selection, - self.mask_other, - min_masks=2, - no_overlap=self.no_mask_overlap, - min_space=self.mask_min_space, - ) - mask_indices = torch.from_numpy(mask_indices).to(x.device) - x[mask_indices] = self.mask_emb + if mask_indices is None: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x = index_put(x, mask_indices, self.mask_emb) else: mask_indices = None if self.mask_channel_prob > 0: - mask_channel_indices = compute_mask_indices( - (B, C), - None, - self.mask_channel_prob, - self.mask_channel_length, - self.mask_channel_selection, - self.mask_channel_other, - no_overlap=self.no_mask_channel_overlap, - min_space=self.mask_channel_min_space, - ) - mask_channel_indices = ( - torch.from_numpy(mask_channel_indices) - .to(x.device) - .unsqueeze(1) - .expand(-1, T, -1) - ) - x[mask_channel_indices] = 0 + if mask_channel_indices is None: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x = index_put(x, mask_channel_indices, 0) return x, mask_indices - def sample_negatives(self, y, num): + def _get_neg_idxs(self, high, size, padding_counts=None, num=None): + if padding_counts is None: + neg_idxs = torch.randint(low=0, high=high-1, size=size) + else: + bsz, l = size + #num = l // self.n_negatives if num is None else num + assert len(padding_counts) == bsz + neg_idxs = [ + torch.randint(low=0, high=high-1, size=(1, l)) + for pc in padding_counts + ] + neg_idxs = torch.stack(neg_idxs) + return neg_idxs + + def sample_negatives(self, y, num, padding_counts=None): if self.n_negatives == 0 and self.cross_sample_negatives == 0: return y.new(0) @@ -454,21 +473,23 @@ def sample_negatives(self, y, num): y = y.view(-1, fsz) # BTC => (BxT)C cross_high = tsz * bsz - high = tsz + high = num + # FIXME: there's a problem here w/ tsz and num + # this assumes y is shrunk at this point. with torch.no_grad(): assert high > 1, f"{bsz,tsz,fsz}" if self.n_negatives > 0: tszs = ( - buffered_arange(num) + buffered_arange(tsz) .unsqueeze(-1) .expand(-1, self.n_negatives) .flatten() ) - neg_idxs = torch.randint( - low=0, high=high - 1, size=(bsz, self.n_negatives * num) - ) + neg_idxs = torch.randint(low=0, high=high-1, size=(bsz, self.n_negatives * tsz)) + import pdb + pdb.set_trace() neg_idxs[neg_idxs >= tszs] += 1 if self.cross_sample_negatives > 0: @@ -495,9 +516,13 @@ def sample_negatives(self, y, num): if self.cross_sample_negatives > 0 and self.n_negatives > 0: neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) + import pdb + pdb.set_trace() negs = y[neg_idxs.view(-1)] + import pdb + pdb.set_trace() negs = negs.view( - bsz, num, self.n_negatives + self.cross_sample_negatives, fsz + bsz, tsz, self.n_negatives + self.cross_sample_negatives, fsz ).permute( 2, 0, 1, 3 ) # to NxBxTxC @@ -511,14 +536,21 @@ def compute_preds(self, x, y, negatives): logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) - logits /= self.logit_temp + logits = logits / self.logit_temp - if neg_is_pos.any(): - logits[1:][neg_is_pos] = float("-inf") + if is_xla_tensor(logits) or neg_is_pos.any(): + fillval = -float(2**30) + if not hasattr(self, '_inftensor'): + self._inftensor = torch.tensor(fillval).to(x.device) + logits[1:] = index_put(logits[1:], neg_is_pos, self._inftensor) return logits - def forward(self, source, padding_mask=None, mask=True, features_only=False): + def forward( + self, source, padding_mask=None, mask=True, features_only=False, + mask_indices=None, mask_channel_indices=None, padding_counts=None, + tszs_after_mask=None, + ): if self.feature_grad_mult > 0: features = self.feature_extractor(source) @@ -562,9 +594,17 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): features = self.project_inp(features) if mask: - x, mask_indices = self.apply_mask(features, padding_mask) - if mask_indices is not None: - y = unmasked_features[mask_indices].view(unmasked_features.size(0), -1, unmasked_features.size(-1)) + x, mask_indices = self.apply_mask( + features, padding_mask, + mask_indices=mask_indices, + mask_channel_indices=mask_channel_indices, + ) + if not is_xla_tensor(x) and mask_indices is not None: + # tpu-comment: reducing the size in a dynamic way causes + # too many recompilations on xla. + y = unmasked_features[mask_indices].view( + unmasked_features.size(0), -1, unmasked_features.size(-1) + ) else: y = unmasked_features else: @@ -587,13 +627,21 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): y = self.project_q(y) + num = y.size(1) if tszs_after_mask is None else max(tszs_after_mask) if self.negatives_from_everywhere: - neg_cands, *_ = self.quantizer(unmasked_features, produce_targets=False) - negs, _ = self.sample_negatives(neg_cands, y.size(1)) + neg_cands, *_ = self.quantizer( + unmasked_features, produce_targets=False, + ) + negs, _ = self.sample_negatives( + neg_cands, num, padding_counts=padding_counts, + ) negs = self.project_q(negs) else: - negs, _ = self.sample_negatives(y, y.size(1)) + negs, _ = self.sample_negatives( + y, num, + padding_counts=padding_counts, + ) if self.codebook_negatives > 0: cb_negs = self.quantizer.sample_from_codebook( @@ -608,17 +656,24 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): y = self.project_q(y) if self.negatives_from_everywhere: - negs, _ = self.sample_negatives(unmasked_features, y.size(1)) + negs, _ = self.sample_negatives( + unmasked_features, num, padding_counts=padding_counts, + ) negs = self.project_q(negs) else: - negs, _ = self.sample_negatives(y, y.size(1)) + negs, _ = self.sample_negatives( + y, num, + padding_counts=padding_counts, + ) - x = x[mask_indices].view(x.size(0), -1, x.size(-1)) + if not is_xla_tensor(x): + # tpu-comment: reducing the size in a dynamic way causes + # too many recompilations on xla. + x = x[mask_indices].view(x.size(0), -1, x.size(-1)) if self.target_glu: y = self.target_glu(y) negs = self.target_glu(negs) - x = self.final_proj(x) x = self.compute_preds(x, y, negs) @@ -811,7 +866,7 @@ def extract_features(self, x, padding_mask=None): x_conv = self.pos_conv(x.transpose(1, 2)) x_conv = x_conv.transpose(1, 2) - x += x_conv + x = x + x_conv if not self.layer_norm_first: x = self.layer_norm(x) diff --git a/fairseq/modules/gumbel_vector_quantizer.py b/fairseq/modules/gumbel_vector_quantizer.py index 01ddd2298b..35a08af166 100644 --- a/fairseq/modules/gumbel_vector_quantizer.py +++ b/fairseq/modules/gumbel_vector_quantizer.py @@ -160,6 +160,7 @@ def forward(self, x, produce_targets=False): avg_probs = torch.softmax( x.view(bsz * tsz, self.groups, -1).float(), dim=-1 ).mean(dim=0) + avg_probs = avg_probs.detach() result["prob_perplexity"] = torch.exp( -torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1) ).sum() diff --git a/fairseq/options.py b/fairseq/options.py index 171c67966d..ce6edc345c 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -238,6 +238,7 @@ def get_parser(desc, default_task="translation"): help='pseudo random number generator seed') parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA') parser.add_argument('--tpu', action='store_true', help='use TPU instead of CUDA') + parser.add_argument('--xla-metrics-debug', action='store_true', help='Print XLA debug info') parser.add_argument('--bf16', action='store_true', help='use bfloat16; implies --tpu') parser.add_argument('--fp16', action='store_true', help='use FP16') parser.add_argument('--memory-efficient-bf16', action='store_true', diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index f33637468f..2683e70578 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -8,7 +8,9 @@ import os import sys -from fairseq.data import FileAudioDataset, Dictionary, AddTargetDataset +from fairseq.data import ( + FileAudioDataset, Dictionary, AddTargetDataset, BucketPadLengthDataset +) from . import FairseqTask, register_task @@ -31,6 +33,14 @@ class AudioPretrainingTask(FairseqTask): @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" + parser.add_argument( + '--num-batch-buckets', default=0, type=int, + help=( + 'if >0, then bucket source and target lengths into N ' + 'buckets and pad accordingly; this is useful on TPUs ' + 'to minimize the number of compilations' + ), + ) parser.add_argument("data", help="path to data directory") parser.add_argument( "--sample-rate", @@ -103,6 +113,9 @@ def load_dataset(self, split, **kwargs): min_length=self.args.min_sample_size, pad=self.args.labels is not None or self.args.enable_padding, normalize=self.args.normalize, + compute_mask_indices=self.args.tpu, + args=self.args, + num_buckets=self.args.num_batch_buckets or int(self.args.tpu), ) if self.args.labels: diff --git a/fairseq/trainer.py b/fairseq/trainer.py index a91d12fdc2..37d578e332 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -347,7 +347,7 @@ def get_train_iterator( num_shards=self.data_parallel_world_size if shard_batch_itr else 1, shard_id=self.data_parallel_rank if shard_batch_itr else 0, num_workers=self.args.num_workers, - epoch=epoch + epoch=epoch, ) def get_valid_iterator( @@ -368,7 +368,7 @@ def get_valid_iterator( seed=self.args.seed, num_shards=self.data_parallel_world_size, shard_id=self.data_parallel_rank, - num_workers=self.args.num_workers + num_workers=self.args.num_workers, ) def begin_epoch(self, epoch): @@ -422,6 +422,7 @@ def maybe_no_sync(): try: with maybe_no_sync(): # forward and backward + loss, sample_size_i, logging_output = self.task.train_step( sample=sample, model=self.model, @@ -462,8 +463,7 @@ def maybe_no_sync(): # before marking step can lead to OOM errors. # To handle gradient accumulation use case, we explicitly # mark step here for every forward pass without a backward pass - import torch_xla.core.xla_model as xm - xm.mark_step() + self._xla_markstep_and_send_to_cpu() if is_dummy_batch: if torch.is_tensor(sample_size): @@ -548,16 +548,23 @@ def maybe_no_sync(): if self.tpu: # mark step on TPUs - import torch_xla.core.xla_model as xm - xm.mark_step() # only log stats every log_interval steps # this causes wps to be misreported when log_interval > 1 logging_output = {} if self.get_num_updates() % self.args.log_interval == 0: + logging_outputs = self._xla_markstep_and_send_to_cpu( + logging_outputs + ) logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) + self._xla_markstep_and_send_to_cpu() + # FIXME: taylan when I put step closure, logging outputs is shrunk.. + #xm.add_step_closure( + # self._reduce_and_log_stats, + # args=(logging_outputs, sample_size, grad_norm) + #) # log whenever there's an XLA compilation, since these # slow down training and may indicate opportunities for @@ -593,9 +600,7 @@ def valid_step(self, sample, raise_oom=False): if self._dummy_batch == "DUMMY": self._dummy_batch = sample if self.tpu: - import torch_xla.core.xla_model as xm - xm.rendezvous('valid_step') # wait for all workers - xm.mark_step() + self._xla_markstep_and_send_to_cpu() with torch.no_grad(): self.model.eval() @@ -641,6 +646,8 @@ def valid_step(self, sample, raise_oom=False): ) # log validation stats + if self.tpu: + logging_outputs = self._xla_markstep_and_send_to_cpu(logging_outputs) logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) return logging_output @@ -778,8 +785,9 @@ def _set_seed(self): utils.set_torch_seed(seed) def _sync_stats(self): - # Return True if it's using multiple GPUs and DDP or multiple GPUs with - # BMUF and it's a bmuf sync with warmup iterations completed before. + # Return True if it's using multiple devices and DDP + # or multiple devices with BMUF and it's a bmuf sync + # with warmup iterations completed before. if self.data_parallel_world_size == 1: return False elif self.args.use_bmuf: @@ -912,7 +920,11 @@ def is_consistent(tensor): ) def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): - if grad_norm is not None: + # tpu-comment: grad_norm is a tensor in XLA + if ( + (not torch.is_tensor(grad_norm) and grad_norm is not None) + or (torch.is_tensor(grad_norm) and not torch.isnan(grad_norm)) + ): metrics.log_speed("ups", 1., priority=100, round=2) metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) if self.args.clip_norm > 0: @@ -968,6 +980,13 @@ def _check_xla_compilation(self, message=None): logging.info("NOTE: XLA compilation detected; {}".format(message)) self._num_xla_compiles = num_xla_compiles + def _xla_markstep_and_send_to_cpu(self, data=None): + import torch_xla.core.xla_model as xm + xm.mark_step() + if data is not None: + from fairseq.utils import xla_device_to_cpu + return xla_device_to_cpu(data) + def _catalog_shared_params(module, memo=None, prefix=''): if memo is None: diff --git a/fairseq/utils.py b/fairseq/utils.py index f68860330c..1af2e2537e 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -81,6 +81,7 @@ def _move_to_cuda(tensor): def move_to_cpu(sample): + def _move_to_cpu(tensor): # PyTorch has poor support for half tensors (float16) on CPU. # Move any such tensors to float32. @@ -252,6 +253,9 @@ def convert_padding_direction( def item(tensor): + # tpu-comment: making this a no-op for xla devices. + if torch.is_tensor(tensor) and tensor.device.type == 'xla': + return tensor.detach() if hasattr(tensor, "item"): return tensor.item() if hasattr(tensor, "__getitem__"): @@ -560,6 +564,27 @@ def get_tpu_device(args): return xm.xla_device() +def is_xla_tensor(tensor): + return torch.is_tensor(tensor) and tensor.device.type == 'xla' + + +def index_put(tensor, indices, value): + if is_xla_tensor(tensor): + for _ in range(indices.dim(), tensor.dim()): + indices = indices.unsqueeze(-1) + if indices.size(-1) < tensor.size(-1): + indices = indices.expand_as(tensor) + tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices) + else: + tensor[indices] = value + return tensor + + +def xla_device_to_cpu(dat): + import torch_xla.core.xla_model as xm + return xm._maybe_convert_to_cpu(dat) + + class CudaEnvironment(object): def __init__(self): cur_device = torch.cuda.current_device() diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 806e4bc54b..832f474678 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -96,7 +96,7 @@ def main(args): "training on {} devices (GPUs/TPUs)".format(args.distributed_world_size) ) logger.info( - "max tokens per GPU = {} and max sentences per GPU = {}".format( + "max tokens per device = {} and max sentences per device = {}".format( args.max_tokens, args.max_sentences ) ) @@ -106,9 +106,7 @@ def main(args): extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) if args.tpu: import torch_xla.core.xla_model as xm - xm.rendezvous("load_checkpoint") # wait for all workers - xm.mark_step() # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf @@ -167,7 +165,6 @@ def tpu_data_loader(args, itr): import torch_xla.distributed.parallel_loader as pl xm.rendezvous("tpu_data_loader") # wait for all workers - xm.mark_step() device = utils.get_tpu_device(args) return iterators.CountingIterator( pl.ParallelLoader(itr, [device]).per_device_loader(device), @@ -211,6 +208,12 @@ def train(args, trainer, task, epoch_itr): should_stop = False num_updates = trainer.get_num_updates() for i, samples in enumerate(progress): + + from fairseq.metsumm import metsumm as m + if not i % 50: + import torch_xla.core.xla_model as xm + if xm.is_master_ordinal(): + m(str(i)) with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i ): @@ -226,7 +229,10 @@ def train(args, trainer, task, epoch_itr): # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved + # FIXME: taylan reset in closure!!!!!!!!!!!!!!!!! metrics.reset_meters("train_inner") + if args.xla_metrics_debug: + metrics.xla_metrics_report() end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save(