Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
7e6f3cf
Intermediate experiments wav2vec
ultrons Aug 18, 2020
6a03e5c
input shape temp update
Aug 20, 2020
433ca76
clean up
Aug 20, 2020
19535a2
dataset updates
ultrons Aug 20, 2020
a500587
clean up
ultrons Aug 20, 2020
b49f103
move tensor idx to matrix op inside apply_mask
kevinmtian Sep 2, 2020
73e2f3b
use tensor operators to replace tensor indexing, passed consistency t…
kevinmtian Sep 3, 2020
8802a31
Minor improvements
taylanbil Oct 6, 2020
dc60592
Fix bucketpadlendataset
taylanbil Oct 7, 2020
be3ca6a
Moved mask matrices creation to dataset prep.
taylanbil Oct 13, 2020
30f3761
Remove dynamism, apply mask correctly, add some guardrails, some clea…
taylanbil Oct 15, 2020
af6d4d6
Send device data to cpu b4 logging.
taylanbil Oct 19, 2020
9ce9909
Fix data bucketing for RawAudioDataset, refactor bucketing functions,…
taylanbil Oct 21, 2020
7cd1be0
Sample size computeation during data prep to reduce atens, dont call …
taylanbil Oct 21, 2020
9b98663
Remove extra validation atens, clean up marking step and sending to cpu.
taylanbil Oct 22, 2020
91cdca2
Correct loss computation for w2v2 criterion + refactor index_put
taylanbil Oct 23, 2020
2c59cec
Fix bug in index_put + fix integer division
taylanbil Oct 26, 2020
5565622
Dont call float on extra logs, clean up comment.
taylanbil Oct 27, 2020
51b8ba8
Correct accuracy computation, refactor xla tensor check.
taylanbil Nov 3, 2020
25f0145
Adjust loss computation so it works w/ binary cross entropy.
taylanbil Nov 3, 2020
d844a2e
Remove sending log outputs back to cpu after allreduce.
taylanbil Nov 5, 2020
10fcd6c
Dont sample padded states when sampling negatives + correct mi in los…
taylanbil Nov 13, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 45 additions & 12 deletions fairseq/criterions/wav2vec_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from fairseq import metrics, utils
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.utils import index_put, is_xla_tensor


@register_criterion('wav2vec')
Expand Down Expand Up @@ -41,9 +42,12 @@ def forward(self, model, sample, reduce=True, log_pred=False):
3) logging outputs to display while training
"""
net_output = model(**sample['net_input'])

logits = model.get_logits(net_output).float()
target = model.get_targets(sample, net_output)
xla = is_xla_tensor(logits)

# XXX: handle weights on xla.
weights = None
if hasattr(model, 'get_target_weights') and not self.infonce:
weights = model.get_target_weights(target, net_output)
Expand All @@ -52,12 +56,32 @@ def forward(self, model, sample, reduce=True, log_pred=False):

losses = []

reduction = "none" if ((not reduce) or xla) else "sum"
if self.infonce:
loss = F.cross_entropy(logits, target, reduction="sum" if reduce else "none",)
loss = F.cross_entropy(logits, target, reduction=reduction)
else:
loss = F.binary_cross_entropy_with_logits(logits, target.float(), weights, reduction="sum" if reduce else "none",)
loss = F.binary_cross_entropy_with_logits(
logits, target.float(), weights, reduction=reduction
)

sample_size = target.numel() if self.infonce else target.long().sum().item()
if xla:
# tpu-comment: since dynamic shapes lead to recompilations on xla,
# we don't shrink tensors using mask_indices.
# Instead, we use mask indices to adjust loss.
mi = (
sample['net_input']['mask_indices']
.transpose(0, 1) # logits are transposed in `model.get_logits`
.reshape(logits.size(0))
)
loss = (loss * mi).sum() if reduce else (loss * mi)

if 'sample_size' in sample and self.infonce:
sample_size = sample['sample_size']
elif 'mask_indices' in sample['net_input']:
sample_size = sample['net_input']['mask_indices'].sum()
else:
# XXX: if not self.infonce, is xla path working correctly?
sample_size = target.numel() if self.infonce else target.long().sum()
losses.append(loss)

if self.loss_weights is not None:
Expand All @@ -75,19 +99,22 @@ def forward(self, model, sample, reduce=True, log_pred=False):
losses.append(p)

logging_output = {
'loss': loss.item() if reduce else loss,
'loss': loss.detach(),
'ntokens': sample_size,
'nsentences': sample['id'].numel(),
'sample_size': sample_size,
}

for lk in self.log_keys:
if lk in net_output:
logging_output[lk] = float((net_output[lk]))
value = net_output[lk]
if not is_xla_tensor(value):
value = float(value)
logging_output[lk] = value

if len(losses) > 1:
for i, l in enumerate(losses):
logging_output[f'loss_{i}'] = l.item()
logging_output[f'loss_{i}'] = l.detach()

if self.infonce:
with torch.no_grad():
Expand All @@ -98,14 +125,20 @@ def forward(self, model, sample, reduce=True, log_pred=False):
assert logits.dim() > 1, logits.shape
max = logits.argmax(-1) == 0
min = logits.argmin(-1) == 0
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = max.numel()
if is_xla_tensor(logits):
max, min = max * mi, min * mi
both = max & min
corr = max.long().sum() - both.long().sum()
count = mi.sum()
else:
both = max & min
corr = max.long().sum().item() - both.long().sum().item()
count = float(max.numel())

logging_output["correct"] = corr
logging_output["count"] = count

if log_pred:
if log_pred and not is_xla_tensor(logits):
logging_output['logits'] = logits.cpu().numpy()
logging_output['target'] = target.cpu().numpy()
return loss, sample_size, logging_output
Expand All @@ -132,7 +165,7 @@ def reduce_metrics(logging_outputs) -> None:
if total > 0:
metrics.log_derived(
"accuracy",
lambda meters: round(meters["_correct"].sum / meters["_total"].sum, 5)
lambda meters: meters["_correct"].sum / meters["_total"].sum
if meters["_total"].sum > 0
else float("nan"),
)
Expand All @@ -154,4 +187,4 @@ def logging_outputs_can_be_summed() -> bool:
across workers prior to calling `reduce_metrics`. Setting this
to True will improves distributed training speed.
"""
return False
return True
121 changes: 119 additions & 2 deletions fairseq/data/audio/raw_audio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import torch
import torch.nn.functional as F

