From 7e6f3cfbb29d9447be8cd673392c2629b4bef66a Mon Sep 17 00:00:00 2001 From: Vaibhav Singh Date: Tue, 18 Aug 2020 22:09:07 +0000 Subject: [PATCH 01/19] Intermediate experiments wav2vec --- fairseq/criterions/wav2vec_criterion.py | 14 +++++++---- fairseq/data/bucket_pad_length_dataset.py | 9 +++++--- fairseq/distributed_utils.py | 6 +++-- fairseq/logging/meters.py | 3 +++ fairseq/metsumm.py | 12 ++++++++++ fairseq/models/wav2vec/wav2vec2.py | 27 ++++++++++++++++++---- fairseq/modules/gumbel_vector_quantizer.py | 1 + fairseq/modules/multihead_attention.py | 2 ++ fairseq/tasks/audio_pretraining.py | 20 +++++++++++++++- fairseq/trainer.py | 18 +++++++++++++-- fairseq/utils.py | 3 +++ 11 files changed, 98 insertions(+), 17 deletions(-) create mode 100644 fairseq/metsumm.py diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index 019db62249..90f20a8488 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -40,7 +40,11 @@ def forward(self, model, sample, reduce=True, log_pred=False): 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ + from fairseq.metsumm import metsumm + metsumm("Before forward") net_output = model(**sample['net_input']) + metsumm("After forward") + logits = model.get_logits(net_output).float() target = model.get_targets(sample, net_output) @@ -75,7 +79,8 @@ def forward(self, model, sample, reduce=True, log_pred=False): losses.append(p) logging_output = { - 'loss': loss.item() if reduce else loss, + #'loss': losr.item() if reduce else loss, + 'loss': loss, 'ntokens': sample_size, 'nsentences': sample['id'].numel(), 'sample_size': sample_size, @@ -87,7 +92,7 @@ def forward(self, model, sample, reduce=True, log_pred=False): if len(losses) > 1: for i, l in enumerate(losses): - logging_output[f'loss_{i}'] = l.item() + logging_output[f'loss_{i}'] = l if self.infonce: with torch.no_grad(): @@ -99,7 +104,8 @@ def forward(self, model, sample, reduce=True, log_pred=False): max = logits.argmax(-1) == 0 min = logits.argmin(-1) == 0 both = max & min - corr = max.long().sum().item() - both.long().sum().item() + #corr = max.long().sum().item() - both.long().sum().item() + corr = max.long().sum() - both.long().sum() count = max.numel() logging_output["correct"] = corr @@ -132,7 +138,7 @@ def reduce_metrics(logging_outputs) -> None: if total > 0: metrics.log_derived( "accuracy", - lambda meters: round(meters["_correct"].sum / meters["_total"].sum, 5) + lambda meters: meters["_correct"].sum / meters["_total"].sum if meters["_total"].sum > 0 else float("nan"), ) diff --git a/fairseq/data/bucket_pad_length_dataset.py b/fairseq/data/bucket_pad_length_dataset.py index 6f53d01188..256c202318 100644 --- a/fairseq/data/bucket_pad_length_dataset.py +++ b/fairseq/data/bucket_pad_length_dataset.py @@ -58,13 +58,16 @@ def get_bucketed_sizes(orig_sizes, buckets): def __getitem__(self, index): item = self.dataset[index] + source = item['source'] bucket_size = self._bucketed_sizes[index] - num_pad = bucket_size - item.size(-1) - return F.pad( - item, + num_pad = bucket_size - source.size(-1) + result = F.pad( + source, (num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad), value=self.pad_idx, ) + item['source'] = result + return item @property def sizes(self): diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 7ee89adce9..0060d8c189 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -113,7 +113,6 @@ def distributed_init(args): args.device_id = xm.get_local_ordinal() args.distributed_rank = xm.get_ordinal() xm.rendezvous('distributed_init') # wait for all workers - xm.mark_step() if is_master(args): logging.getLogger().setLevel(logging.INFO) @@ -182,7 +181,10 @@ def call_main(args, main, **kwargs): xmp.spawn( fn=distributed_main, args=(main, args, kwargs), - nprocs=8, # use all 8 TPU cores + # tpu-comment: + # 8 devices in one TPU VM, is the max processes to be spawned. + # The rest is driven by xm.distributed.xla_dist + nprocs=min(args.distributed_world_size, 8), ) else: # single GPU main diff --git a/fairseq/logging/meters.py b/fairseq/logging/meters.py index 78e6d4d224..ee903282a0 100644 --- a/fairseq/logging/meters.py +++ b/fairseq/logging/meters.py @@ -51,6 +51,9 @@ def smoothed_value(self) -> float: def safe_round(number, ndigits): + # FIXME: taylan revisit this + import pdb + pdb.set_trace() if hasattr(number, '__round__'): return round(number, ndigits) elif torch is not None and torch.is_tensor(number) and number.numel() == 1: diff --git a/fairseq/metsumm.py b/fairseq/metsumm.py new file mode 100644 index 0000000000..f30386a3d3 --- /dev/null +++ b/fairseq/metsumm.py @@ -0,0 +1,12 @@ + +def metsumm(stepno=''): + try: + import torch_xla.debug.metrics as met + x = met.metrics_report().split('\n') + for i, line in enumerate(x): + if 'CompileTime' in line or 'aten::' in line: + key = line.split()[-1] + value = x[i+1].split()[-1] + print('step {}, key {}, value {}'.format(stepno, key, value)) + except RuntimeError: + return diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index 226f035ba8..50d5e7c591 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -511,10 +511,14 @@ def compute_preds(self, x, y, negatives): logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1).type_as(x) - logits /= self.logit_temp + logits = logits / self.logit_temp - if neg_is_pos.any(): - logits[1:][neg_is_pos] = float("-inf") + if logits.device.type == 'xla' or neg_is_pos.any(): + # FIXME: taylan what is neg_is_pos doing? inspect. + import pdb + pdb.set_trace() + logits = logits + -1.0 * (2**55) * neg_is_pos + #logits[1:][neg_is_pos] = float("-inf") return logits @@ -561,7 +565,13 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): curr_temp = q["temp"] features = self.project_inp(features) + from fairseq.metsumm import metsumm + metsumm("Before mask...") + if mask: + # FIXME: taylan investigate dynamicity. + import pdb + pdb.set_trace() x, mask_indices = self.apply_mask(features, padding_mask) if mask_indices is not None: y = unmasked_features[mask_indices].view(unmasked_features.size(0), -1, unmasked_features.size(-1)) @@ -571,12 +581,14 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): x = features y = unmasked_features mask_indices = None + metsumm("After mask...") x = self.encoder(x, padding_mask=padding_mask) if features_only: return {"x": x, "padding_mask": padding_mask} + metsumm("Before quantizer...") if self.quantizer: q = self.quantizer(y, produce_targets=False) y = q["x"] @@ -613,14 +625,19 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): else: negs, _ = self.sample_negatives(y, y.size(1)) + metsumm("After quantizer...") + # FIXME: taylan mask indices investigate dynamicity + import pdb + pdb.set_trace() x = x[mask_indices].view(x.size(0), -1, x.size(-1)) + metsumm("Before Negs ...") if self.target_glu: y = self.target_glu(y) negs = self.target_glu(negs) - x = self.final_proj(x) x = self.compute_preds(x, y, negs) + metsumm("After compute-pred ...") result = {"x": x, "padding_mask": padding_mask, "features_pen": features_pen} @@ -811,7 +828,7 @@ def extract_features(self, x, padding_mask=None): x_conv = self.pos_conv(x.transpose(1, 2)) x_conv = x_conv.transpose(1, 2) - x += x_conv + x = x + x_conv if not self.layer_norm_first: x = self.layer_norm(x) diff --git a/fairseq/modules/gumbel_vector_quantizer.py b/fairseq/modules/gumbel_vector_quantizer.py index 01ddd2298b..35a08af166 100644 --- a/fairseq/modules/gumbel_vector_quantizer.py +++ b/fairseq/modules/gumbel_vector_quantizer.py @@ -160,6 +160,7 @@ def forward(self, x, produce_targets=False): avg_probs = torch.softmax( x.view(bsz * tsz, self.groups, -1).float(), dim=-1 ).mean(dim=0) + avg_probs = avg_probs.detach() result["prob_perplexity"] = torch.exp( -torch.sum(avg_probs * torch.log(avg_probs + 1e-7), dim=-1) ).sum() diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index e33dd450ee..6f04f8bb07 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -86,6 +86,8 @@ def prepare_for_onnx_export_(self): self.onnx_trace = True def prepare_for_tpu_(self, **kwargs): + print('PREPARING FOR TPU, DELETE ME WHEN U SEE DIS') + raise self.tpu = True def reset_parameters(self): diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index f33637468f..9008c9b5c9 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -8,7 +8,9 @@ import os import sys -from fairseq.data import FileAudioDataset, Dictionary, AddTargetDataset +from fairseq.data import ( + FileAudioDataset, Dictionary, AddTargetDataset, BucketPadLengthDataset +) from . import FairseqTask, register_task @@ -31,6 +33,14 @@ class AudioPretrainingTask(FairseqTask): @staticmethod def add_args(parser): """Add task-specific arguments to the parser.""" + parser.add_argument( + '--num-batch-buckets', default=0, type=int, + help=( + 'if >0, then bucket source and target lengths into N ' + 'buckets and pad accordingly; this is useful on TPUs ' + 'to minimize the number of compilations' + ), + ) parser.add_argument("data", help="path to data directory") parser.add_argument( "--sample-rate", @@ -104,6 +114,14 @@ def load_dataset(self, split, **kwargs): pad=self.args.labels is not None or self.args.enable_padding, normalize=self.args.normalize, ) + if (self.args.num_batch_buckets < 0): + self.datasets[split] = BucketPadLengthDataset( + self.datasets[split], + sizes=self.datasets[split].sizes, + num_buckets=self.args.num_batch_buckets, + pad_idx=0, + left_pad=False, + ) if self.args.labels: dict_path = os.path.join(self.args.data, f"dict.{self.args.labels}.txt") diff --git a/fairseq/trainer.py b/fairseq/trainer.py index a91d12fdc2..9f549916fa 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -21,6 +21,7 @@ from fairseq.logging import meters, metrics from fairseq.nan_detector import NanDetector from fairseq.optim import lr_scheduler +from fairseq.metsumm import metsumm logger = logging.getLogger(__name__) @@ -422,6 +423,7 @@ def maybe_no_sync(): try: with maybe_no_sync(): # forward and backward + metsumm("Before task.train_step") loss, sample_size_i, logging_output = self.task.train_step( sample=sample, model=self.model, @@ -430,6 +432,7 @@ def maybe_no_sync(): update_num=self.get_num_updates(), ignore_grad=is_dummy_batch, ) + metsumm("After task.train_step") del loss logging_outputs.append(logging_output) @@ -491,6 +494,7 @@ def maybe_no_sync(): gradients = xm._fetch_gradients(self.optimizer.optimizer) xm.all_reduce('sum', gradients, scale=1.0 / self.data_parallel_world_size) + metsumm("Before Autograd-profiler-record") with torch.autograd.profiler.record_function("multiply-grads"): # multiply gradients by (# GPUs / sample_size) since DDP # already normalizes by the number of GPUs. Thus we get @@ -504,6 +508,7 @@ def maybe_no_sync(): with torch.autograd.profiler.record_function("clip-grads"): # clip grads grad_norm = self.clip_grad_norm(self.args.clip_norm) + metsumm("After Autograd-profiler-record") # check that grad norms are consistent across workers if ( @@ -513,9 +518,11 @@ def maybe_no_sync(): ): self._check_grad_norms(grad_norm) + metsumm("Before Optimizer-Step") with torch.autograd.profiler.record_function("optimizer"): # take an optimization step self.optimizer.step() + metsumm("After Optimizer-Step") except FloatingPointError: # re-run the forward and backward pass with hooks attached to print # out where it fails @@ -537,27 +544,33 @@ def maybe_no_sync(): raise e # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step + metsumm("Before Additional-Optimizer-Step") if hasattr(self.model, 'perform_additional_optimizer_actions'): if hasattr(self.optimizer, 'fp32_params'): self.model.perform_additional_optimizer_actions(self.optimizer.optimizer, self.optimizer.fp32_params) else: self.model.perform_additional_optimizer_actions(self.optimizer.optimizer) + metsumm("After Additional-Optimizer-Step") if not overflow or self.args.distributed_wrapper == 'SlowMo': self.set_num_updates(self.get_num_updates() + 1) if self.tpu: # mark step on TPUs - import torch_xla.core.xla_model as xm - xm.mark_step() # only log stats every log_interval steps # this causes wps to be misreported when log_interval > 1 logging_output = {} + metsumm("Before reduce-log-stat") if self.get_num_updates() % self.args.log_interval == 0: + metsumm("Before mark-step") + import torch_xla.core.xla_model as xm + xm.mark_step() + metsumm("After mark-step") logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) + metsumm("After reduce-log-stat") # log whenever there's an XLA compilation, since these # slow down training and may indicate opportunities for @@ -824,6 +837,7 @@ def _all_gather_list_sync( suitable when logging outputs are complex types. """ if self.tpu: + # FIXME: taylan - all gather etc. raise NotImplementedError if ignore: logging_outputs = [] diff --git a/fairseq/utils.py b/fairseq/utils.py index f68860330c..cd8952760c 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -252,6 +252,9 @@ def convert_padding_direction( def item(tensor): + # tpu-comment: making this a no-op for xla devices. + if tensor.device.type == 'xla': + return tensor if hasattr(tensor, "item"): return tensor.item() if hasattr(tensor, "__getitem__"): From 6a03e5c4c27b0efd83c282a0f1488890d9ad7039 Mon Sep 17 00:00:00 2001 From: vaibhav singh Date: Thu, 20 Aug 2020 17:13:01 +0000 Subject: [PATCH 02/19] input shape temp update --- fairseq/models/wav2vec/wav2vec2.py | 1 - fairseq/tasks/audio_pretraining.py | 2 +- fairseq/trainer.py | 22 ++++++++++------------ 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index 50d5e7c591..a69b36e515 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -523,7 +523,6 @@ def compute_preds(self, x, y, negatives): return logits def forward(self, source, padding_mask=None, mask=True, features_only=False): - if self.feature_grad_mult > 0: features = self.feature_extractor(source) if self.feature_grad_mult != 1.0: diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 9008c9b5c9..9b73557f16 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -114,7 +114,7 @@ def load_dataset(self, split, **kwargs): pad=self.args.labels is not None or self.args.enable_padding, normalize=self.args.normalize, ) - if (self.args.num_batch_buckets < 0): + if self.args.num_batch_buckets > 0: self.datasets[split] = BucketPadLengthDataset( self.datasets[split], sizes=self.datasets[split].sizes, diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 9f549916fa..a92385df84 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -40,6 +40,8 @@ class Trainer(object): def __init__(self, args, task, model, criterion, quantizer=None): self.args = args self.task = task + self.logging_history = [] + self.cumm_sample_size = 0 # catalog shared parameters shared_params = _catalog_shared_params(model) @@ -423,7 +425,6 @@ def maybe_no_sync(): try: with maybe_no_sync(): # forward and backward - metsumm("Before task.train_step") loss, sample_size_i, logging_output = self.task.train_step( sample=sample, model=self.model, @@ -432,7 +433,6 @@ def maybe_no_sync(): update_num=self.get_num_updates(), ignore_grad=is_dummy_batch, ) - metsumm("After task.train_step") del loss logging_outputs.append(logging_output) @@ -481,6 +481,10 @@ def maybe_no_sync(): # gather logging outputs from all replicas if self._sync_stats(): + # FIXME: taylan is this a problem for tpu? + # FIXME: taylan maybe backward first, then sync stats? + import pdb + pdb.set_trace() train_time = self._local_cumulative_training_time() logging_outputs, (sample_size, ooms, total_train_time) = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch, @@ -494,7 +498,6 @@ def maybe_no_sync(): gradients = xm._fetch_gradients(self.optimizer.optimizer) xm.all_reduce('sum', gradients, scale=1.0 / self.data_parallel_world_size) - metsumm("Before Autograd-profiler-record") with torch.autograd.profiler.record_function("multiply-grads"): # multiply gradients by (# GPUs / sample_size) since DDP # already normalizes by the number of GPUs. Thus we get @@ -508,7 +511,6 @@ def maybe_no_sync(): with torch.autograd.profiler.record_function("clip-grads"): # clip grads grad_norm = self.clip_grad_norm(self.args.clip_norm) - metsumm("After Autograd-profiler-record") # check that grad norms are consistent across workers if ( @@ -518,11 +520,9 @@ def maybe_no_sync(): ): self._check_grad_norms(grad_norm) - metsumm("Before Optimizer-Step") with torch.autograd.profiler.record_function("optimizer"): # take an optimization step self.optimizer.step() - metsumm("After Optimizer-Step") except FloatingPointError: # re-run the forward and backward pass with hooks attached to print # out where it fails @@ -544,13 +544,11 @@ def maybe_no_sync(): raise e # Some distributed wrappers (e.g., SlowMo) need access to the optimizer after the step - metsumm("Before Additional-Optimizer-Step") if hasattr(self.model, 'perform_additional_optimizer_actions'): if hasattr(self.optimizer, 'fp32_params'): self.model.perform_additional_optimizer_actions(self.optimizer.optimizer, self.optimizer.fp32_params) else: self.model.perform_additional_optimizer_actions(self.optimizer.optimizer) - metsumm("After Additional-Optimizer-Step") if not overflow or self.args.distributed_wrapper == 'SlowMo': self.set_num_updates(self.get_num_updates() + 1) @@ -560,17 +558,17 @@ def maybe_no_sync(): # only log stats every log_interval steps # this causes wps to be misreported when log_interval > 1 + self.logging_history.extend(logging_outputs) + self.cumm_sample_size += sample_size logging_output = {} - metsumm("Before reduce-log-stat") if self.get_num_updates() % self.args.log_interval == 0: - metsumm("Before mark-step") import torch_xla.core.xla_model as xm xm.mark_step() - metsumm("After mark-step") logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) - metsumm("After reduce-log-stat") + self.logging_history = [] + self.cumm_sample_size = 0 # log whenever there's an XLA compilation, since these # slow down training and may indicate opportunities for From 433ca7661b960b1f526e136acce671a02a30d35a Mon Sep 17 00:00:00 2001 From: vaibhav singh Date: Thu, 20 Aug 2020 22:17:49 +0000 Subject: [PATCH 03/19] clean up --- fairseq/trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fairseq/trainer.py b/fairseq/trainer.py index a92385df84..85c1a841ef 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -425,6 +425,9 @@ def maybe_no_sync(): try: with maybe_no_sync(): # forward and backward + # FIXME: remove + print("DEBUG_MESSAGE: Sample-Size: ", sample['net_input']['source'].size()) + loss, sample_size_i, logging_output = self.task.train_step( sample=sample, model=self.model, From 19535a257544281d2a7efe96db8fdb009bf75859 Mon Sep 17 00:00:00 2001 From: Vaibhav Singh Date: Thu, 20 Aug 2020 22:19:49 +0000 Subject: [PATCH 04/19] dataset updates --- fairseq/data/audio/raw_audio_dataset.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index 675b095647..37a0c370de 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -171,6 +171,9 @@ def __init__( self.sizes.append(sz) logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples") + def get_batch_shapes(self): + return eval(args.batch_shapes) + def __getitem__(self, index): import soundfile as sf From a5005879888bb2a8fad60336f348cbb8aa48079a Mon Sep 17 00:00:00 2001 From: Vaibhav Singh Date: Thu, 20 Aug 2020 23:31:43 +0000 Subject: [PATCH 05/19] clean up --- fairseq/criterions/wav2vec_criterion.py | 6 ++++++ fairseq/data/audio/raw_audio_dataset.py | 3 --- fairseq/data/bucket_pad_length_dataset.py | 3 +++ fairseq/tasks/audio_pretraining.py | 1 + fairseq/trainer.py | 3 +++ 5 files changed, 13 insertions(+), 3 deletions(-) diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index 90f20a8488..9dcaa53723 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -124,6 +124,9 @@ def reduce_metrics(logging_outputs) -> None: nsentences = utils.item(sum(log.get('nsentences', 0) for log in logging_outputs)) sample_size = utils.item(sum(log.get('sample_size', 0) for log in logging_outputs)) + # FIXME: taylan sample_size could be a tensor, rounding could be a problem + import pdb + pdb.set_trace() metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3) metrics.log_scalar('ntokens', ntokens) metrics.log_scalar('nsentences', nsentences) @@ -151,6 +154,9 @@ def reduce_metrics(logging_outputs) -> None: if k.startswith('loss'): metrics.log_scalar(k, val / sample_size / math.log(2), sample_size) else: + # FIXME: taylan, round=3 could be a problem + import pdb + pdb.set_trace() metrics.log_scalar(k, val, round=3) @staticmethod diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index 37a0c370de..675b095647 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -171,9 +171,6 @@ def __init__( self.sizes.append(sz) logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples") - def get_batch_shapes(self): - return eval(args.batch_shapes) - def __getitem__(self, index): import soundfile as sf diff --git a/fairseq/data/bucket_pad_length_dataset.py b/fairseq/data/bucket_pad_length_dataset.py index 256c202318..f4b4fb37c1 100644 --- a/fairseq/data/bucket_pad_length_dataset.py +++ b/fairseq/data/bucket_pad_length_dataset.py @@ -67,6 +67,9 @@ def __getitem__(self, index): value=self.pad_idx, ) item['source'] = result + # FIXME: taylan do we return item or F.pad? + import pdb + pdb.set_trace() return item @property diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 9b73557f16..626b58d1ac 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -113,6 +113,7 @@ def load_dataset(self, split, **kwargs): min_length=self.args.min_sample_size, pad=self.args.labels is not None or self.args.enable_padding, normalize=self.args.normalize, + num_batch_buckets=self.args.num_batch_buckets ) if self.args.num_batch_buckets > 0: self.datasets[split] = BucketPadLengthDataset( diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 85c1a841ef..40cb56d7ba 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -928,6 +928,9 @@ def is_consistent(tensor): def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): if grad_norm is not None: + # FIXME: taylan what to do here? torch.clamp? + import pdb + pdb.set_trace() metrics.log_speed("ups", 1., priority=100, round=2) metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) if self.args.clip_norm > 0: From b49f1033cf6aa24428cd10a720ca405762768d35 Mon Sep 17 00:00:00 2001 From: kevinmtian Date: Wed, 2 Sep 2020 17:49:10 +0000 Subject: [PATCH 06/19] move tensor idx to matrix op inside apply_mask --- fairseq/criterions/wav2vec_criterion.py | 2 +- fairseq/data/audio/raw_audio_dataset.py | 1 - fairseq/data/fairseq_dataset.py | 5 +++-- fairseq/tasks/audio_pretraining.py | 1 - fairseq/trainer.py | 4 ++-- 5 files changed, 6 insertions(+), 7 deletions(-) diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index 9dcaa53723..a21a98d847 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -104,7 +104,7 @@ def forward(self, model, sample, reduce=True, log_pred=False): max = logits.argmax(-1) == 0 min = logits.argmin(-1) == 0 both = max & min - #corr = max.long().sum().item() - both.long().sum().item() + # corr = max.long().sum().item() - both.long().sum().item() corr = max.long().sum() - both.long().sum() count = max.numel() diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index 675b095647..71594550f1 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -154,7 +154,6 @@ def __init__( pad=pad, normalize=normalize, ) - self.fnames = [] skipped = 0 diff --git a/fairseq/data/fairseq_dataset.py b/fairseq/data/fairseq_dataset.py index 5786d5c851..6e9591ee8c 100644 --- a/fairseq/data/fairseq_dataset.py +++ b/fairseq/data/fairseq_dataset.py @@ -124,8 +124,9 @@ def adjust_bsz(bsz, num_tokens): class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening): - """For datasets that need to be read sequentially, usually because the data - is being streamed or otherwise can't be manipulated on a single machine. + """ + For datasets that need to be read sequentially, usually because the data is + being streamed or otherwise can't be manipulated on a single machine. """ def __iter__(self): diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 626b58d1ac..9b73557f16 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -113,7 +113,6 @@ def load_dataset(self, split, **kwargs): min_length=self.args.min_sample_size, pad=self.args.labels is not None or self.args.enable_padding, normalize=self.args.normalize, - num_batch_buckets=self.args.num_batch_buckets ) if self.args.num_batch_buckets > 0: self.datasets[split] = BucketPadLengthDataset( diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 40cb56d7ba..d744cc7ab7 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -350,7 +350,7 @@ def get_train_iterator( num_shards=self.data_parallel_world_size if shard_batch_itr else 1, shard_id=self.data_parallel_rank if shard_batch_itr else 0, num_workers=self.args.num_workers, - epoch=epoch + epoch=epoch, ) def get_valid_iterator( @@ -371,7 +371,7 @@ def get_valid_iterator( seed=self.args.seed, num_shards=self.data_parallel_world_size, shard_id=self.data_parallel_rank, - num_workers=self.args.num_workers + num_workers=self.args.num_workers, ) def begin_epoch(self, epoch): From 73e2f3b7128c9cc58f43757776f5272a219b0d2c Mon Sep 17 00:00:00 2001 From: kevinmtian Date: Thu, 3 Sep 2020 07:59:28 +0000 Subject: [PATCH 07/19] use tensor operators to replace tensor indexing, passed consistency test verification --- fairseq/data/audio/raw_audio_dataset.py | 1 + fairseq/data/data_utils.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+) diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index 71594550f1..675b095647 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -154,6 +154,7 @@ def __init__( pad=pad, normalize=normalize, ) + self.fnames = [] skipped = 0 diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 57991a8802..e3cee19ff8 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -283,6 +283,7 @@ def compute_mask_indices( no_overlap: bool = False, min_space: int = 0, ) -> np.ndarray: +#) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ Computes random mask spans for a given shape @@ -393,3 +394,20 @@ def arrange(s, e, length, keep_length): mask[i, mask_idc] = True return mask + # FIXME: taylan remove this + """ + left_mask, right_mask = [], [] + for i, mask_idc in enumerate(mask_idcs): + if len(mask_idc) > min_len: + mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + for idc in np.sort(mask_idc): + l_mask = [False] * bsz + l_mask[i] = True + r_mask = [False] * all_sz + r_mask[idc] = True + left_mask.append(l_mask) + right_mask.append(r_mask) + + return mask, np.array(left_mask), np.array(right_mask) + """ From 8802a310033aab94d661a659c5c68bf64daacfc2 Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Tue, 6 Oct 2020 16:55:46 +0000 Subject: [PATCH 08/19] Minor improvements --- fairseq/data/bucket_pad_length_dataset.py | 6 +----- fairseq/data/fairseq_dataset.py | 5 ++--- fairseq/modules/multihead_attention.py | 2 -- fairseq/trainer.py | 3 --- fairseq_cli/train.py | 6 +++--- 5 files changed, 6 insertions(+), 16 deletions(-) diff --git a/fairseq/data/bucket_pad_length_dataset.py b/fairseq/data/bucket_pad_length_dataset.py index f4b4fb37c1..4bb5ec3df8 100644 --- a/fairseq/data/bucket_pad_length_dataset.py +++ b/fairseq/data/bucket_pad_length_dataset.py @@ -61,15 +61,11 @@ def __getitem__(self, index): source = item['source'] bucket_size = self._bucketed_sizes[index] num_pad = bucket_size - source.size(-1) - result = F.pad( + item['source'] = F.pad( source, (num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad), value=self.pad_idx, ) - item['source'] = result - # FIXME: taylan do we return item or F.pad? - import pdb - pdb.set_trace() return item @property diff --git a/fairseq/data/fairseq_dataset.py b/fairseq/data/fairseq_dataset.py index 6e9591ee8c..5786d5c851 100644 --- a/fairseq/data/fairseq_dataset.py +++ b/fairseq/data/fairseq_dataset.py @@ -124,9 +124,8 @@ def adjust_bsz(bsz, num_tokens): class FairseqIterableDataset(torch.utils.data.IterableDataset, EpochListening): - """ - For datasets that need to be read sequentially, usually because the data is - being streamed or otherwise can't be manipulated on a single machine. + """For datasets that need to be read sequentially, usually because the data + is being streamed or otherwise can't be manipulated on a single machine. """ def __iter__(self): diff --git a/fairseq/modules/multihead_attention.py b/fairseq/modules/multihead_attention.py index 6f04f8bb07..e33dd450ee 100644 --- a/fairseq/modules/multihead_attention.py +++ b/fairseq/modules/multihead_attention.py @@ -86,8 +86,6 @@ def prepare_for_onnx_export_(self): self.onnx_trace = True def prepare_for_tpu_(self, **kwargs): - print('PREPARING FOR TPU, DELETE ME WHEN U SEE DIS') - raise self.tpu = True def reset_parameters(self): diff --git a/fairseq/trainer.py b/fairseq/trainer.py index d744cc7ab7..5c3b69fa1f 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -21,7 +21,6 @@ from fairseq.logging import meters, metrics from fairseq.nan_detector import NanDetector from fairseq.optim import lr_scheduler -from fairseq.metsumm import metsumm logger = logging.getLogger(__name__) @@ -425,8 +424,6 @@ def maybe_no_sync(): try: with maybe_no_sync(): # forward and backward - # FIXME: remove - print("DEBUG_MESSAGE: Sample-Size: ", sample['net_input']['source'].size()) loss, sample_size_i, logging_output = self.task.train_step( sample=sample, diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 806e4bc54b..aa61d74599 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -106,9 +106,7 @@ def main(args): extra_state, epoch_itr = checkpoint_utils.load_checkpoint(args, trainer) if args.tpu: import torch_xla.core.xla_model as xm - xm.rendezvous("load_checkpoint") # wait for all workers - xm.mark_step() # Train until the learning rate gets too small max_epoch = args.max_epoch or math.inf @@ -167,7 +165,6 @@ def tpu_data_loader(args, itr): import torch_xla.distributed.parallel_loader as pl xm.rendezvous("tpu_data_loader") # wait for all workers - xm.mark_step() device = utils.get_tpu_device(args) return iterators.CountingIterator( pl.ParallelLoader(itr, [device]).per_device_loader(device), @@ -211,9 +208,12 @@ def train(args, trainer, task, epoch_itr): should_stop = False num_updates = trainer.get_num_updates() for i, samples in enumerate(progress): + with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i ): + # FIXME: first iterate and check bszs + raise RuntimeError('first iterate and check bszs') log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue From dc6059293470abff07e90f2f708536644ce788e9 Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Wed, 7 Oct 2020 20:33:27 +0000 Subject: [PATCH 09/19] Fix bucketpadlendataset --- fairseq/data/bucket_pad_length_dataset.py | 25 +++++++++++++++-------- fairseq/tasks/audio_pretraining.py | 7 +++++++ fairseq_cli/train.py | 5 +++-- 3 files changed, 27 insertions(+), 10 deletions(-) diff --git a/fairseq/data/bucket_pad_length_dataset.py b/fairseq/data/bucket_pad_length_dataset.py index 4bb5ec3df8..423a89da2c 100644 --- a/fairseq/data/bucket_pad_length_dataset.py +++ b/fairseq/data/bucket_pad_length_dataset.py @@ -30,6 +30,8 @@ def __init__( num_buckets, pad_idx, left_pad, + lambda_get=None, + lambda_set=None, ): super().__init__(dataset) self.pad_idx = pad_idx @@ -55,18 +57,25 @@ def get_bucketed_sizes(orig_sizes, buckets): return sizes self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets) + self._get_tensor = (lambda x: x) if lambda_get is None else lambda_get + self._set_tensor = ( + (lambda item, val: val) if lambda_set is None else lambda_set + ) - def __getitem__(self, index): - item = self.dataset[index] - source = item['source'] - bucket_size = self._bucketed_sizes[index] - num_pad = bucket_size - source.size(-1) - item['source'] = F.pad( - source, + def _pad(self, tensor, bucket_size, dim=-1): + num_pad = bucket_size - tensor.size(dim) + return F.pad( + tensor, (num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad), value=self.pad_idx, ) - return item + + def __getitem__(self, index): + item = self.dataset[index] + bucket_size = self._bucketed_sizes[index] + tensor = self._get_tensor(item) + padded = self._pad(tensor, bucket_size) + return self._set_tensor(item, padded) @property def sizes(self): diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index 9b73557f16..eafc4e794c 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -121,6 +121,13 @@ def load_dataset(self, split, **kwargs): num_buckets=self.args.num_batch_buckets, pad_idx=0, left_pad=False, + lambda_get=lambda item: item['source'], + lambda_set=( + lambda item, val: { + k: v if k != 'source' else val + for k, v in item.items() + } + ), ) if self.args.labels: diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index aa61d74599..c3baeb5267 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -208,12 +208,13 @@ def train(args, trainer, task, epoch_itr): should_stop = False num_updates = trainer.get_num_updates() for i, samples in enumerate(progress): + # FIXME: delete these in the end + #print('SHAPE', i, samples[0]['net_input']['source'].shape) + #continue with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i ): - # FIXME: first iterate and check bszs - raise RuntimeError('first iterate and check bszs') log_output = trainer.train_step(samples) if log_output is None: # OOM, overflow, ... continue From be3ca6ace93d0e07987278f9d211c6f35482dd27 Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Tue, 13 Oct 2020 23:26:40 +0000 Subject: [PATCH 10/19] Moved mask matrices creation to dataset prep. --- fairseq/data/audio/raw_audio_dataset.py | 77 +++++++++++++++++++++- fairseq/data/data_utils.py | 13 ++++ fairseq/models/wav2vec/wav2vec2.py | 86 ++++++++++++++----------- fairseq/tasks/audio_pretraining.py | 7 +- 4 files changed, 142 insertions(+), 41 deletions(-) diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index 675b095647..bf35b28bfb 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -12,7 +12,8 @@ import torch import torch.nn.functional as F -from .. import FairseqDataset +from .. import FairseqDataset, BaseWrapperDataset +from ..data_utils import compute_mask_indices logger = logging.getLogger(__name__) @@ -27,6 +28,8 @@ def __init__( min_length=0, pad=False, normalize=False, + compute_mask_indices=False, + args=None, ): super().__init__() @@ -40,6 +43,12 @@ def __init__( self.pad = pad self.shuffle = shuffle self.normalize = normalize + self.compute_mask_indices = compute_mask_indices + if self.compute_mask_indices: + self.args = args + self._features_size_map = {} + self._C = self.args.encoder_embed_dim + self._conv_feature_layers = eval(self.args.conv_feature_layers) def __getitem__(self, index): raise NotImplementedError() @@ -71,6 +80,42 @@ def crop_to_max_size(self, wav, target_size): end = size - diff + start return wav[start:end] + def _compute_mask_indices(self, dims, padding_mask): + B, T, C = dims + mask_indices, mask_channel_indices = None, None + if self.args.mask_prob > 0: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.args.mask_prob, + self.args.mask_length, + self.args.mask_selection, + self.args.mask_other, + min_masks=2, + no_overlap=self.args.no_mask_overlap, + min_space=self.args.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices) + if self.args.mask_channel_prob > 0: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.args.mask_channel_prob, + self.args.mask_channel_length, + self.args.mask_channel_selection, + self.args.mask_channel_other, + no_overlap=self.args.no_mask_channel_overlap, + min_space=self.args.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .unsqueeze(1) + .expand(-1, T, -1) + ) + + return mask_indices, mask_channel_indices + + def collater(self, samples): samples = [ s @@ -108,7 +153,29 @@ def collater(self, samples): input = {"source": collated_sources} if self.pad: input["padding_mask"] = padding_mask - return {"id": torch.LongTensor([s["id"] for s in samples]), "net_input": input} + + if self.compute_mask_indices: + B = collated_sources.size(0) + T = self._get_mask_indices_dims(collated_sources.size(-1)) + mask_indices, mask_channel_indices = self._compute_mask_indices( + (B, T, self._C), padding_mask + ) + input["mask_indices"] = mask_indices + input["mask_channel_indices"] = mask_channel_indices + return { + "id": torch.LongTensor([s["id"] for s in samples]), + "net_input": input, + } + + def _get_mask_indices_dims(self, size, padding=0, dilation=1): + if size not in self._features_size_map: + L_in = size + for (_, kernel_size, stride) in self._conv_feature_layers: + L_out = L_in + 2*padding - dilation*(kernel_size-1) - 1 + L_out = 1 + L_out // stride + L_in = L_out + self._features_size_map[size] = L_out + return self._features_size_map[size] def num_tokens(self, index): return self.size(index) @@ -144,6 +211,8 @@ def __init__( min_length=0, pad=False, normalize=False, + compute_mask_indices=False, + args=None, ): super().__init__( sample_rate=sample_rate, @@ -153,6 +222,8 @@ def __init__( min_length=min_length, pad=pad, normalize=normalize, + compute_mask_indices=compute_mask_indices, + args=args, ) self.fnames = [] @@ -179,3 +250,5 @@ def __getitem__(self, index): feats = torch.from_numpy(wav).float() feats = self.postprocess(feats, curr_sample_rate) return {"id": index, "source": feats} + + diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index e3cee19ff8..c9f66e17a4 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -272,6 +272,19 @@ def post_process(sentence: str, symbol: str): sentence = (sentence + " ").replace(symbol, "").rstrip() return sentence + +def index_put(tensor, indices, value): + if tensor.device.type == 'xla': + for _ in range(indices.dim(), tensor.dim()): + indices = indices.unsqueeze(-1) + if indices.size(-1) < tensor.size(-1): + indices = indices.expand_as(tensor) + tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices) + else: + tensor[indices] = value + return tensor + + def compute_mask_indices( shape: Tuple[int, int], padding_mask: Optional[torch.Tensor], diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index a69b36e515..b5552c4b2f 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -14,7 +14,7 @@ from typing import List, Tuple from fairseq import utils -from fairseq.data.data_utils import compute_mask_indices +from fairseq.data.data_utils import compute_mask_indices, index_put from fairseq.models import BaseFairseqModel, register_model, register_model_architecture from fairseq.modules import ( Fp32GroupNorm, @@ -405,43 +405,48 @@ def build_model(cls, args, task=None): return cls(args) - def apply_mask(self, x, padding_mask): + def apply_mask( + self, x, padding_mask, + mask_indices=None, mask_channel_indices=None, + ): B, T, C = x.shape if self.mask_prob > 0: - mask_indices = compute_mask_indices( - (B, T), - padding_mask, - self.mask_prob, - self.mask_length, - self.mask_selection, - self.mask_other, - min_masks=2, - no_overlap=self.no_mask_overlap, - min_space=self.mask_min_space, - ) - mask_indices = torch.from_numpy(mask_indices).to(x.device) - x[mask_indices] = self.mask_emb + if mask_indices is None: + mask_indices = compute_mask_indices( + (B, T), + padding_mask, + self.mask_prob, + self.mask_length, + self.mask_selection, + self.mask_other, + min_masks=2, + no_overlap=self.no_mask_overlap, + min_space=self.mask_min_space, + ) + mask_indices = torch.from_numpy(mask_indices).to(x.device) + x = index_put(x, mask_indices, self.mask_emb) else: mask_indices = None if self.mask_channel_prob > 0: - mask_channel_indices = compute_mask_indices( - (B, C), - None, - self.mask_channel_prob, - self.mask_channel_length, - self.mask_channel_selection, - self.mask_channel_other, - no_overlap=self.no_mask_channel_overlap, - min_space=self.mask_channel_min_space, - ) - mask_channel_indices = ( - torch.from_numpy(mask_channel_indices) - .to(x.device) - .unsqueeze(1) - .expand(-1, T, -1) - ) - x[mask_channel_indices] = 0 + if mask_channel_indices is None: + mask_channel_indices = compute_mask_indices( + (B, C), + None, + self.mask_channel_prob, + self.mask_channel_length, + self.mask_channel_selection, + self.mask_channel_other, + no_overlap=self.no_mask_channel_overlap, + min_space=self.mask_channel_min_space, + ) + mask_channel_indices = ( + torch.from_numpy(mask_channel_indices) + .to(x.device) + .unsqueeze(1) + .expand(-1, T, -1) + ) + x = index_put(x, mask_channel_indices, 0) return x, mask_indices @@ -522,7 +527,13 @@ def compute_preds(self, x, y, negatives): return logits - def forward(self, source, padding_mask=None, mask=True, features_only=False): + def forward( + self, source, padding_mask=None, mask=True, features_only=False, + mask_indices=None, mask_channel_indices=None, + ): + import pdb + pdb.set_trace() + if self.feature_grad_mult > 0: features = self.feature_extractor(source) if self.feature_grad_mult != 1.0: @@ -568,10 +579,11 @@ def forward(self, source, padding_mask=None, mask=True, features_only=False): metsumm("Before mask...") if mask: - # FIXME: taylan investigate dynamicity. - import pdb - pdb.set_trace() - x, mask_indices = self.apply_mask(features, padding_mask) + x, mask_indices = self.apply_mask( + features, padding_mask, + mask_indices=mask_indices, + mask_channel_indices=mask_channel_indices, + ) if mask_indices is not None: y = unmasked_features[mask_indices].view(unmasked_features.size(0), -1, unmasked_features.size(-1)) else: diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index eafc4e794c..d298c931ee 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -113,12 +113,15 @@ def load_dataset(self, split, **kwargs): min_length=self.args.min_sample_size, pad=self.args.labels is not None or self.args.enable_padding, normalize=self.args.normalize, + compute_mask_indices=self.args.tpu, + args=self.args, ) - if self.args.num_batch_buckets > 0: + if self.args.num_batch_buckets > 0 or self.args.tpu: + # Always bucket for tpus. self.datasets[split] = BucketPadLengthDataset( self.datasets[split], sizes=self.datasets[split].sizes, - num_buckets=self.args.num_batch_buckets, + num_buckets=self.args.num_batch_buckets or 1, pad_idx=0, left_pad=False, lambda_get=lambda item: item['source'], From 30f37612145142ae36aa0e49153615dbf404ac06 Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Thu, 15 Oct 2020 17:09:31 +0000 Subject: [PATCH 11/19] Remove dynamism, apply mask correctly, add some guardrails, some cleanups. --- fairseq/criterions/wav2vec_criterion.py | 17 +++++++----- fairseq/distributed_utils.py | 2 ++ fairseq/logging/meters.py | 3 --- fairseq/metsumm.py | 9 +++++-- fairseq/models/wav2vec/wav2vec2.py | 36 +++++++++++-------------- fairseq/trainer.py | 25 +++++++++-------- fairseq/utils.py | 4 ++- fairseq_cli/train.py | 2 +- 8 files changed, 53 insertions(+), 45 deletions(-) diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index a21a98d847..a2ee0d465e 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -40,10 +40,9 @@ def forward(self, model, sample, reduce=True, log_pred=False): 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ - from fairseq.metsumm import metsumm - metsumm("Before forward") + # FIXME: taylan clean metsumm + from fairseq.metsumm import metsumm; metsumm("Before forward") net_output = model(**sample['net_input']) - metsumm("After forward") logits = model.get_logits(net_output).float() target = model.get_targets(sample, net_output) @@ -61,7 +60,11 @@ def forward(self, model, sample, reduce=True, log_pred=False): else: loss = F.binary_cross_entropy_with_logits(logits, target.float(), weights, reduction="sum" if reduce else "none",) - sample_size = target.numel() if self.infonce else target.long().sum().item() + if 'mask_indices' in sample['net_input'] and self.infonce: + # XXX: what happens if not self.infonce? + sample_size = sample['net_input']['mask_indices'].sum() + else: + sample_size = target.numel() if self.infonce else target.long().sum() losses.append(loss) if self.loss_weights is not None: @@ -112,6 +115,8 @@ def forward(self, model, sample, reduce=True, log_pred=False): logging_output["count"] = count if log_pred: + # FIXME: taylan remove this. + raise logging_output['logits'] = logits.cpu().numpy() logging_output['target'] = target.cpu().numpy() return loss, sample_size, logging_output @@ -124,9 +129,6 @@ def reduce_metrics(logging_outputs) -> None: nsentences = utils.item(sum(log.get('nsentences', 0) for log in logging_outputs)) sample_size = utils.item(sum(log.get('sample_size', 0) for log in logging_outputs)) - # FIXME: taylan sample_size could be a tensor, rounding could be a problem - import pdb - pdb.set_trace() metrics.log_scalar('loss', loss_sum / sample_size / math.log(2), sample_size, round=3) metrics.log_scalar('ntokens', ntokens) metrics.log_scalar('nsentences', nsentences) @@ -155,6 +157,7 @@ def reduce_metrics(logging_outputs) -> None: metrics.log_scalar(k, val / sample_size / math.log(2), sample_size) else: # FIXME: taylan, round=3 could be a problem + # XXX: we dont hit this in this workload import pdb pdb.set_trace() metrics.log_scalar(k, val, round=3) diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index 0060d8c189..f2150b52b2 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -252,6 +252,8 @@ def all_gather_list(data, group=None, max_size=16384): all_reduce(buffer, group=group) + # FIXME: taylan remogve + raise buffer = buffer.cpu() try: result = [] diff --git a/fairseq/logging/meters.py b/fairseq/logging/meters.py index ee903282a0..78e6d4d224 100644 --- a/fairseq/logging/meters.py +++ b/fairseq/logging/meters.py @@ -51,9 +51,6 @@ def smoothed_value(self) -> float: def safe_round(number, ndigits): - # FIXME: taylan revisit this - import pdb - pdb.set_trace() if hasattr(number, '__round__'): return round(number, ndigits) elif torch is not None and torch.is_tensor(number) and number.numel() == 1: diff --git a/fairseq/metsumm.py b/fairseq/metsumm.py index f30386a3d3..9d83e9c016 100644 --- a/fairseq/metsumm.py +++ b/fairseq/metsumm.py @@ -1,5 +1,8 @@ - def metsumm(stepno=''): + if hasattr(metsumm, 'STEPNO'): + metsumm.STEPNO += stepno.lower()=="before forward" + else: + metsumm.STEPNO = 0 try: import torch_xla.debug.metrics as met x = met.metrics_report().split('\n') @@ -7,6 +10,8 @@ def metsumm(stepno=''): if 'CompileTime' in line or 'aten::' in line: key = line.split()[-1] value = x[i+1].split()[-1] - print('step {}, key {}, value {}'.format(stepno, key, value)) + print('step {}-{}, key {}, value {}'.format( + metsumm.STEPNO, stepno, key, value) + ) except RuntimeError: return diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index b5552c4b2f..16d5b9eba8 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -519,11 +519,13 @@ def compute_preds(self, x, y, negatives): logits = logits / self.logit_temp if logits.device.type == 'xla' or neg_is_pos.any(): - # FIXME: taylan what is neg_is_pos doing? inspect. - import pdb - pdb.set_trace() - logits = logits + -1.0 * (2**55) * neg_is_pos + #pass + fillval = float(2**30) + if not hasattr(self, '_inftensor'): + self._inftensor = torch.tensor(fillval).to(x.device) + logits[1:] = index_put(logits[1:], neg_is_pos, self._inftensor) #logits[1:][neg_is_pos] = float("-inf") + #logits[1:] = index_put(logits[1:], neg_is_pos, fillval) return logits @@ -531,8 +533,6 @@ def forward( self, source, padding_mask=None, mask=True, features_only=False, mask_indices=None, mask_channel_indices=None, ): - import pdb - pdb.set_trace() if self.feature_grad_mult > 0: features = self.feature_extractor(source) @@ -575,31 +575,30 @@ def forward( curr_temp = q["temp"] features = self.project_inp(features) - from fairseq.metsumm import metsumm - metsumm("Before mask...") - if mask: x, mask_indices = self.apply_mask( features, padding_mask, mask_indices=mask_indices, mask_channel_indices=mask_channel_indices, ) - if mask_indices is not None: - y = unmasked_features[mask_indices].view(unmasked_features.size(0), -1, unmasked_features.size(-1)) + if x.device.type != 'xla' and mask_indices is not None: + # tpu-comment: reducing the size in a dynamic way causes + # too many recompilations on xla. + y = unmasked_features[mask_indices].view( + unmasked_features.size(0), -1, unmasked_features.size(-1) + ) else: y = unmasked_features else: x = features y = unmasked_features mask_indices = None - metsumm("After mask...") x = self.encoder(x, padding_mask=padding_mask) if features_only: return {"x": x, "padding_mask": padding_mask} - metsumm("Before quantizer...") if self.quantizer: q = self.quantizer(y, produce_targets=False) y = q["x"] @@ -636,19 +635,16 @@ def forward( else: negs, _ = self.sample_negatives(y, y.size(1)) - metsumm("After quantizer...") - # FIXME: taylan mask indices investigate dynamicity - import pdb - pdb.set_trace() - x = x[mask_indices].view(x.size(0), -1, x.size(-1)) + if x.device.type != 'xla': + # tpu-comment: reducing the size in a dynamic way causes + # too many recompilations on xla. + x = x[mask_indices].view(x.size(0), -1, x.size(-1)) - metsumm("Before Negs ...") if self.target_glu: y = self.target_glu(y) negs = self.target_glu(negs) x = self.final_proj(x) x = self.compute_preds(x, y, negs) - metsumm("After compute-pred ...") result = {"x": x, "padding_mask": padding_mask, "features_pen": features_pen} diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 5c3b69fa1f..0242cfddd6 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -39,8 +39,6 @@ class Trainer(object): def __init__(self, args, task, model, criterion, quantizer=None): self.args = args self.task = task - self.logging_history = [] - self.cumm_sample_size = 0 # catalog shared parameters shared_params = _catalog_shared_params(model) @@ -483,6 +481,7 @@ def maybe_no_sync(): if self._sync_stats(): # FIXME: taylan is this a problem for tpu? # FIXME: taylan maybe backward first, then sync stats? + # XXX: this never gets hit in this workload import pdb pdb.set_trace() train_time = self._local_cumulative_training_time() @@ -558,8 +557,6 @@ def maybe_no_sync(): # only log stats every log_interval steps # this causes wps to be misreported when log_interval > 1 - self.logging_history.extend(logging_outputs) - self.cumm_sample_size += sample_size logging_output = {} if self.get_num_updates() % self.args.log_interval == 0: import torch_xla.core.xla_model as xm @@ -567,8 +564,12 @@ def maybe_no_sync(): logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) - self.logging_history = [] - self.cumm_sample_size = 0 + xm.mark_step() + # XXX: when I put step closure, logging outputs is shrunk.. + #xm.add_step_closure( + # self._reduce_and_log_stats, + # args=(logging_outputs, sample_size, grad_norm) + #) # log whenever there's an XLA compilation, since these # slow down training and may indicate opportunities for @@ -605,7 +606,6 @@ def valid_step(self, sample, raise_oom=False): self._dummy_batch = sample if self.tpu: import torch_xla.core.xla_model as xm - xm.rendezvous('valid_step') # wait for all workers xm.mark_step() with torch.no_grad(): @@ -815,6 +815,8 @@ def _aggregate_logging_outputs( *extra_stats_to_sum, ignore=False, ): + import pdb + pdb.set_trace() if self.task.__class__.logging_outputs_can_be_summed(self.get_criterion()): return self._fast_stat_sync_sum( logging_outputs, *extra_stats_to_sum, ignore=ignore @@ -924,10 +926,11 @@ def is_consistent(tensor): ) def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): - if grad_norm is not None: - # FIXME: taylan what to do here? torch.clamp? - import pdb - pdb.set_trace() + # tpu-comment: grad_norm is a tensor in XLA + if ( + (not torch.is_tensor(grad_norm) and grad_norm is not None) + or (torch.is_tensor(grad_norm) and not torch.isnan(grad_norm)) + ): metrics.log_speed("ups", 1., priority=100, round=2) metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) if self.args.clip_norm > 0: diff --git a/fairseq/utils.py b/fairseq/utils.py index cd8952760c..163213ab78 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -81,6 +81,8 @@ def _move_to_cuda(tensor): def move_to_cpu(sample): + # FIXME: taylan remove + raise def _move_to_cpu(tensor): # PyTorch has poor support for half tensors (float16) on CPU. # Move any such tensors to float32. @@ -253,7 +255,7 @@ def convert_padding_direction( def item(tensor): # tpu-comment: making this a no-op for xla devices. - if tensor.device.type == 'xla': + if hasattr(tensor, 'device') and tensor.device.type == 'xla': return tensor if hasattr(tensor, "item"): return tensor.item() diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index c3baeb5267..6413f7e830 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -96,7 +96,7 @@ def main(args): "training on {} devices (GPUs/TPUs)".format(args.distributed_world_size) ) logger.info( - "max tokens per GPU = {} and max sentences per GPU = {}".format( + "max tokens per device = {} and max sentences per device = {}".format( args.max_tokens, args.max_sentences ) ) From af6d4d6adba4582e02d484c980c067a8ac4de602 Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Mon, 19 Oct 2020 20:25:21 +0000 Subject: [PATCH 12/19] Send device data to cpu b4 logging. --- fairseq/criterions/wav2vec_criterion.py | 10 ++++------ fairseq/data/bucket_pad_length_dataset.py | 19 +++++++++++++------ fairseq/distributed_utils.py | 2 -- fairseq/logging/metrics.py | 3 ++- fairseq/options.py | 1 + fairseq/tasks/audio_pretraining.py | 8 +------- fairseq/trainer.py | 16 ++++++---------- fairseq/utils.py | 9 +++++++-- fairseq_cli/train.py | 4 ++++ 9 files changed, 38 insertions(+), 34 deletions(-) diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index a2ee0d465e..7dde9f0b20 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -83,7 +83,7 @@ def forward(self, model, sample, reduce=True, log_pred=False): logging_output = { #'loss': losr.item() if reduce else loss, - 'loss': loss, + 'loss': loss.detach(), 'ntokens': sample_size, 'nsentences': sample['id'].numel(), 'sample_size': sample_size, @@ -95,7 +95,7 @@ def forward(self, model, sample, reduce=True, log_pred=False): if len(losses) > 1: for i, l in enumerate(losses): - logging_output[f'loss_{i}'] = l + logging_output[f'loss_{i}'] = l.detach() if self.infonce: with torch.no_grad(): @@ -114,9 +114,7 @@ def forward(self, model, sample, reduce=True, log_pred=False): logging_output["correct"] = corr logging_output["count"] = count - if log_pred: - # FIXME: taylan remove this. - raise + if log_pred and logits.device.type != 'xla': logging_output['logits'] = logits.cpu().numpy() logging_output['target'] = target.cpu().numpy() return loss, sample_size, logging_output @@ -169,4 +167,4 @@ def logging_outputs_can_be_summed() -> bool: across workers prior to calling `reduce_metrics`. Setting this to True will improves distributed training speed. """ - return False + return True diff --git a/fairseq/data/bucket_pad_length_dataset.py b/fairseq/data/bucket_pad_length_dataset.py index 423a89da2c..27fcca3d52 100644 --- a/fairseq/data/bucket_pad_length_dataset.py +++ b/fairseq/data/bucket_pad_length_dataset.py @@ -30,8 +30,7 @@ def __init__( num_buckets, pad_idx, left_pad, - lambda_get=None, - lambda_set=None, + tensor_key=None, ): super().__init__(dataset) self.pad_idx = pad_idx @@ -57,10 +56,18 @@ def get_bucketed_sizes(orig_sizes, buckets): return sizes self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets) - self._get_tensor = (lambda x: x) if lambda_get is None else lambda_get - self._set_tensor = ( - (lambda item, val: val) if lambda_set is None else lambda_set - ) + self._tensor_key = tensor_key + + def _set_tensor(self, item, val): + if self._tensor_key is None: + return val + item[self._tensor_key] = val + return item + + def _get_tensor(self, item): + if self._tensor_key is None: + return item + return item[self._tensor_key] def _pad(self, tensor, bucket_size, dim=-1): num_pad = bucket_size - tensor.size(dim) diff --git a/fairseq/distributed_utils.py b/fairseq/distributed_utils.py index f2150b52b2..0060d8c189 100644 --- a/fairseq/distributed_utils.py +++ b/fairseq/distributed_utils.py @@ -252,8 +252,6 @@ def all_gather_list(data, group=None, max_size=16384): all_reduce(buffer, group=group) - # FIXME: taylan remogve - raise buffer = buffer.cpu() try: result = [] diff --git a/fairseq/logging/metrics.py b/fairseq/logging/metrics.py index 6ca1d201e0..ba319a3c85 100644 --- a/fairseq/logging/metrics.py +++ b/fairseq/logging/metrics.py @@ -126,12 +126,13 @@ def log_scalar( priority (int): smaller values are logged earlier in the output round (Optional[int]): number of digits to round to when displaying """ + if torch.is_tensor(value) and value.device.type == 'xla': + value = value.item() for agg in get_active_aggregators(): if key not in agg: agg.add_meter(key, AverageMeter(round=round), priority) agg[key].update(value, weight) - def log_derived(key: str, fn: Callable[[MetersDict], float], priority: int = 20): """Log a scalar value derived from other meters. diff --git a/fairseq/options.py b/fairseq/options.py index 171c67966d..ce6edc345c 100644 --- a/fairseq/options.py +++ b/fairseq/options.py @@ -238,6 +238,7 @@ def get_parser(desc, default_task="translation"): help='pseudo random number generator seed') parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA') parser.add_argument('--tpu', action='store_true', help='use TPU instead of CUDA') + parser.add_argument('--xla-metrics-debug', action='store_true', help='Print XLA debug info') parser.add_argument('--bf16', action='store_true', help='use bfloat16; implies --tpu') parser.add_argument('--fp16', action='store_true', help='use FP16') parser.add_argument('--memory-efficient-bf16', action='store_true', diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index d298c931ee..b00acd6bad 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -124,13 +124,7 @@ def load_dataset(self, split, **kwargs): num_buckets=self.args.num_batch_buckets or 1, pad_idx=0, left_pad=False, - lambda_get=lambda item: item['source'], - lambda_set=( - lambda item, val: { - k: v if k != 'source' else val - for k, v in item.items() - } - ), + tensor_key='source', ) if self.args.labels: diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 0242cfddd6..15510f1040 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -479,11 +479,7 @@ def maybe_no_sync(): # gather logging outputs from all replicas if self._sync_stats(): - # FIXME: taylan is this a problem for tpu? - # FIXME: taylan maybe backward first, then sync stats? - # XXX: this never gets hit in this workload - import pdb - pdb.set_trace() + # FIXME: taylan this is not hit in 1 core. revisit when running 8 cores train_time = self._local_cumulative_training_time() logging_outputs, (sample_size, ooms, total_train_time) = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch, @@ -560,7 +556,9 @@ def maybe_no_sync(): logging_output = {} if self.get_num_updates() % self.args.log_interval == 0: import torch_xla.core.xla_model as xm + from fairseq.utils import xla_device_to_cpu xm.mark_step() + logging_outputs = xla_device_to_cpu(logging_outputs) logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) @@ -789,8 +787,9 @@ def _set_seed(self): utils.set_torch_seed(seed) def _sync_stats(self): - # Return True if it's using multiple GPUs and DDP or multiple GPUs with - # BMUF and it's a bmuf sync with warmup iterations completed before. + # Return True if it's using multiple devices and DDP + # or multiple devices with BMUF and it's a bmuf sync + # with warmup iterations completed before. if self.data_parallel_world_size == 1: return False elif self.args.use_bmuf: @@ -815,8 +814,6 @@ def _aggregate_logging_outputs( *extra_stats_to_sum, ignore=False, ): - import pdb - pdb.set_trace() if self.task.__class__.logging_outputs_can_be_summed(self.get_criterion()): return self._fast_stat_sync_sum( logging_outputs, *extra_stats_to_sum, ignore=ignore @@ -837,7 +834,6 @@ def _all_gather_list_sync( suitable when logging outputs are complex types. """ if self.tpu: - # FIXME: taylan - all gather etc. raise NotImplementedError if ignore: logging_outputs = [] diff --git a/fairseq/utils.py b/fairseq/utils.py index 163213ab78..a74d562365 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -255,8 +255,8 @@ def convert_padding_direction( def item(tensor): # tpu-comment: making this a no-op for xla devices. - if hasattr(tensor, 'device') and tensor.device.type == 'xla': - return tensor + if torch.is_tensor(tensor) and tensor.device.type == 'xla': + return tensor.detach() if hasattr(tensor, "item"): return tensor.item() if hasattr(tensor, "__getitem__"): @@ -565,6 +565,11 @@ def get_tpu_device(args): return xm.xla_device() +def xla_device_to_cpu(dat): + import torch_xla.core.xla_model as xm + return xm._maybe_convert_to_cpu(dat) + + class CudaEnvironment(object): def __init__(self): cur_device = torch.cuda.current_device() diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 6413f7e830..3c45c4c64e 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -211,6 +211,9 @@ def train(args, trainer, task, epoch_itr): # FIXME: delete these in the end #print('SHAPE', i, samples[0]['net_input']['source'].shape) #continue + if not i % 10: + import torch_xla.debug.metrics as met + print(met.metrics_report()) with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i @@ -227,6 +230,7 @@ def train(args, trainer, task, epoch_itr): # reset mid-epoch stats after each log interval # the end-of-epoch stats will still be preserved + # FIXME: taylan reset in closure!!!!!!!!!!!!!!!!! metrics.reset_meters("train_inner") end_of_epoch = not itr.has_next() From 9ce990959026e58c2834a4949daa8c5bb3652418 Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Wed, 21 Oct 2020 18:28:00 +0000 Subject: [PATCH 13/19] Fix data bucketing for RawAudioDataset, refactor bucketing functions, fix filling w/ -inf in wav2vec2, minor cleanups --- fairseq/criterions/wav2vec_criterion.py | 2 - fairseq/data/audio/raw_audio_dataset.py | 49 ++++++++++++++++++++--- fairseq/data/bucket_pad_length_dataset.py | 20 +-------- fairseq/data/data_utils.py | 39 ++++++++++-------- fairseq/logging/metrics.py | 9 +++++ fairseq/metsumm.py | 1 + fairseq/models/wav2vec/wav2vec2.py | 4 +- fairseq/tasks/audio_pretraining.py | 11 +---- fairseq/utils.py | 3 +- fairseq_cli/train.py | 8 +--- 10 files changed, 82 insertions(+), 64 deletions(-) diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index 7dde9f0b20..4e0bdb1917 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -40,8 +40,6 @@ def forward(self, model, sample, reduce=True, log_pred=False): 2) the sample size, which is used as the denominator for the gradient 3) logging outputs to display while training """ - # FIXME: taylan clean metsumm - from fairseq.metsumm import metsumm; metsumm("Before forward") net_output = model(**sample['net_input']) logits = model.get_logits(net_output).float() diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index bf35b28bfb..0fbfda7a33 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -13,7 +13,7 @@ import torch.nn.functional as F from .. import FairseqDataset, BaseWrapperDataset -from ..data_utils import compute_mask_indices +from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes logger = logging.getLogger(__name__) @@ -115,6 +115,9 @@ def _compute_mask_indices(self, dims, padding_mask): return mask_indices, mask_channel_indices + @staticmethod + def _bucket_tensor(tensor, num_pad, value): + return F.pad(tensor, (0, num_pad), value=value) def collater(self, samples): samples = [ @@ -154,14 +157,35 @@ def collater(self, samples): if self.pad: input["padding_mask"] = padding_mask + if hasattr(self, 'num_buckets') and self.num_buckets > 0: + assert self.pad, "Cannot bucket without padding first." + bucket = max(self._bucketed_sizes[s['id']] for s in samples) + num_pad = bucket - collated_sources.size(-1) + if num_pad: + input['source'] = self._bucket_tensor( + collated_sources, num_pad, 0 + ) + input['padding_mask'] = self._bucket_tensor( + padding_mask, num_pad, True + ) + if self.compute_mask_indices: - B = collated_sources.size(0) - T = self._get_mask_indices_dims(collated_sources.size(-1)) + B = input['source'].size(0) + T = self._get_mask_indices_dims(input['source'].size(-1)) + padding_mask_reshaped = input['padding_mask'].clone() + extra = padding_mask_reshaped.size(1) % T + if extra > 0: + padding_mask_reshaped = padding_mask_reshaped[:, :-extra] + padding_mask_reshaped = padding_mask_reshaped.view( + padding_mask_reshaped.size(0), T, -1 + ) + padding_mask_reshaped = padding_mask_reshaped.all(-1) mask_indices, mask_channel_indices = self._compute_mask_indices( - (B, T, self._C), padding_mask + (B, T, self._C), padding_mask_reshaped, ) input["mask_indices"] = mask_indices input["mask_channel_indices"] = mask_channel_indices + return { "id": torch.LongTensor([s["id"] for s in samples]), "net_input": input, @@ -213,6 +237,7 @@ def __init__( normalize=False, compute_mask_indices=False, args=None, + num_buckets=0, ): super().__init__( sample_rate=sample_rate, @@ -240,8 +265,22 @@ def __init__( continue self.fnames.append(items[0]) self.sizes.append(sz) + self.set_bucket_info(num_buckets) logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples") + def set_bucket_info(self, num_buckets): + self.num_buckets = num_buckets + if self.num_buckets > 0: + self._collated_sizes = np.minimum( + np.array(self.sizes), self.max_sample_size, + ) + self.buckets = get_buckets( + self._collated_sizes, self.num_buckets, + ) + self._bucketed_sizes = get_bucketed_sizes( + self._collated_sizes, self.buckets + ) + def __getitem__(self, index): import soundfile as sf @@ -250,5 +289,3 @@ def __getitem__(self, index): feats = torch.from_numpy(wav).float() feats = self.postprocess(feats, curr_sample_rate) return {"id": index, "source": feats} - - diff --git a/fairseq/data/bucket_pad_length_dataset.py b/fairseq/data/bucket_pad_length_dataset.py index 27fcca3d52..e4ce3c40ef 100644 --- a/fairseq/data/bucket_pad_length_dataset.py +++ b/fairseq/data/bucket_pad_length_dataset.py @@ -7,6 +7,7 @@ import torch.nn.functional as F from fairseq.data import BaseWrapperDataset +from fairseq.data.data_utils import get_buckets, get_bucketed_sizes class BucketPadLengthDataset(BaseWrapperDataset): @@ -37,24 +38,7 @@ def __init__( self.left_pad = left_pad assert num_buckets > 0 - self.buckets = np.unique( - np.percentile( - sizes, - np.linspace(0, 100, num_buckets + 1), - interpolation='lower', - )[1:] - ) - - def get_bucketed_sizes(orig_sizes, buckets): - sizes = np.copy(orig_sizes) - assert np.min(sizes) >= 0 - start_val = -1 - for end_val in buckets: - mask = (sizes > start_val) & (sizes <= end_val) - sizes[mask] = end_val - start_val = end_val - return sizes - + self.buckets = get_buckets(sizes, num_buckets) self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets) self._tensor_key = tensor_key diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index c9f66e17a4..36d785eb6e 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -407,20 +407,25 @@ def arrange(s, e, length, keep_length): mask[i, mask_idc] = True return mask - # FIXME: taylan remove this - """ - left_mask, right_mask = [], [] - for i, mask_idc in enumerate(mask_idcs): - if len(mask_idc) > min_len: - mask_idc = np.random.choice(mask_idc, min_len, replace=False) - mask[i, mask_idc] = True - for idc in np.sort(mask_idc): - l_mask = [False] * bsz - l_mask[i] = True - r_mask = [False] * all_sz - r_mask[idc] = True - left_mask.append(l_mask) - right_mask.append(r_mask) - - return mask, np.array(left_mask), np.array(right_mask) - """ + + +def get_buckets(sizes, num_buckets): + buckets = np.unique( + np.percentile( + sizes, + np.linspace(0, 100, num_buckets + 1), + interpolation='lower', + )[1:] + ) + return buckets + + +def get_bucketed_sizes(orig_sizes, buckets): + sizes = np.copy(orig_sizes) + assert np.min(sizes) >= 0 + start_val = -1 + for end_val in buckets: + mask = (sizes > start_val) & (sizes <= end_val) + sizes[mask] = end_val + start_val = end_val + return sizes diff --git a/fairseq/logging/metrics.py b/fairseq/logging/metrics.py index ba319a3c85..300234bc65 100644 --- a/fairseq/logging/metrics.py +++ b/fairseq/logging/metrics.py @@ -133,6 +133,7 @@ def log_scalar( agg.add_meter(key, AverageMeter(round=round), priority) agg[key].update(value, weight) + def log_derived(key: str, fn: Callable[[MetersDict], float], priority: int = 20): """Log a scalar value derived from other meters. @@ -290,3 +291,11 @@ def load_state_dict(state_dict): for name, agg_state in state_dict.items(): _aggregators[name] = MetersDict() _aggregators[name].load_state_dict(agg_state) + + +def xla_metrics_report(): + try: + import torch_xla.debug.metrics as met + print(met.metrics_report()) + except ImportError: + return diff --git a/fairseq/metsumm.py b/fairseq/metsumm.py index 9d83e9c016..ce0cfa0594 100644 --- a/fairseq/metsumm.py +++ b/fairseq/metsumm.py @@ -1,3 +1,4 @@ +# FIXME: remove this file def metsumm(stepno=''): if hasattr(metsumm, 'STEPNO'): metsumm.STEPNO += stepno.lower()=="before forward" diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index 16d5b9eba8..e5f6b4bd01 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -520,12 +520,10 @@ def compute_preds(self, x, y, negatives): if logits.device.type == 'xla' or neg_is_pos.any(): #pass - fillval = float(2**30) + fillval = -float(2**30) if not hasattr(self, '_inftensor'): self._inftensor = torch.tensor(fillval).to(x.device) logits[1:] = index_put(logits[1:], neg_is_pos, self._inftensor) - #logits[1:][neg_is_pos] = float("-inf") - #logits[1:] = index_put(logits[1:], neg_is_pos, fillval) return logits diff --git a/fairseq/tasks/audio_pretraining.py b/fairseq/tasks/audio_pretraining.py index b00acd6bad..2683e70578 100644 --- a/fairseq/tasks/audio_pretraining.py +++ b/fairseq/tasks/audio_pretraining.py @@ -115,17 +115,8 @@ def load_dataset(self, split, **kwargs): normalize=self.args.normalize, compute_mask_indices=self.args.tpu, args=self.args, + num_buckets=self.args.num_batch_buckets or int(self.args.tpu), ) - if self.args.num_batch_buckets > 0 or self.args.tpu: - # Always bucket for tpus. - self.datasets[split] = BucketPadLengthDataset( - self.datasets[split], - sizes=self.datasets[split].sizes, - num_buckets=self.args.num_batch_buckets or 1, - pad_idx=0, - left_pad=False, - tensor_key='source', - ) if self.args.labels: dict_path = os.path.join(self.args.data, f"dict.{self.args.labels}.txt") diff --git a/fairseq/utils.py b/fairseq/utils.py index a74d562365..94a539de0f 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -81,8 +81,7 @@ def _move_to_cuda(tensor): def move_to_cpu(sample): - # FIXME: taylan remove - raise + def _move_to_cpu(tensor): # PyTorch has poor support for half tensors (float16) on CPU. # Move any such tensors to float32. diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 3c45c4c64e..9346eabf43 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -208,12 +208,6 @@ def train(args, trainer, task, epoch_itr): should_stop = False num_updates = trainer.get_num_updates() for i, samples in enumerate(progress): - # FIXME: delete these in the end - #print('SHAPE', i, samples[0]['net_input']['source'].shape) - #continue - if not i % 10: - import torch_xla.debug.metrics as met - print(met.metrics_report()) with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i @@ -232,6 +226,8 @@ def train(args, trainer, task, epoch_itr): # the end-of-epoch stats will still be preserved # FIXME: taylan reset in closure!!!!!!!!!!!!!!!!! metrics.reset_meters("train_inner") + if args.xla_metrics_debug: + metrics.xla_metrics_report() end_of_epoch = not itr.has_next() valid_losses, should_stop = validate_and_save( From 7cd1be04db90dfebb3c7b25d3954585479377e8a Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Wed, 21 Oct 2020 22:21:52 +0000 Subject: [PATCH 14/19] Sample size computeation during data prep to reduce atens, dont call item in log_scalar, minor cleanups --- fairseq/criterions/wav2vec_criterion.py | 10 ++++------ fairseq/data/audio/raw_audio_dataset.py | 8 ++++---- fairseq/logging/metrics.py | 2 -- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index 4e0bdb1917..80400492ea 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -58,7 +58,9 @@ def forward(self, model, sample, reduce=True, log_pred=False): else: loss = F.binary_cross_entropy_with_logits(logits, target.float(), weights, reduction="sum" if reduce else "none",) - if 'mask_indices' in sample['net_input'] and self.infonce: + if 'sample_size' in sample and self.infonce: + sample_size = sample['sample_size'] + elif 'mask_indices' in sample['net_input'] and self.infonce: # XXX: what happens if not self.infonce? sample_size = sample['net_input']['mask_indices'].sum() else: @@ -80,7 +82,7 @@ def forward(self, model, sample, reduce=True, log_pred=False): losses.append(p) logging_output = { - #'loss': losr.item() if reduce else loss, + #'loss': loss.item() if reduce else loss, 'loss': loss.detach(), 'ntokens': sample_size, 'nsentences': sample['id'].numel(), @@ -152,10 +154,6 @@ def reduce_metrics(logging_outputs) -> None: if k.startswith('loss'): metrics.log_scalar(k, val / sample_size / math.log(2), sample_size) else: - # FIXME: taylan, round=3 could be a problem - # XXX: we dont hit this in this workload - import pdb - pdb.set_trace() metrics.log_scalar(k, val, round=3) @staticmethod diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index 0fbfda7a33..3eac4b040a 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -154,6 +154,7 @@ def collater(self, samples): collated_sources[i] = self.crop_to_max_size(source, target_size) input = {"source": collated_sources} + out = {"id": torch.LongTensor([s["id"] for s in samples])} if self.pad: input["padding_mask"] = padding_mask @@ -185,11 +186,10 @@ def collater(self, samples): ) input["mask_indices"] = mask_indices input["mask_channel_indices"] = mask_channel_indices + out['sample_size'] = mask_indices.sum().item() - return { - "id": torch.LongTensor([s["id"] for s in samples]), - "net_input": input, - } + out["net_input"] = input + return out def _get_mask_indices_dims(self, size, padding=0, dilation=1): if size not in self._features_size_map: diff --git a/fairseq/logging/metrics.py b/fairseq/logging/metrics.py index 300234bc65..0bac4c25b2 100644 --- a/fairseq/logging/metrics.py +++ b/fairseq/logging/metrics.py @@ -126,8 +126,6 @@ def log_scalar( priority (int): smaller values are logged earlier in the output round (Optional[int]): number of digits to round to when displaying """ - if torch.is_tensor(value) and value.device.type == 'xla': - value = value.item() for agg in get_active_aggregators(): if key not in agg: agg.add_meter(key, AverageMeter(round=round), priority) From 9b986637f0e3055e3a5b4521cea6cdee6c82d8ef Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Thu, 22 Oct 2020 20:03:45 +0000 Subject: [PATCH 15/19] Remove extra validation atens, clean up marking step and sending to cpu. --- fairseq/data/audio/raw_audio_dataset.py | 4 ++++ fairseq/trainer.py | 26 +++++++++++++++---------- 2 files changed, 20 insertions(+), 10 deletions(-) diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index 3eac4b040a..baafe0f961 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -280,6 +280,10 @@ def set_bucket_info(self, num_buckets): self._bucketed_sizes = get_bucketed_sizes( self._collated_sizes, self.buckets ) + logger.info( + f"{len(self.buckets)} bucket(s) for the audio dataset: " + f"{self.buckets}" + ) def __getitem__(self, index): import soundfile as sf diff --git a/fairseq/trainer.py b/fairseq/trainer.py index 15510f1040..e46cb88cf7 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -463,8 +463,7 @@ def maybe_no_sync(): # before marking step can lead to OOM errors. # To handle gradient accumulation use case, we explicitly # mark step here for every forward pass without a backward pass - import torch_xla.core.xla_model as xm - xm.mark_step() + self._xla_markstep_and_send_to_cpu() if is_dummy_batch: if torch.is_tensor(sample_size): @@ -555,15 +554,14 @@ def maybe_no_sync(): # this causes wps to be misreported when log_interval > 1 logging_output = {} if self.get_num_updates() % self.args.log_interval == 0: - import torch_xla.core.xla_model as xm - from fairseq.utils import xla_device_to_cpu - xm.mark_step() - logging_outputs = xla_device_to_cpu(logging_outputs) + logging_outputs = self._xla_markstep_and_send_to_cpu( + logging_outputs + ) logging_output = self._reduce_and_log_stats( logging_outputs, sample_size, grad_norm, ) - xm.mark_step() - # XXX: when I put step closure, logging outputs is shrunk.. + self._xla_markstep_and_send_to_cpu() + # FIXME: taylan when I put step closure, logging outputs is shrunk.. #xm.add_step_closure( # self._reduce_and_log_stats, # args=(logging_outputs, sample_size, grad_norm) @@ -603,8 +601,7 @@ def valid_step(self, sample, raise_oom=False): if self._dummy_batch == "DUMMY": self._dummy_batch = sample if self.tpu: - import torch_xla.core.xla_model as xm - xm.mark_step() + self._xla_markstep_and_send_to_cpu() with torch.no_grad(): self.model.eval() @@ -650,6 +647,8 @@ def valid_step(self, sample, raise_oom=False): ) # log validation stats + if self.tpu: + logging_outputs = self._xla_markstep_and_send_to_cpu(logging_outputs) logging_output = self._reduce_and_log_stats(logging_outputs, sample_size) return logging_output @@ -982,6 +981,13 @@ def _check_xla_compilation(self, message=None): logging.info("NOTE: XLA compilation detected; {}".format(message)) self._num_xla_compiles = num_xla_compiles + def _xla_markstep_and_send_to_cpu(self, data=None): + import torch_xla.core.xla_model as xm + xm.mark_step() + if data is not None: + from fairseq.utils import xla_device_to_cpu + return xla_device_to_cpu(data) + def _catalog_shared_params(module, memo=None, prefix=''): if memo is None: From 91cdca262b302486816ff3201f5ddc134a644f25 Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Fri, 23 Oct 2020 22:15:25 +0000 Subject: [PATCH 16/19] Correct loss computation for w2v2 criterion + refactor index_put --- fairseq/criterions/wav2vec_criterion.py | 20 ++++++++++++++++++-- fairseq/data/data_utils.py | 12 ------------ fairseq/models/wav2vec/wav2vec2.py | 4 ++-- fairseq/utils.py | 12 ++++++++++++ 4 files changed, 32 insertions(+), 16 deletions(-) diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index 80400492ea..761d38c2ae 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -10,6 +10,7 @@ from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion +from fairseq.utils import index_put @register_criterion('wav2vec') @@ -45,6 +46,14 @@ def forward(self, model, sample, reduce=True, log_pred=False): logits = model.get_logits(net_output).float() target = model.get_targets(sample, net_output) + if logits.device.type == 'xla': + # tpu-comment: since dynamic shapes lead to recompilations on xla, + # we don't shrink tensors using mask_indices. + # Instead, we do the following when computing loss: + mi = sample['net_input']['mask_indices'].reshape(logits.size(0)) + target = index_put(target, ~mi, -1) + + # XXX: handle weights on xla. weights = None if hasattr(model, 'get_target_weights') and not self.infonce: weights = model.get_target_weights(target, net_output) @@ -54,9 +63,16 @@ def forward(self, model, sample, reduce=True, log_pred=False): losses = [] if self.infonce: - loss = F.cross_entropy(logits, target, reduction="sum" if reduce else "none",) + loss = F.cross_entropy( + logits, target, reduction="sum" if reduce else "none", + ignore_index=-1, + ) else: - loss = F.binary_cross_entropy_with_logits(logits, target.float(), weights, reduction="sum" if reduce else "none",) + loss = F.binary_cross_entropy_with_logits( + logits, target.float(), weights, + reduction="sum" if reduce else "none", + ignore_index=-1, + ) if 'sample_size' in sample and self.infonce: sample_size = sample['sample_size'] diff --git a/fairseq/data/data_utils.py b/fairseq/data/data_utils.py index 36d785eb6e..e6818219cb 100644 --- a/fairseq/data/data_utils.py +++ b/fairseq/data/data_utils.py @@ -273,18 +273,6 @@ def post_process(sentence: str, symbol: str): return sentence -def index_put(tensor, indices, value): - if tensor.device.type == 'xla': - for _ in range(indices.dim(), tensor.dim()): - indices = indices.unsqueeze(-1) - if indices.size(-1) < tensor.size(-1): - indices = indices.expand_as(tensor) - tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices) - else: - tensor[indices] = value - return tensor - - def compute_mask_indices( shape: Tuple[int, int], padding_mask: Optional[torch.Tensor], diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index e5f6b4bd01..21239d8bb0 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -14,7 +14,7 @@ from typing import List, Tuple from fairseq import utils -from fairseq.data.data_utils import compute_mask_indices, index_put +from fairseq.data.data_utils import compute_mask_indices from fairseq.models import BaseFairseqModel, register_model, register_model_architecture from fairseq.modules import ( Fp32GroupNorm, @@ -27,7 +27,7 @@ TransposeLast, ) from fairseq.modules.transformer_sentence_encoder import init_bert_params -from fairseq.utils import buffered_arange +from fairseq.utils import buffered_arange, index_put @register_model("wav2vec2") diff --git a/fairseq/utils.py b/fairseq/utils.py index 94a539de0f..063df4a6ba 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -564,6 +564,18 @@ def get_tpu_device(args): return xm.xla_device() +def index_put(tensor, indices, value): + if tensor.device.type != 'xla': + for _ in range(indices.dim(), tensor.dim()): + indices = indices.unsqueeze(-1) + if indices.size(-1) < tensor.size(-1): + indices = indices.expand_as(tensor) + tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices) + else: + tensor[indices] = value + return tensor + + def xla_device_to_cpu(dat): import torch_xla.core.xla_model as xm return xm._maybe_convert_to_cpu(dat) From 2c59cecf654e12818988c31d1f82c0fc6050b355 Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Mon, 26 Oct 2020 23:54:19 +0000 Subject: [PATCH 17/19] Fix bug in index_put + fix integer division --- fairseq/criterions/wav2vec_criterion.py | 2 +- fairseq/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index 761d38c2ae..29f54a0fb0 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -125,7 +125,7 @@ def forward(self, model, sample, reduce=True, log_pred=False): both = max & min # corr = max.long().sum().item() - both.long().sum().item() corr = max.long().sum() - both.long().sum() - count = max.numel() + count = float(max.numel()) logging_output["correct"] = corr logging_output["count"] = count diff --git a/fairseq/utils.py b/fairseq/utils.py index 063df4a6ba..7e8db9f647 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -565,7 +565,7 @@ def get_tpu_device(args): def index_put(tensor, indices, value): - if tensor.device.type != 'xla': + if tensor.device.type == 'xla': for _ in range(indices.dim(), tensor.dim()): indices = indices.unsqueeze(-1) if indices.size(-1) < tensor.size(-1): From 55656228259db242723d040e0fc7ec986d08565d Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Tue, 27 Oct 2020 22:32:27 +0000 Subject: [PATCH 18/19] Dont call float on extra logs, clean up comment. --- fairseq/criterions/wav2vec_criterion.py | 5 ++++- fairseq/trainer.py | 1 - 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index 29f54a0fb0..c316147fe6 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -107,7 +107,10 @@ def forward(self, model, sample, reduce=True, log_pred=False): for lk in self.log_keys: if lk in net_output: - logging_output[lk] = float((net_output[lk])) + value = net_output[lk] + if not torch.is_tensor(value) or value.device.type != 'xla': + value = float(value) + logging_output[lk] = value if len(losses) > 1: for i, l in enumerate(losses): diff --git a/fairseq/trainer.py b/fairseq/trainer.py index e46cb88cf7..37d578e332 100644 --- a/fairseq/trainer.py +++ b/fairseq/trainer.py @@ -478,7 +478,6 @@ def maybe_no_sync(): # gather logging outputs from all replicas if self._sync_stats(): - # FIXME: taylan this is not hit in 1 core. revisit when running 8 cores train_time = self._local_cumulative_training_time() logging_outputs, (sample_size, ooms, total_train_time) = self._aggregate_logging_outputs( logging_outputs, sample_size, ooms, train_time, ignore=is_dummy_batch, From 45bde79fe9d87e1a628fcfc32e36e579a233711f Mon Sep 17 00:00:00 2001 From: Taylan Bilal Date: Mon, 2 Nov 2020 18:25:39 +0000 Subject: [PATCH 19/19] Conv debug attempts. --- fairseq/criterions/wav2vec_criterion.py | 24 ++++---- fairseq/data/audio/raw_audio_dataset.py | 79 +++++++++++++++++++++++++ fairseq/models/wav2vec/wav2vec2.py | 69 +++++++++++++++------ fairseq/utils.py | 6 +- fairseq_cli/train.py | 5 ++ 5 files changed, 154 insertions(+), 29 deletions(-) diff --git a/fairseq/criterions/wav2vec_criterion.py b/fairseq/criterions/wav2vec_criterion.py index c316147fe6..364c716300 100644 --- a/fairseq/criterions/wav2vec_criterion.py +++ b/fairseq/criterions/wav2vec_criterion.py @@ -10,7 +10,7 @@ from fairseq import metrics, utils from fairseq.criterions import FairseqCriterion, register_criterion -from fairseq.utils import index_put +from fairseq.utils import index_put, is_xla_tensor @register_criterion('wav2vec') @@ -46,7 +46,7 @@ def forward(self, model, sample, reduce=True, log_pred=False): logits = model.get_logits(net_output).float() target = model.get_targets(sample, net_output) - if logits.device.type == 'xla': + if is_xla_tensor(logits): # tpu-comment: since dynamic shapes lead to recompilations on xla, # we don't shrink tensors using mask_indices. # Instead, we do the following when computing loss: @@ -77,9 +77,9 @@ def forward(self, model, sample, reduce=True, log_pred=False): if 'sample_size' in sample and self.infonce: sample_size = sample['sample_size'] elif 'mask_indices' in sample['net_input'] and self.infonce: - # XXX: what happens if not self.infonce? sample_size = sample['net_input']['mask_indices'].sum() else: + # XXX: if not self.infonce, is xla path working correctly? sample_size = target.numel() if self.infonce else target.long().sum() losses.append(loss) @@ -98,7 +98,6 @@ def forward(self, model, sample, reduce=True, log_pred=False): losses.append(p) logging_output = { - #'loss': loss.item() if reduce else loss, 'loss': loss.detach(), 'ntokens': sample_size, 'nsentences': sample['id'].numel(), @@ -108,7 +107,7 @@ def forward(self, model, sample, reduce=True, log_pred=False): for lk in self.log_keys: if lk in net_output: value = net_output[lk] - if not torch.is_tensor(value) or value.device.type != 'xla': + if not is_xla_tensor(value): value = float(value) logging_output[lk] = value @@ -125,15 +124,20 @@ def forward(self, model, sample, reduce=True, log_pred=False): assert logits.dim() > 1, logits.shape max = logits.argmax(-1) == 0 min = logits.argmin(-1) == 0 - both = max & min - # corr = max.long().sum().item() - both.long().sum().item() - corr = max.long().sum() - both.long().sum() - count = float(max.numel()) + if is_xla_tensor(logits): + max, min = max * mi, min * mi + both = max & min + corr = max.long().sum() - both.long().sum() + count = mi.sum() + else: + both = max & min + corr = max.long().sum().item() - both.long().sum().item() + count = float(max.numel()) logging_output["correct"] = corr logging_output["count"] = count - if log_pred and logits.device.type != 'xla': + if log_pred and not is_xla_tensor(logits): logging_output['logits'] = logits.cpu().numpy() logging_output['target'] = target.cpu().numpy() return loss, sample_size, logging_output diff --git a/fairseq/data/audio/raw_audio_dataset.py b/fairseq/data/audio/raw_audio_dataset.py index baafe0f961..4d9b60c372 100644 --- a/fairseq/data/audio/raw_audio_dataset.py +++ b/fairseq/data/audio/raw_audio_dataset.py @@ -185,12 +185,91 @@ def collater(self, samples): (B, T, self._C), padding_mask_reshaped, ) input["mask_indices"] = mask_indices + import pdb + pdb.set_trace() + # FIXME: implement + #input["neg_idxs"] = self._get_neg_idxs(mask_indices) + input["tszs_after_mask"] = mask_indices.sum(-1).tolist() input["mask_channel_indices"] = mask_channel_indices out['sample_size'] = mask_indices.sum().item() out["net_input"] = input return out + def _get_neg_idxs(self, mask_indices): + return + + def _sample_negatives(self, mask_indices): + """ + Sampling negatives during model's forward is problematic on XLA. + That's why we do it here during data prep when run on XLA. + """ + self.n_negatives = self.args.num_negatives + self.cross_sample_negatives = self.args.cross_sample_negatives + if self.n_negatives == 0 and self.cross_sample_negatives == 0: + # FIXME: handle this + return y.new(0) + + (bsz, tsz), fsz = mask_indices.shape, self.args.final_dim + high = mask_indices.sum(-1).max().item() + cross_high = high * bsz + + with torch.no_grad(): + assert high > 1, f"{bsz,tsz,fsz}" + + if self.n_negatives > 0: + tszs = ( + torch.arange(tsz) + .unsqueeze(-1) + .expand(-1, self.n_negatives) + .flatten() + ) + + ts = torch.arange(tsz) + + neg_idxs = torch.randint( + low=0, high=high-1, size=(bsz, tsz, self.n_negatives) + ) + neg_idxs = torch.stack([ + ts[mask_indices[j]][ni] for j, ni in enumerate(neg_idxs) + ]).reshape(bsz, -1) + neg_idxs[neg_idxs >= tszs] += 1 + import pdb + pdb.set_trace() + + if self.cross_sample_negatives > 0: + raise NotImplementedError('Implement for XLA.') + tszs = ( + torch.arange(num) + .unsqueeze(-1) + .expand(-1, self.cross_sample_negatives) + .flatten() + ) + + cross_neg_idxs = torch.randint( + low=0, + high=cross_high - 1, + size=(bsz, self.cross_sample_negatives * num), + ) + cross_neg_idxs[cross_neg_idxs >= tszs] += 1 + + if self.n_negatives > 0: + for i in range(1, bsz): + neg_idxs[i] += i * high + else: + neg_idxs = cross_neg_idxs + + if self.cross_sample_negatives > 0 and self.n_negatives > 0: + neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) + + negs = y[neg_idxs.view(-1)] + negs = negs.view( + bsz, tsz, self.n_negatives + self.cross_sample_negatives, fsz + ).permute( + 2, 0, 1, 3 + ) # to NxBxTxC + return negs, neg_idxs + def _get_mask_indices_dims(self, size, padding=0, dilation=1): if size not in self._features_size_map: L_in = size diff --git a/fairseq/models/wav2vec/wav2vec2.py b/fairseq/models/wav2vec/wav2vec2.py index 21239d8bb0..5196797df2 100644 --- a/fairseq/models/wav2vec/wav2vec2.py +++ b/fairseq/models/wav2vec/wav2vec2.py @@ -27,7 +27,7 @@ TransposeLast, ) from fairseq.modules.transformer_sentence_encoder import init_bert_params -from fairseq.utils import buffered_arange, index_put +from fairseq.utils import buffered_arange, index_put, is_xla_tensor @register_model("wav2vec2") @@ -450,7 +450,21 @@ def apply_mask( return x, mask_indices - def sample_negatives(self, y, num): + def _get_neg_idxs(self, high, size, padding_counts=None, num=None): + if padding_counts is None: + neg_idxs = torch.randint(low=0, high=high-1, size=size) + else: + bsz, l = size + #num = l // self.n_negatives if num is None else num + assert len(padding_counts) == bsz + neg_idxs = [ + torch.randint(low=0, high=high-1, size=(1, l)) + for pc in padding_counts + ] + neg_idxs = torch.stack(neg_idxs) + return neg_idxs + + def sample_negatives(self, y, num, padding_counts=None): if self.n_negatives == 0 and self.cross_sample_negatives == 0: return y.new(0) @@ -459,21 +473,23 @@ def sample_negatives(self, y, num): y = y.view(-1, fsz) # BTC => (BxT)C cross_high = tsz * bsz - high = tsz + high = num + # FIXME: there's a problem here w/ tsz and num + # this assumes y is shrunk at this point. with torch.no_grad(): assert high > 1, f"{bsz,tsz,fsz}" if self.n_negatives > 0: tszs = ( - buffered_arange(num) + buffered_arange(tsz) .unsqueeze(-1) .expand(-1, self.n_negatives) .flatten() ) - neg_idxs = torch.randint( - low=0, high=high - 1, size=(bsz, self.n_negatives * num) - ) + neg_idxs = torch.randint(low=0, high=high-1, size=(bsz, self.n_negatives * tsz)) + import pdb + pdb.set_trace() neg_idxs[neg_idxs >= tszs] += 1 if self.cross_sample_negatives > 0: @@ -500,9 +516,13 @@ def sample_negatives(self, y, num): if self.cross_sample_negatives > 0 and self.n_negatives > 0: neg_idxs = torch.cat([neg_idxs, cross_neg_idxs], dim=1) + import pdb + pdb.set_trace() negs = y[neg_idxs.view(-1)] + import pdb + pdb.set_trace() negs = negs.view( - bsz, num, self.n_negatives + self.cross_sample_negatives, fsz + bsz, tsz, self.n_negatives + self.cross_sample_negatives, fsz ).permute( 2, 0, 1, 3 ) # to NxBxTxC @@ -518,8 +538,7 @@ def compute_preds(self, x, y, negatives): logits = logits / self.logit_temp - if logits.device.type == 'xla' or neg_is_pos.any(): - #pass + if is_xla_tensor(logits) or neg_is_pos.any(): fillval = -float(2**30) if not hasattr(self, '_inftensor'): self._inftensor = torch.tensor(fillval).to(x.device) @@ -529,7 +548,8 @@ def compute_preds(self, x, y, negatives): def forward( self, source, padding_mask=None, mask=True, features_only=False, - mask_indices=None, mask_channel_indices=None, + mask_indices=None, mask_channel_indices=None, padding_counts=None, + tszs_after_mask=None, ): if self.feature_grad_mult > 0: @@ -579,7 +599,7 @@ def forward( mask_indices=mask_indices, mask_channel_indices=mask_channel_indices, ) - if x.device.type != 'xla' and mask_indices is not None: + if not is_xla_tensor(x) and mask_indices is not None: # tpu-comment: reducing the size in a dynamic way causes # too many recompilations on xla. y = unmasked_features[mask_indices].view( @@ -607,13 +627,21 @@ def forward( y = self.project_q(y) + num = y.size(1) if tszs_after_mask is None else max(tszs_after_mask) if self.negatives_from_everywhere: - neg_cands, *_ = self.quantizer(unmasked_features, produce_targets=False) - negs, _ = self.sample_negatives(neg_cands, y.size(1)) + neg_cands, *_ = self.quantizer( + unmasked_features, produce_targets=False, + ) + negs, _ = self.sample_negatives( + neg_cands, num, padding_counts=padding_counts, + ) negs = self.project_q(negs) else: - negs, _ = self.sample_negatives(y, y.size(1)) + negs, _ = self.sample_negatives( + y, num, + padding_counts=padding_counts, + ) if self.codebook_negatives > 0: cb_negs = self.quantizer.sample_from_codebook( @@ -628,12 +656,17 @@ def forward( y = self.project_q(y) if self.negatives_from_everywhere: - negs, _ = self.sample_negatives(unmasked_features, y.size(1)) + negs, _ = self.sample_negatives( + unmasked_features, num, padding_counts=padding_counts, + ) negs = self.project_q(negs) else: - negs, _ = self.sample_negatives(y, y.size(1)) + negs, _ = self.sample_negatives( + y, num, + padding_counts=padding_counts, + ) - if x.device.type != 'xla': + if not is_xla_tensor(x): # tpu-comment: reducing the size in a dynamic way causes # too many recompilations on xla. x = x[mask_indices].view(x.size(0), -1, x.size(-1)) diff --git a/fairseq/utils.py b/fairseq/utils.py index 7e8db9f647..1af2e2537e 100644 --- a/fairseq/utils.py +++ b/fairseq/utils.py @@ -564,8 +564,12 @@ def get_tpu_device(args): return xm.xla_device() +def is_xla_tensor(tensor): + return torch.is_tensor(tensor) and tensor.device.type == 'xla' + + def index_put(tensor, indices, value): - if tensor.device.type == 'xla': + if is_xla_tensor(tensor): for _ in range(indices.dim(), tensor.dim()): indices = indices.unsqueeze(-1) if indices.size(-1) < tensor.size(-1): diff --git a/fairseq_cli/train.py b/fairseq_cli/train.py index 9346eabf43..832f474678 100644 --- a/fairseq_cli/train.py +++ b/fairseq_cli/train.py @@ -209,6 +209,11 @@ def train(args, trainer, task, epoch_itr): num_updates = trainer.get_num_updates() for i, samples in enumerate(progress): + from fairseq.metsumm import metsumm as m + if not i % 50: + import torch_xla.core.xla_model as xm + if xm.is_master_ordinal(): + m(str(i)) with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function( "train_step-%d" % i ):