from .. import FairseqDataset
from .. import FairseqDataset, BaseWrapperDataset
from ..data_utils import compute_mask_indices, get_buckets, get_bucketed_sizes

logger = logging.getLogger(__name__)

Expand All @@ -27,6 +28,8 @@ def __init__(
min_length=0,
pad=False,
normalize=False,
compute_mask_indices=False,
args=None,
):
super().__init__()

Expand All @@ -40,6 +43,12 @@ def __init__(
self.pad = pad
self.shuffle = shuffle
self.normalize = normalize
self.compute_mask_indices = compute_mask_indices
if self.compute_mask_indices:
self.args = args
self._features_size_map = {}
self._C = self.args.encoder_embed_dim
self._conv_feature_layers = eval(self.args.conv_feature_layers)

def __getitem__(self, index):
raise NotImplementedError()
Expand Down Expand Up @@ -71,6 +80,45 @@ def crop_to_max_size(self, wav, target_size):
end = size - diff + start
return wav[start:end]

def _compute_mask_indices(self, dims, padding_mask):
B, T, C = dims
mask_indices, mask_channel_indices = None, None
if self.args.mask_prob > 0:
mask_indices = compute_mask_indices(
(B, T),
padding_mask,
self.args.mask_prob,
self.args.mask_length,
self.args.mask_selection,
self.args.mask_other,
min_masks=2,
no_overlap=self.args.no_mask_overlap,
min_space=self.args.mask_min_space,
)
mask_indices = torch.from_numpy(mask_indices)
if self.args.mask_channel_prob > 0:
mask_channel_indices = compute_mask_indices(
(B, C),
None,
self.args.mask_channel_prob,
self.args.mask_channel_length,
self.args.mask_channel_selection,
self.args.mask_channel_other,
no_overlap=self.args.no_mask_channel_overlap,
min_space=self.args.mask_channel_min_space,
)
mask_channel_indices = (
torch.from_numpy(mask_channel_indices)
.unsqueeze(1)
.expand(-1, T, -1)
)

return mask_indices, mask_channel_indices

@staticmethod
def _bucket_tensor(tensor, num_pad, value):
return F.pad(tensor, (0, num_pad), value=value)

def collater(self, samples):
samples = [
s
Expand Down Expand Up @@ -106,9 +154,55 @@ def collater(self, samples):
collated_sources[i] = self.crop_to_max_size(source, target_size)

input = {"source": collated_sources}
out = {"id": torch.LongTensor([s["id"] for s in samples])}
if self.pad:
input["padding_mask"] = padding_mask
return {"id": torch.LongTensor([s["id"] for s in samples]), "net_input": input}

if hasattr(self, 'num_buckets') and self.num_buckets > 0:
assert self.pad, "Cannot bucket without padding first."
bucket = max(self._bucketed_sizes[s['id']] for s in samples)
num_pad = bucket - collated_sources.size(-1)
if num_pad:
input['source'] = self._bucket_tensor(
collated_sources, num_pad, 0
)
input['padding_mask'] = self._bucket_tensor(
padding_mask, num_pad, True
)

if self.compute_mask_indices:
B = input['source'].size(0)
T = self._get_mask_indices_dims(input['source'].size(-1))
padding_mask_reshaped = input['padding_mask'].clone()
extra = padding_mask_reshaped.size(1) % T
if extra > 0:
padding_mask_reshaped = padding_mask_reshaped[:, :-extra]
padding_mask_reshaped = padding_mask_reshaped.view(
padding_mask_reshaped.size(0), T, -1
)
padding_mask_reshaped = padding_mask_reshaped.all(-1)
input['padding_count'] = (
padding_mask_reshaped.sum(-1).max().item()
)
mask_indices, mask_channel_indices = self._compute_mask_indices(
(B, T, self._C), padding_mask_reshaped,
)
input["mask_indices"] = mask_indices
input["mask_channel_indices"] = mask_channel_indices
out['sample_size'] = mask_indices.sum().item()

out["net_input"] = input
return out

def _get_mask_indices_dims(self, size, padding=0, dilation=1):
if size not in self._features_size_map:
L_in = size
for (_, kernel_size, stride) in self._conv_feature_layers:
L_out = L_in + 2*padding - dilation*(kernel_size-1) - 1
L_out = 1 + L_out // stride
L_in = L_out
self._features_size_map[size] = L_out
return self._features_size_map[size]

def num_tokens(self, index):
return self.size(index)
Expand Down Expand Up @@ -144,6 +238,9 @@ def __init__(
min_length=0,
pad=False,
normalize=False,
compute_mask_indices=False,
args=None,
num_buckets=0,
):
super().__init__(
sample_rate=sample_rate,
Expand All @@ -153,6 +250,8 @@ def __init__(
min_length=min_length,
pad=pad,
normalize=normalize,
compute_mask_indices=compute_mask_indices,
args=args,
)

self.fnames = []
Expand All @@ -169,8 +268,26 @@ def __init__(
continue
self.fnames.append(items[0])
self.sizes.append(sz)
self.set_bucket_info(num_buckets)
logger.info(f"loaded {len(self.fnames)}, skipped {skipped} samples")

def set_bucket_info(self, num_buckets):
self.num_buckets = num_buckets
if self.num_buckets > 0:
self._collated_sizes = np.minimum(
np.array(self.sizes), self.max_sample_size,
)
self.buckets = get_buckets(
self._collated_sizes, self.num_buckets,
)
self._bucketed_sizes = get_bucketed_sizes(
self._collated_sizes, self.buckets
)
logger.info(
f"{len(self.buckets)} bucket(s) for the audio dataset: "
f"{self.buckets}"
)

def __getitem__(self, index):
import soundfile as sf

Expand Down
46 changes: 24 additions & 22 deletions fairseq/data/bucket_pad_length_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import torch.nn.functional as F

from fairseq.data import BaseWrapperDataset
from fairseq.data.data_utils import get_buckets, get_bucketed_sizes


class BucketPadLengthDataset(BaseWrapperDataset):
Expand All @@ -30,42 +31,43 @@ def __init__(
num_buckets,
pad_idx,
left_pad,
tensor_key=None,
):
super().__init__(dataset)
self.pad_idx = pad_idx
self.left_pad = left_pad

assert num_buckets > 0
self.buckets = np.unique(
np.percentile(
sizes,
np.linspace(0, 100, num_buckets + 1),
interpolation='lower',
)[1:]
)
self.buckets = get_buckets(sizes, num_buckets)
self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets)
self._tensor_key = tensor_key

def get_bucketed_sizes(orig_sizes, buckets):
sizes = np.copy(orig_sizes)
assert np.min(sizes) >= 0
start_val = -1
for end_val in buckets:
mask = (sizes > start_val) & (sizes <= end_val)
sizes[mask] = end_val
start_val = end_val
return sizes
def _set_tensor(self, item, val):
if self._tensor_key is None:
return val
item[self._tensor_key] = val
return item

self._bucketed_sizes = get_bucketed_sizes(sizes, self.buckets)
def _get_tensor(self, item):
if self._tensor_key is None:
return item
return item[self._tensor_key]

def __getitem__(self, index):
item = self.dataset[index]
bucket_size = self._bucketed_sizes[index]
num_pad = bucket_size - item.size(-1)
def _pad(self, tensor, bucket_size, dim=-1):
num_pad = bucket_size - tensor.size(dim)
return F.pad(
item,
tensor,
(num_pad if self.left_pad else 0, 0 if self.left_pad else num_pad),
value=self.pad_idx,
)

def __getitem__(self, index):
item = self.dataset[index]
bucket_size = self._bucketed_sizes[index]
tensor = self._get_tensor(item)
padded = self._pad(tensor, bucket_size)
return self._set_tensor(item, padded)

@property
def sizes(self):
return self._bucketed_sizes
Expand Down
Loading