From f7b3dfa839efb361b767a10e177064b607563362 Mon Sep 17 00:00:00 2001 From: Vishu Gupta Date: Sat, 13 Dec 2025 14:39:10 -0500 Subject: [PATCH 01/21] add code to debug Transformer, add S4 archietcture and NaN based Early Stopping --- src/clm/commands/sample_molecules_RNN.py | 138 ++- src/clm/commands/train_models_RNN.py | 99 +- src/clm/loggers.py | 13 + src/clm/models.py | 859 +++++++++++++++++- src/clm/module_library/README.md | 1 + src/clm/module_library/__init__.py | 0 src/clm/module_library/cauchy.py | 18 + src/clm/module_library/dplr.py | 129 +++ src/clm/module_library/ff.py | 65 ++ src/clm/module_library/hippo.py | 277 ++++++ src/clm/module_library/kernel.py | 718 +++++++++++++++ src/clm/module_library/krylov.py | 209 +++++ src/clm/module_library/pool.py | 62 ++ src/clm/module_library/residual.py | 23 + src/clm/module_library/s4.py | 290 ++++++ src/clm/module_library/sequence_model.py | 204 +++++ src/clm/module_library/sequence_module.py | 131 +++ .../module_library/sequence_residual_block.py | 148 +++ src/clm/module_library/toeplitz.py | 156 ++++ src/clm/module_library/util_modules.py | 318 +++++++ src/clm/src/__init__.py | 0 src/clm/src/callbacks/norms.py | 39 + src/clm/src/callbacks/params.py | 37 + src/clm/src/callbacks/progressive_resizing.py | 118 +++ src/clm/src/callbacks/timer.py | 100 ++ src/clm/src/callbacks/wandb.py | 277 ++++++ src/clm/src/dataloaders/README.md | 40 + src/clm/src/dataloaders/__init__.py | 2 + src/clm/src/dataloaders/base.py | 276 ++++++ src/clm/src/dataloaders/basic.py | 271 ++++++ .../src/dataloaders/datasets/detokenizer.py | 53 ++ .../src/dataloaders/datasets/lm_dataset.py | 32 + src/clm/src/dataloaders/et.py | 626 +++++++++++++ .../src/dataloaders/fault_tolerant_sampler.py | 123 +++ .../src/dataloaders/language_modeling_hf.py | 311 +++++++ src/clm/src/dataloaders/lm.py | 507 +++++++++++ src/clm/src/dataloaders/lra.py | 689 ++++++++++++++ src/clm/src/dataloaders/synthetics.py | 335 +++++++ .../dataloaders/utils/cifar_augmentations.py | 138 +++ src/clm/src/dataloaders/utils/timm_mixup.py | 22 + src/clm/src/dataloaders/utils/vocabulary.py | 237 +++++ src/clm/src/dataloaders/vision.py | 279 ++++++ src/clm/src/models/__init__.py | 0 src/clm/src/models/baselines/vit_all.py | 433 +++++++++ src/clm/src/models/nn/__init__.py | 1 + src/clm/src/models/nn/adaptive_softmax.py | 404 ++++++++ src/clm/src/models/nn/components.py | 389 ++++++++ src/clm/src/models/nn/dxt.py | 196 ++++ src/clm/src/models/nn/gate.py | 128 +++ src/clm/src/models/nn/residual.py | 108 +++ src/clm/src/models/nn/utils.py | 125 +++ src/clm/src/models/sequence/__init__.py | 3 + src/clm/src/models/sequence/base.py | 131 +++ src/clm/src/models/sequence/block.py | 129 +++ src/clm/src/models/sequence/block_fft.py | 177 ++++ src/clm/src/models/sequence/ff.py | 50 + src/clm/src/models/sequence/h3.py | 206 +++++ src/clm/src/models/sequence/h3_conv.py | 150 +++ src/clm/src/models/sequence/hyena.py | 359 ++++++++ .../src/models/sequence/hyena_components.py | 255 ++++++ src/clm/src/models/sequence/long_conv.py | 170 ++++ .../src/models/sequence/long_conv_kernel.py | 82 ++ src/clm/src/models/sequence/long_conv_lm.py | 397 ++++++++ src/clm/src/models/sequence/mha.py | 122 +++ src/clm/src/models/sequence/model.py | 134 +++ src/clm/src/models/sequence/pool.py | 459 ++++++++++ src/clm/src/models/sequence/simple_lm.py | 469 ++++++++++ src/clm/src/models/sequence/ssm/dplr.py | 107 +++ src/clm/src/models/sequence/ssm/hippo.py | 259 ++++++ src/clm/src/models/sequence/ssm/s4_simple.py | 262 ++++++ src/clm/src/models/sequence/ssm/s4d.py | 404 ++++++++ src/clm/src/models/sequence/ssm/ss_kernel.py | 180 ++++ .../src/models/sequence/ssm/ss_kernel_diag.py | 331 +++++++ .../models/sequence/ssm/ss_kernel_shift.py | 83 ++ src/clm/src/ops/fftconv.py | 103 +++ src/clm/src/ops/krylov.py | 198 ++++ src/clm/src/ops/toeplitz.py | 157 ++++ src/clm/src/ops/unroll.py | 421 +++++++++ src/clm/src/ops/vandermonde.py | 167 ++++ src/clm/src/retnet/__init__.py | 0 src/clm/src/retnet/complex/retention.py | 177 ++++ src/clm/src/retnet/complex/retnet.py | 118 +++ src/clm/src/retnet/complex/test_retention.py | 119 +++ src/clm/src/retnet/complex/test_retnet.py | 102 +++ src/clm/src/retnet/complex/util.py | 71 ++ src/clm/src/retnet/example.py | 17 + src/clm/src/retnet/retention.py | 204 +++++ src/clm/src/retnet/retnet.py | 76 ++ src/clm/src/retnet/tests.py | 154 ++++ src/clm/src/retnet/xpos_relative_position.py | 94 ++ src/clm/src/tasks/decoders.py | 319 +++++++ src/clm/src/tasks/encoders.py | 358 ++++++++ src/clm/src/tasks/metrics.py | 225 +++++ src/clm/src/tasks/tasks.py | 371 ++++++++ src/clm/src/tasks/torchmetrics.py | 120 +++ src/clm/src/utils/__init__.py | 1 + src/clm/src/utils/config.py | 124 +++ src/clm/src/utils/distributed.py | 144 +++ src/clm/src/utils/optim/lamb.py | 251 +++++ src/clm/src/utils/optim/schedulers.py | 87 ++ src/clm/src/utils/optim_groups.py | 144 +++ src/clm/src/utils/permutations.py | 180 ++++ src/clm/src/utils/registry.py | 53 ++ src/clm/src/utils/train.py | 156 ++++ .../config-spectraverse-allv1-s4_cv.yaml | 196 ++++ ...fig-spectraverse-allv1-transformer_cv.yaml | 196 ++++ 106 files changed, 20120 insertions(+), 56 deletions(-) create mode 100644 src/clm/module_library/README.md create mode 100644 src/clm/module_library/__init__.py create mode 100644 src/clm/module_library/cauchy.py create mode 100644 src/clm/module_library/dplr.py create mode 100644 src/clm/module_library/ff.py create mode 100644 src/clm/module_library/hippo.py create mode 100644 src/clm/module_library/kernel.py create mode 100644 src/clm/module_library/krylov.py create mode 100644 src/clm/module_library/pool.py create mode 100644 src/clm/module_library/residual.py create mode 100644 src/clm/module_library/s4.py create mode 100644 src/clm/module_library/sequence_model.py create mode 100644 src/clm/module_library/sequence_module.py create mode 100644 src/clm/module_library/sequence_residual_block.py create mode 100644 src/clm/module_library/toeplitz.py create mode 100644 src/clm/module_library/util_modules.py create mode 100644 src/clm/src/__init__.py create mode 100644 src/clm/src/callbacks/norms.py create mode 100644 src/clm/src/callbacks/params.py create mode 100644 src/clm/src/callbacks/progressive_resizing.py create mode 100644 src/clm/src/callbacks/timer.py create mode 100644 src/clm/src/callbacks/wandb.py create mode 100644 src/clm/src/dataloaders/README.md create mode 100644 src/clm/src/dataloaders/__init__.py create mode 100644 src/clm/src/dataloaders/base.py create mode 100644 src/clm/src/dataloaders/basic.py create mode 100644 src/clm/src/dataloaders/datasets/detokenizer.py create mode 100644 src/clm/src/dataloaders/datasets/lm_dataset.py create mode 100644 src/clm/src/dataloaders/et.py create mode 100644 src/clm/src/dataloaders/fault_tolerant_sampler.py create mode 100644 src/clm/src/dataloaders/language_modeling_hf.py create mode 100644 src/clm/src/dataloaders/lm.py create mode 100644 src/clm/src/dataloaders/lra.py create mode 100644 src/clm/src/dataloaders/synthetics.py create mode 100644 src/clm/src/dataloaders/utils/cifar_augmentations.py create mode 100644 src/clm/src/dataloaders/utils/timm_mixup.py create mode 100644 src/clm/src/dataloaders/utils/vocabulary.py create mode 100644 src/clm/src/dataloaders/vision.py create mode 100644 src/clm/src/models/__init__.py create mode 100644 src/clm/src/models/baselines/vit_all.py create mode 100644 src/clm/src/models/nn/__init__.py create mode 100644 src/clm/src/models/nn/adaptive_softmax.py create mode 100644 src/clm/src/models/nn/components.py create mode 100644 src/clm/src/models/nn/dxt.py create mode 100644 src/clm/src/models/nn/gate.py create mode 100644 src/clm/src/models/nn/residual.py create mode 100644 src/clm/src/models/nn/utils.py create mode 100644 src/clm/src/models/sequence/__init__.py create mode 100644 src/clm/src/models/sequence/base.py create mode 100644 src/clm/src/models/sequence/block.py create mode 100644 src/clm/src/models/sequence/block_fft.py create mode 100644 src/clm/src/models/sequence/ff.py create mode 100644 src/clm/src/models/sequence/h3.py create mode 100644 src/clm/src/models/sequence/h3_conv.py create mode 100644 src/clm/src/models/sequence/hyena.py create mode 100644 src/clm/src/models/sequence/hyena_components.py create mode 100644 src/clm/src/models/sequence/long_conv.py create mode 100644 src/clm/src/models/sequence/long_conv_kernel.py create mode 100644 src/clm/src/models/sequence/long_conv_lm.py create mode 100644 src/clm/src/models/sequence/mha.py create mode 100644 src/clm/src/models/sequence/model.py create mode 100644 src/clm/src/models/sequence/pool.py create mode 100644 src/clm/src/models/sequence/simple_lm.py create mode 100644 src/clm/src/models/sequence/ssm/dplr.py create mode 100644 src/clm/src/models/sequence/ssm/hippo.py create mode 100644 src/clm/src/models/sequence/ssm/s4_simple.py create mode 100644 src/clm/src/models/sequence/ssm/s4d.py create mode 100644 src/clm/src/models/sequence/ssm/ss_kernel.py create mode 100644 src/clm/src/models/sequence/ssm/ss_kernel_diag.py create mode 100644 src/clm/src/models/sequence/ssm/ss_kernel_shift.py create mode 100644 src/clm/src/ops/fftconv.py create mode 100644 src/clm/src/ops/krylov.py create mode 100644 src/clm/src/ops/toeplitz.py create mode 100644 src/clm/src/ops/unroll.py create mode 100644 src/clm/src/ops/vandermonde.py create mode 100644 src/clm/src/retnet/__init__.py create mode 100644 src/clm/src/retnet/complex/retention.py create mode 100644 src/clm/src/retnet/complex/retnet.py create mode 100644 src/clm/src/retnet/complex/test_retention.py create mode 100644 src/clm/src/retnet/complex/test_retnet.py create mode 100644 src/clm/src/retnet/complex/util.py create mode 100644 src/clm/src/retnet/example.py create mode 100644 src/clm/src/retnet/retention.py create mode 100644 src/clm/src/retnet/retnet.py create mode 100644 src/clm/src/retnet/tests.py create mode 100644 src/clm/src/retnet/xpos_relative_position.py create mode 100644 src/clm/src/tasks/decoders.py create mode 100644 src/clm/src/tasks/encoders.py create mode 100644 src/clm/src/tasks/metrics.py create mode 100644 src/clm/src/tasks/tasks.py create mode 100644 src/clm/src/tasks/torchmetrics.py create mode 100644 src/clm/src/utils/__init__.py create mode 100644 src/clm/src/utils/config.py create mode 100644 src/clm/src/utils/distributed.py create mode 100644 src/clm/src/utils/optim/lamb.py create mode 100644 src/clm/src/utils/optim/schedulers.py create mode 100644 src/clm/src/utils/optim_groups.py create mode 100644 src/clm/src/utils/permutations.py create mode 100644 src/clm/src/utils/registry.py create mode 100644 src/clm/src/utils/train.py create mode 100644 workflow/config/config-spectraverse-allv1-s4_cv.yaml create mode 100644 workflow/config/config-spectraverse-allv1-transformer_cv.yaml diff --git a/src/clm/commands/sample_molecules_RNN.py b/src/clm/commands/sample_molecules_RNN.py index 86d6cf05..19b2b831 100644 --- a/src/clm/commands/sample_molecules_RNN.py +++ b/src/clm/commands/sample_molecules_RNN.py @@ -6,7 +6,7 @@ from tqdm import tqdm from clm.datasets import Vocabulary, SelfiesVocabulary -from clm.models import RNN, ConditionalRNN +from clm.models import RNN, ConditionalRNN, Transformer, StructuredStateSpaceSequenceModel#, H3Model, H3ConvModel, HyenaModel from clm.functions import load_dataset, write_to_csv_file logger = logging.getLogger(__name__) @@ -122,7 +122,8 @@ def sample_molecules_RNN( vocab = Vocabulary(vocab_file=vocab_file) heldout_dataset = None - if conditional: + + if rnn_type == "S4": assert ( heldout_file is not None ), "heldout_file must be provided for conditional RNN Model" @@ -131,29 +132,124 @@ def sample_molecules_RNN( input_file=heldout_file, vocab_file=vocab_file, ) - model = ConditionalRNN( - vocab, - rnn_type=rnn_type, + model = StructuredStateSpaceSequenceModel( + vocabulary=vocab, # heldout_dataset.vocabulary + model_dim=embedding_size, + state_dim=64, n_layers=n_layers, - embedding_size=embedding_size, - hidden_size=hidden_size, + n_ssm=1, dropout=dropout, - num_descriptors=heldout_dataset.n_descriptors, - conditional_emb=conditional_emb, - conditional_emb_l=conditional_emb_l, - conditional_dec=conditional_dec, - conditional_dec_l=conditional_dec_l, - conditional_h=conditional_h, ) - else: - model = RNN( - vocab, - rnn_type=rnn_type, - n_layers=n_layers, + # elif rnn_type == "H3": + # assert ( + # heldout_file is not None + # ), "heldout_file must be provided for conditional RNN Model" + # heldout_dataset = load_dataset( + # representation=representation, + # input_file=heldout_file, + # vocab_file=vocab_file, + # ) + # model = H3Model( + # vocabulary=vocab, + # n_layers=n_layers, + # d_model=embedding_size, + # d_state=64, + # head_dim=1, + # dropout=dropout, + # max_len=250, + # use_fast_fftconv=False, + # ) + # elif rnn_type == "H3Conv": + # assert ( + # heldout_file is not None + # ), "heldout_file must be provided for conditional RNN Model" + # heldout_dataset = load_dataset( + # representation=representation, + # input_file=heldout_file, + # vocab_file=vocab_file, + # ) + # model = H3ConvModel( + # vocabulary=vocab, + # n_layers=n_layers, + # d_model=embedding_size, + # head_dim=1, + # dropout=dropout, + # max_len=250, + # use_fast_fftconv=False, + # ) + # elif rnn_type == "Hyena": + # assert ( + # heldout_file is not None + # ), "heldout_file must be provided for conditional RNN Model" + # heldout_dataset = load_dataset( + # representation=representation, + # input_file=heldout_file, + # vocab_file=vocab_file, + # ) + # model = HyenaModel( + # vocabulary=vocab, + # n_layers=n_layers, + # d_model=embedding_size, + # order=2, + # filter_order=64, + # num_heads=1, + # dropout=dropout, + # max_len=250, + # inner_factor=1, + # ) + + elif rnn_type == "Transformer": + assert ( + heldout_file is not None + ), "heldout_file must be provided for conditional RNN Model" + heldout_dataset = load_dataset( + representation=representation, + input_file=heldout_file, + vocab_file=vocab_file, + ) + model = Transformer( + vocabulary=vocab, + n_blocks=n_layers, + n_heads=4, embedding_size=embedding_size, - hidden_size=hidden_size, dropout=dropout, + exp_factor=4, + bias=True, ) + + else: + if conditional: + assert ( + heldout_file is not None + ), "heldout_file must be provided for conditional RNN Model" + heldout_dataset = load_dataset( + representation=representation, + input_file=heldout_file, + vocab_file=vocab_file, + ) + model = ConditionalRNN( + vocab, + rnn_type=rnn_type, + n_layers=n_layers, + embedding_size=embedding_size, + hidden_size=hidden_size, + dropout=dropout, + num_descriptors=heldout_dataset.n_descriptors, + conditional_emb=conditional_emb, + conditional_emb_l=conditional_emb_l, + conditional_dec=conditional_dec, + conditional_dec_l=conditional_dec_l, + conditional_h=conditional_h, + ) + else: + model = RNN( + vocab, + rnn_type=rnn_type, + n_layers=n_layers, + embedding_size=embedding_size, + hidden_size=hidden_size, + dropout=dropout, + ) logging.info(vocab.dictionary) if torch.cuda.is_available(): @@ -171,8 +267,10 @@ def sample_molecules_RNN( n_sequences = min(batch_size, sample_mols - i) descriptors = None if heldout_dataset is not None: + # Use modulo to cycle through heldout_dataset + descriptor_indices = [(i + j) % len(heldout_dataset) for j in range(n_sequences)] descriptors = torch.stack( - [heldout_dataset[_i][1] for _i in range(i, i + n_sequences)] + [heldout_dataset[idx][1] for idx in descriptor_indices] ) descriptors = descriptors.to(model.device) sampled_smiles, losses = model.sample( diff --git a/src/clm/commands/train_models_RNN.py b/src/clm/commands/train_models_RNN.py index 4abfddf1..3ad3cdd2 100644 --- a/src/clm/commands/train_models_RNN.py +++ b/src/clm/commands/train_models_RNN.py @@ -6,10 +6,13 @@ from torch.utils.data import DataLoader from tqdm import tqdm from rdkit import rdBase -from clm.models import RNN, ConditionalRNN +from clm.models import RNN, ConditionalRNN, Transformer, StructuredStateSpaceSequenceModel#, H3Model, H3ConvModel, HyenaModel from clm.loggers import EarlyStopping, track_loss, print_update from clm.functions import write_smiles, load_dataset +import warnings +warnings.filterwarnings("ignore", category=FutureWarning) + # suppress Chem.MolFromSmiles error output rdBase.DisableLog("rdApp.error") logger = logging.getLogger(__name__) @@ -179,30 +182,88 @@ def train_models_RNN( dataset = load_dataset(representation, input_file, vocab_file) - if conditional: - model = ConditionalRNN( - dataset.vocabulary, - rnn_type=rnn_type, + if rnn_type == "S4": + model = StructuredStateSpaceSequenceModel( + vocabulary=dataset.vocabulary, + model_dim=embedding_size, + state_dim=64, n_layers=n_layers, - embedding_size=embedding_size, - hidden_size=hidden_size, + n_ssm=1, dropout=dropout, - num_descriptors=dataset.n_descriptors, - conditional_emb=conditional_emb, - conditional_emb_l=conditional_emb_l, - conditional_dec=conditional_dec, - conditional_dec_l=conditional_dec_l, - conditional_h=conditional_h, ) - else: - model = RNN( - dataset.vocabulary, - rnn_type=rnn_type, - n_layers=n_layers, + + # elif rnn_type == "H3": + # model = H3Model( + # vocabulary=dataset.vocabulary, + # n_layers=n_layers, + # d_model=embedding_size, + # d_state=64, + # head_dim=1, + # dropout=dropout, + # max_len=250, + # use_fast_fftconv=False, + # ) + + # elif rnn_type == "H3Conv": + # model = H3ConvModel( + # vocabulary=dataset.vocabulary, + # n_layers=n_layers, + # d_model=embedding_size, + # head_dim=1, + # dropout=dropout, + # max_len=250, + # use_fast_fftconv=False, + # ) + + # elif rnn_type == "Hyena": + # model = HyenaModel( + # vocabulary=dataset.vocabulary, + # n_layers=n_layers, + # d_model=embedding_size, + # order=2, + # filter_order=64, + # num_heads=1, + # dropout=dropout, + # max_len=250, + # inner_factor=1, + # ) + + elif rnn_type == "Transformer": + model = Transformer( + vocabulary=dataset.vocabulary, + n_blocks=n_layers, + n_heads=4, embedding_size=embedding_size, - hidden_size=hidden_size, dropout=dropout, + exp_factor=4, + bias=True, ) + + else: + if conditional: + model = ConditionalRNN( + dataset.vocabulary, + rnn_type=rnn_type, + n_layers=n_layers, + embedding_size=embedding_size, + hidden_size=hidden_size, + dropout=dropout, + num_descriptors=dataset.n_descriptors, + conditional_emb=conditional_emb, + conditional_emb_l=conditional_emb_l, + conditional_dec=conditional_dec, + conditional_dec_l=conditional_dec_l, + conditional_h=conditional_h, + ) + else: + model = RNN( + dataset.vocabulary, + rnn_type=rnn_type, + n_layers=n_layers, + embedding_size=embedding_size, + hidden_size=hidden_size, + dropout=dropout, + ) logger.info(dataset.vocabulary.dictionary) diff --git a/src/clm/loggers.py b/src/clm/loggers.py index b9c7d496..785c7f13 100644 --- a/src/clm/loggers.py +++ b/src/clm/loggers.py @@ -1,6 +1,7 @@ import os import pandas as pd import torch +import math from rdkit import Chem from tqdm import tqdm from clm.models import ConditionalRNN @@ -34,9 +35,21 @@ def __init__(self, patience=100): self.best_loss = None self.step_at_best = 0 self.stop = False + self.nan_counter = 0 print("instantiated early stopping with patience=" + str(self.patience)) def __call__(self, val_loss, model, output_file, step_idx): + # Check for NaN/Inf + if math.isnan(val_loss) or math.isinf(val_loss): + self.nan_counter += 1 + print(f"NaN/Inf loss detected at step {step_idx} ({self.nan_counter}/3)") + if self.nan_counter >= 3: + self.stop = True + print("Stopping training after 3 consecutive NaN/Inf losses.") + if self.best_loss is not None: + print(f"Best model (loss={self.best_loss:.4f}) already saved.") + return + # do nothing if early stopping is disabled if self.patience > 0: if self.best_loss is None: diff --git a/src/clm/models.py b/src/clm/models.py index 7e1f62ea..c1c93464 100644 --- a/src/clm/models.py +++ b/src/clm/models.py @@ -4,6 +4,820 @@ import torch.nn.functional as F from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence +from einops import rearrange +# from clm.src.models.sequence.h3 import H3 +# from clm.src.models.sequence.h3_conv import H3Conv +# from clm.src.models.sequence.hyena_components import HyenaOperator + +from .module_library.sequence_model import SequenceModel + + +# class H3Model(nn.Module): +# def __init__( +# self, +# vocabulary, +# n_layers=4, +# d_model=256, +# d_state=64, +# head_dim=1, +# dropout=0.1, +# max_len=250, +# use_fast_fftconv=False, +# ): +# super(H3Model, self).__init__() + +# if H3 is None: +# raise ImportError( +# "H3 modules not found. Make sure src.models.sequence.h3 is available." +# ) + +# # detect device +# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# # vocabulary +# self.vocabulary = vocabulary +# self.vocabulary_size = len(self.vocabulary) +# self.padding_idx = self.vocabulary.dictionary[""] +# padding_t = torch.tensor(self.padding_idx).to(self.device) + +# # hyperparams +# self.n_layers = n_layers +# self.d_model = d_model +# self.d_state = d_state +# self.head_dim = head_dim +# self.dropout = dropout +# self.max_len = max_len +# self.use_fast_fftconv = use_fast_fftconv + +# # model components +# self.embedding = nn.Embedding( +# self.vocabulary_size, self.d_model, padding_idx=padding_t +# ) + +# # H3 layers +# self.layers = nn.ModuleList([ +# H3( +# d_model=self.d_model, +# d_state=self.d_state, +# l_max=self.max_len, +# head_dim=self.head_dim, +# use_fast_fftconv=self.use_fast_fftconv, +# dropout=self.dropout, +# layer_idx=i, +# ) +# for i in range(self.n_layers) +# ]) + +# # dropout and output +# self.norm = nn.LayerNorm(self.d_model) +# self.dropout_layer = nn.Dropout(dropout) +# self.output_projection = nn.Linear(self.d_model, self.vocabulary_size) + +# # loss function (ignoring padding) +# self.loss_fn = nn.CrossEntropyLoss( +# ignore_index=self.padding_idx, reduction="none" +# ) + +# # move to GPU +# if torch.cuda.is_available(): +# self.cuda() + +# def forward(self, x, inference_params=None): + +# batch_size, seq_len = x.size() + +# # Embed the input +# x = self.embedding(x) # (batch_size, seq_len, d_model) + +# # Pass through H3 layers +# for layer in self.layers: +# x = layer(x, inference_params=inference_params) +# if self.dropout > 0: +# x = self.dropout_layer(x) + +# # Normalize and project to vocabulary +# x = self.norm(x) +# logits = self.output_projection(x) # (batch_size, seq_len, vocab_size) + +# return logits + +# def loss(self, batch): +# if len(batch) == 3: +# padded, lengths, _ = batch +# else: +# padded, lengths = batch + +# padded = padded.to(self.device) + +# # Handle different input formats +# if padded.dim() == 2: +# if padded.shape[0] > padded.shape[1]: +# padded = padded.transpose(0, 1) + +# # Forward pass +# logits = self(padded) + +# # Calculate loss +# targets = padded[:, 1:] +# logits = logits[:, :-1, :] + +# loss = 0.0 +# actual_len = min(logits.shape[1], targets.shape[1]) + +# for char_idx in range(actual_len): +# loss += self.loss_fn(logits[:, char_idx, :], targets[:, char_idx]) + +# return loss.mean() + +# def sample( +# self, +# *, +# n_sequences, +# max_len=None, +# return_smiles=True, +# return_losses=False, +# descriptors=None, +# ): +# if max_len is None: +# max_len = self.max_len + +# self.eval() + +# # Get start/stop tokens +# start_token = self.vocabulary.dictionary["SOS"] +# stop_token = self.vocabulary.dictionary["EOS"] +# pad_token = self.vocabulary.dictionary[""] + +# # Create inference params +# class InferenceParams: +# def __init__(self, max_seqlen, batch_size): +# self.max_seqlen = max_seqlen +# self.max_batch_size = batch_size +# self.sequence_len_offset = 0 +# self.key_value_memory_dict = {} + +# inference_params = InferenceParams(max_len, n_sequences) + +# # Initialize with start tokens - keep only current token for recurrent stepping +# current_token = torch.full( +# (n_sequences, 1), start_token, dtype=torch.long, device=self.device +# ) + +# loss_fn = nn.NLLLoss(reduction="none", ignore_index=pad_token) + +# finished = torch.zeros(n_sequences, dtype=torch.bool, device=self.device) +# log_probs = torch.zeros(n_sequences, device=self.device) +# sequences = [] + +# with torch.no_grad(): +# for step in range(max_len): +# # Process only the current token in recurrent mode +# logits = self(current_token, inference_params=inference_params) +# logits = logits[:, -1, :] # Get last (and only) position + +# logits = torch.clamp(logits, min=-1e4, max=1e4) +# prob = F.softmax(logits, dim=-1) + +# if torch.isnan(prob).any() or torch.isinf(prob).any(): +# break + +# outputs = torch.multinomial(prob, num_samples=1) +# sequences.append(outputs) + +# log_prob = F.log_softmax(logits, dim=-1) +# losses = loss_fn(log_prob, outputs.squeeze(1)) +# losses[finished] = 0 +# log_probs += losses + +# # Update current token for next step (don't accumulate) +# current_token = outputs +# inference_params.sequence_len_offset += 1 + +# finished = finished | (outputs.squeeze(1) == stop_token) +# if finished.all(): +# break + +# seqs = torch.cat(sequences, 1) if sequences else torch.full( +# (n_sequences, 1), start_token, dtype=torch.long, device=self.device +# ) + +# if return_smiles: +# outputs = [self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs] +# else: +# outputs = sequences + +# if return_losses: +# return outputs, log_probs.detach().cpu().numpy() +# else: +# return outputs + + +# class H3ConvModel(nn.Module): +# def __init__( +# self, +# vocabulary, +# n_layers=4, +# d_model=256, +# head_dim=1, +# dropout=0.1, +# max_len=250, +# use_fast_fftconv=False, +# ): +# super(H3ConvModel, self).__init__() + +# if H3Conv is None: +# raise ImportError( +# "H3Conv modules not found. Make sure src.models.sequence.h3_conv is available." +# ) + +# # detect device +# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# # vocabulary +# self.vocabulary = vocabulary +# self.vocabulary_size = len(self.vocabulary) +# self.padding_idx = self.vocabulary.dictionary[""] +# padding_t = torch.tensor(self.padding_idx).to(self.device) + +# # hyperparams +# self.n_layers = n_layers +# self.d_model = d_model +# self.head_dim = head_dim +# self.dropout = dropout +# self.max_len = max_len +# self.use_fast_fftconv = use_fast_fftconv + +# # model components +# self.embedding = nn.Embedding( +# self.vocabulary_size, self.d_model, padding_idx=padding_t +# ) + +# # H3Conv layers +# self.layers = nn.ModuleList([ +# H3Conv( +# d_model=self.d_model, +# l_max=self.max_len, +# head_dim=self.head_dim, +# use_fast_fftconv=self.use_fast_fftconv, +# dropout=self.dropout, +# layer_idx=i, +# ) +# for i in range(self.n_layers) +# ]) + +# # dropout and output +# self.norm = nn.LayerNorm(self.d_model) +# self.dropout_layer = nn.Dropout(dropout) +# self.output_projection = nn.Linear(self.d_model, self.vocabulary_size) + +# # loss function (ignoring padding) +# self.loss_fn = nn.CrossEntropyLoss( +# ignore_index=self.padding_idx, reduction="none" +# ) + +# # move to GPU +# if torch.cuda.is_available(): +# self.cuda() + +# def forward(self, x, inference_params=None): +# batch_size, seq_len = x.size() + +# # Embed the input +# x = self.embedding(x) # (batch_size, seq_len, d_model) + +# # Pass through H3Conv layers +# for layer in self.layers: +# x = layer(x, inference_params=inference_params) +# if self.dropout > 0: +# x = self.dropout_layer(x) + +# # Normalize and project to vocabulary +# x = self.norm(x) +# logits = self.output_projection(x) + +# return logits + +# def loss(self, batch): +# if len(batch) == 3: +# padded, lengths, _ = batch +# else: +# padded, lengths = batch + +# padded = padded.to(self.device) + +# # Handle different input formats +# if padded.dim() == 2: +# if padded.shape[0] > padded.shape[1]: +# padded = padded.transpose(0, 1) + +# # Forward pass +# logits = self(padded) + +# # Calculate loss +# targets = padded[:, 1:] +# logits = logits[:, :-1, :] + +# loss = 0.0 +# actual_len = min(logits.shape[1], targets.shape[1]) + +# for char_idx in range(actual_len): +# loss += self.loss_fn(logits[:, char_idx, :], targets[:, char_idx]) + +# return loss.mean() + +# def sample( +# self, +# *, +# n_sequences, +# max_len=None, +# return_smiles=True, +# return_losses=False, +# descriptors=None, +# ): +# if max_len is None: +# max_len = self.max_len + +# self.eval() + +# start_token = self.vocabulary.dictionary["SOS"] +# stop_token = self.vocabulary.dictionary["EOS"] +# pad_token = self.vocabulary.dictionary[""] + +# # H3Conv doesn't use stateful inference, process full sequence each time +# inputs = torch.full( +# (n_sequences, 1), start_token, dtype=torch.long, device=self.device +# ) + +# loss_fn = nn.NLLLoss(reduction="none", ignore_index=pad_token) + +# finished = torch.zeros(n_sequences, dtype=torch.bool, device=self.device) +# log_probs = torch.zeros(n_sequences, device=self.device) +# sequences = [] + +# with torch.no_grad(): +# for step in range(max_len): +# logits = self(inputs) +# logits = logits[:, -1, :] + +# logits = torch.clamp(logits, min=-1e4, max=1e4) +# prob = F.softmax(logits, dim=-1) + +# if torch.isnan(prob).any() or torch.isinf(prob).any(): +# break + +# outputs = torch.multinomial(prob, num_samples=1) +# sequences.append(outputs) + +# log_prob = F.log_softmax(logits, dim=-1) +# losses = loss_fn(log_prob, outputs.squeeze(1)) +# losses[finished] = 0 +# log_probs += losses + +# inputs = torch.cat([inputs, outputs], dim=1) + +# finished = finished | (outputs.squeeze(1) == stop_token) +# if finished.all(): +# break + +# seqs = torch.cat(sequences, 1) if sequences else torch.full( +# (n_sequences, 1), start_token, dtype=torch.long, device=self.device +# ) + +# if return_smiles: +# outputs = [self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs] +# else: +# outputs = sequences + +# if return_losses: +# return outputs, log_probs.detach().cpu().numpy() +# else: +# return outputs + + +# class HyenaModel(nn.Module): +# def __init__( +# self, +# vocabulary, +# n_layers=4, +# d_model=256, +# order=2, +# filter_order=64, +# num_heads=1, +# dropout=0.1, +# max_len=250, +# inner_factor=1, +# ): +# super(HyenaModel, self).__init__() + +# # detect device +# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +# # vocabulary +# self.vocabulary = vocabulary +# self.vocabulary_size = len(self.vocabulary) +# self.padding_idx = self.vocabulary.dictionary[""] +# padding_t = torch.tensor(self.padding_idx).to(self.device) + +# # hyperparams +# self.n_layers = n_layers +# self.d_model = d_model +# self.order = order +# self.filter_order = filter_order +# self.num_heads = num_heads +# self.dropout = dropout +# self.max_len = max_len +# self.inner_factor = inner_factor + +# # model components +# self.embedding = nn.Embedding( +# self.vocabulary_size, self.d_model, padding_idx=padding_t +# ) + +# # Hyena layers +# self.layers = nn.ModuleList([ +# HyenaOperator( +# d_model=self.d_model, +# l_max=self.max_len, +# order=self.order, +# filter_order=self.filter_order, +# num_heads=self.num_heads, +# inner_factor=self.inner_factor, +# dropout=self.dropout, +# ) +# for i in range(self.n_layers) +# ]) + +# # dropout and output +# self.norm = nn.LayerNorm(self.d_model) +# self.dropout_layer = nn.Dropout(dropout) +# self.output_projection = nn.Linear(self.d_model, self.vocabulary_size) + +# # loss function (ignoring padding) +# self.loss_fn = nn.CrossEntropyLoss( +# ignore_index=self.padding_idx, reduction="none" +# ) + +# # move to GPU +# if torch.cuda.is_available(): +# self.cuda() + +# def forward(self, x): +# batch_size, seq_len = x.size() + +# # Embed the input +# x = self.embedding(x) # (batch_size, seq_len, d_model) + +# # Pass through Hyena layers +# for layer in self.layers: +# residual = x +# x = layer(x) +# x = x + residual # Residual connection +# if self.dropout > 0: +# x = self.dropout_layer(x) + +# # Normalize and project to vocabulary +# x = self.norm(x) +# logits = self.output_projection(x) + +# return logits + +# def loss(self, batch): +# if len(batch) == 3: +# padded, lengths, _ = batch +# else: +# padded, lengths = batch + +# padded = padded.to(self.device) + +# # Handle different input formats +# if padded.dim() == 2: +# if padded.shape[0] > padded.shape[1]: +# padded = padded.transpose(0, 1) + +# # Forward pass +# logits = self(padded) + +# # Calculate loss +# targets = padded[:, 1:] +# logits = logits[:, :-1, :] + +# loss = 0.0 +# actual_len = min(logits.shape[1], targets.shape[1]) + +# for char_idx in range(actual_len): +# loss += self.loss_fn(logits[:, char_idx, :], targets[:, char_idx]) + +# return loss.mean() + +# def sample( +# self, +# *, +# n_sequences, +# max_len=None, +# return_smiles=True, +# return_losses=False, +# descriptors=None, +# ): +# if max_len is None: +# max_len = self.max_len + +# self.eval() + +# start_token = self.vocabulary.dictionary["SOS"] +# stop_token = self.vocabulary.dictionary["EOS"] +# pad_token = self.vocabulary.dictionary[""] + +# # Initialize with start tokens +# inputs = torch.full( +# (n_sequences, 1), start_token, dtype=torch.long, device=self.device +# ) + +# loss_fn = nn.NLLLoss(reduction="none", ignore_index=pad_token) + +# finished = torch.zeros(n_sequences, dtype=torch.bool, device=self.device) +# log_probs = torch.zeros(n_sequences, device=self.device) +# sequences = [] + +# with torch.no_grad(): +# for step in range(max_len): +# # Hyena processes full sequence each time (stateless) +# logits = self(inputs) +# logits = logits[:, -1, :] + +# logits = torch.clamp(logits, min=-1e4, max=1e4) +# prob = F.softmax(logits, dim=-1) + +# if torch.isnan(prob).any() or torch.isinf(prob).any(): +# break + +# outputs = torch.multinomial(prob, num_samples=1) +# sequences.append(outputs) + +# log_prob = F.log_softmax(logits, dim=-1) +# losses = loss_fn(log_prob, outputs.squeeze(1)) +# losses[finished] = 0 +# log_probs += losses + +# inputs = torch.cat([inputs, outputs], dim=1) + +# finished = finished | (outputs.squeeze(1) == stop_token) +# if finished.all(): +# break + +# seqs = torch.cat(sequences, 1) if sequences else torch.full( +# (n_sequences, 1), start_token, dtype=torch.long, device=self.device +# ) + +# if return_smiles: +# outputs = [self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs] +# else: +# outputs = sequences + +# if return_losses: +# return outputs, log_probs.detach().cpu().numpy() +# else: +# return outputs + + +class StructuredStateSpaceSequenceModel(nn.Module): + def __init__( + self, + vocabulary, + model_dim=256, + state_dim=64, + n_layers=4, + n_ssm=1, + dropout=0.25, + max_len=250, + ): + super(StructuredStateSpaceSequenceModel, self).__init__() + + # detect device + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # vocabulary + self.vocabulary = vocabulary + self.vocabulary_size = len(self.vocabulary) + self.padding_idx = self.vocabulary.dictionary[""] + padding_t = torch.tensor(self.padding_idx).to(self.device) + + # hyperparams + self.model_dim = model_dim + self.state_dim = state_dim + self.n_layers = n_layers + self.n_ssm = n_ssm + self.dropout = dropout + self.max_len = max_len + + # S4 layer configuration + self.layer_config = [ + { + "_name_": "s4", + "d_state": self.state_dim, + "n_ssm": self.n_ssm, + }, + { + "_name_": "s4", + "d_state": self.state_dim, + "n_ssm": self.n_ssm, + }, + {"_name_": "ff"}, + ] + self.pool_config = {"_name_": "pool", "stride": 1, "expand": None} + + # model components + self.embedding = nn.Embedding( + self.vocabulary_size, self.model_dim, padding_idx=padding_t + ) + + # Import SequenceModel from your module library + from .module_library.sequence_model import SequenceModel + + self.model = SequenceModel( + d_model=self.model_dim, + n_layers=self.n_layers, + transposed=False, # Changed to False - expect (batch, length, dim) + dropout=self.dropout, + layer=self.layer_config, + pool=self.pool_config, + ) + + self.output_embedding = nn.Linear(self.model_dim, self.vocabulary_size) + self.recurrent_state = None + + # loss function (ignoring padding) + self.loss_fn = nn.CrossEntropyLoss( + ignore_index=self.padding_idx, reduction="none" + ) + + # move to GPU + if torch.cuda.is_available(): + self.cuda() + + def forward(self, x): + batch_size, seq_len = x.size() + + # Embed the input + x = self.embedding(x) # (batch_size, seq_len, model_dim) + + # Pass through S4 model (without state in training mode) + x, _ = self.model(x, state=None) + + # Project to vocabulary + logits = self.output_embedding(x) # (batch_size, seq_len, vocab_size) + + return logits + + def reset_state(self, batch_size, device=None): + if device is None: + device = self.device + self.recurrent_state = self.model.default_state(batch_size, device=device) + + def recurrent_step(self, x_t): + if x_t.dim() == 1: + x_t = x_t.unsqueeze(1) + + x_t = self.embedding(x_t).squeeze(1) # (batch_size, model_dim) + x_t, state = self.model.step(x_t, state=self.recurrent_state) + self.recurrent_state = state + x_t = self.output_embedding(x_t) # (batch_size, vocab_size) + + return x_t + + def loss(self, batch): + if len(batch) == 3: + padded, lengths, _ = batch + else: + padded, lengths = batch + + padded = padded.to(self.device) + + # Handle different input formats + # RNN format is typically (seq_len, batch_size) + # S4/Transformer format is typically (batch_size, seq_len) + if padded.dim() == 2: + if padded.shape[0] > padded.shape[1]: + # Likely (seq_len, batch_size), transpose to (batch_size, seq_len) + padded = padded.transpose(0, 1) + + batch_size = padded.shape[0] + seq_len = padded.shape[1] + + # Don't use recurrent state during training - use full convolution mode + self.recurrent_state = None + + # Forward pass + logits = self(padded) # (batch_size, seq_len, vocab_size) + + # Calculate loss + # Shift targets: predict next token + targets = padded[:, 1:] # (batch_size, seq_len-1) + logits = logits[:, :-1, :] # (batch_size, seq_len-1, vocab_size) + + # Reshape for loss calculation + loss = 0.0 + actual_len = min(logits.shape[1], targets.shape[1]) + + for char_idx in range(actual_len): + loss += self.loss_fn(logits[:, char_idx, :], targets[:, char_idx]) + + return loss.mean() + + def sample( + self, + *, + n_sequences, + max_len=None, + return_smiles=True, + return_losses=False, + descriptors=None, + ): + if max_len is None: + max_len = self.max_len + + # IMPORTANT: Set model to eval mode before sampling + self.eval() + + # Setup for recurrent mode + for module in self.model.modules(): + if hasattr(module, "setup_step"): + module.setup_step() + + # Reset state + self.reset_state(n_sequences, device=self.device) + + # Get start/stop tokens + start_token = self.vocabulary.dictionary["SOS"] + stop_token = self.vocabulary.dictionary["EOS"] + pad_token = self.vocabulary.dictionary[""] + + # Create start token tensor + inputs = ( + torch.empty(n_sequences) + .fill_(start_token) + .long() + .to(self.device) + ) + + # Setup loss function + loss_fn = nn.NLLLoss(reduction="none", ignore_index=pad_token) + + # Sample sequences + finished = torch.zeros(n_sequences).byte().to(self.device) + log_probs = torch.zeros(n_sequences).to(self.device) + sequences = [] + + with torch.no_grad(): # Also add no_grad for efficiency + for step in range(max_len): + # Get logits for current input + logits = self.recurrent_step(inputs) + + # Clamp logits to prevent inf/nan + logits = torch.clamp(logits, min=-1e4, max=1e4) + + # Sample from distribution + prob = F.softmax(logits, dim=-1) + + # Check for invalid values + if torch.isnan(prob).any() or torch.isinf(prob).any(): + break + + outputs = torch.multinomial(prob, num_samples=1).squeeze(1) + + sequences.append(outputs.view(-1, 1)) + + # Calculate NLL + log_prob = F.log_softmax(logits, dim=-1) + losses = loss_fn(log_prob, outputs) + + # Zero losses if we are finished sampling + losses[finished.bool()] = 0 + log_probs += losses + + # Update inputs for next step + inputs = outputs + + # Track whether sampling is done for all molecules + finished = torch.ge(finished + (outputs == stop_token), 1) + if torch.prod(finished) == 1: + break + + # Concatenate sequences and decode + seqs = torch.cat(sequences, 1) if sequences else torch.empty( + n_sequences, 1, dtype=torch.long + ).fill_(start_token).to(self.device) + + if return_smiles: + outputs = [self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs] + else: + outputs = sequences + + # Optionally return losses + if return_losses: + return outputs, log_probs.detach().cpu().numpy() + else: + return outputs + class RNN(nn.Module): def __init__( @@ -413,29 +1227,33 @@ def forward(self, x): return logits def loss(self, batch): - # extract the elements of a single minibatch - padded, lengths = batch - # tranpose padded to batch_size * seq_len - padded = padded.transpose(0, 1) - # move to the gpu + if len(batch) == 3: + padded, lengths, _ = batch + else: + padded, lengths = batch + padded = padded.to(self.device) - - # pass through the entire transformer model - decoded = self(padded) - # -> decoded: batch_size x max_len x vocab_size - - # finally, calculate loss + + # Get actual sequence length from batch + actual_seq_len = padded.shape[1] + + decoded = self(padded) # batch_size x seq_len x vocab_size + loss = 0.0 - max_len = max(lengths) - targets = padded[:, 1:] - for char_idx in range(max_len - 1): - loss += self.loss_fn(decoded[:, char_idx], targets[:, char_idx]) - + targets = padded[:, 1:] # batch_size x (seq_len-1) + + # Loop only up to actual decoded sequence length minus 1 + for char_idx in range(min(actual_seq_len - 1, decoded.shape[1], targets.shape[1])): + loss += self.loss_fn(decoded[:, char_idx, :], targets[:, char_idx]) + return loss.mean() def sample( self, *, n_sequences, return_smiles=True, return_losses=False, descriptors=None ): + # Reset recurrent state before sampling + self.reset_state(n_sequences, device=self.device) + # get start/stop tokens start_token = self.vocabulary.dictionary["SOS"] stop_token = self.vocabulary.dictionary["EOS"] @@ -459,7 +1277,14 @@ def sample( sequences = [] for step in range(self.max_len): logits = self(inputs)[:, -1, :] + # Clamp logits to prevent inf/nan + logits = torch.clamp(logits, min=-1e4, max=1e4) prob = F.softmax(logits, dim=-1) + + # Check for invalid values and skip if found + if torch.isnan(prob).any() or torch.isinf(prob).any(): + break + outputs = torch.multinomial(prob, num_samples=1) # append to growing sequence inputs = torch.cat((inputs, outputs), dim=1) @@ -476,7 +1301,7 @@ def sample( break # concatenate sequences and decode - seqs = torch.cat(sequences, 1) + seqs = torch.cat(sequences, 1) if sequences else torch.empty(n_sequences, 1, dtype=torch.long).fill_(start_token).to(self.device) if return_smiles: outputs = [self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs] else: diff --git a/src/clm/module_library/README.md b/src/clm/module_library/README.md new file mode 100644 index 00000000..42b03fa3 --- /dev/null +++ b/src/clm/module_library/README.md @@ -0,0 +1 @@ +These modules are heavily borrowed from the [original codebase for S4](https://github.com/HazyResearch/state-spaces) and empower the S4 model. Visit the original repository for more information. \ No newline at end of file diff --git a/src/clm/module_library/__init__.py b/src/clm/module_library/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/clm/module_library/cauchy.py b/src/clm/module_library/cauchy.py new file mode 100644 index 00000000..e774a534 --- /dev/null +++ b/src/clm/module_library/cauchy.py @@ -0,0 +1,18 @@ +import torch + + +_conj = lambda x: torch.cat([x, x.conj()], dim=-1) + + +def cauchy_naive(v, z, w, conj=True): + """ + v: (..., N) + z: (..., L) + w: (..., N) + returns: (..., L) \sum v/(z-w) + """ + if conj: + v = _conj(v) + w = _conj(w) + cauchy_matrix = v.unsqueeze(-1) / (z.unsqueeze(-2) - w.unsqueeze(-1)) # (... N L) + return torch.sum(cauchy_matrix, dim=-2) diff --git a/src/clm/module_library/dplr.py b/src/clm/module_library/dplr.py new file mode 100644 index 00000000..03ee00a6 --- /dev/null +++ b/src/clm/module_library/dplr.py @@ -0,0 +1,129 @@ +import math +import torch +from einops import repeat +from . import hippo + + +def dplr( + scaling="linear", + N=64, + rank=1, + H=1, + dtype=torch.float, + real_scale=1.0, + imag_scale=1.0, + random_real=False, + random_imag=False, + normalize=False, + diagonal=True, + random_B=False, +): + assert dtype == torch.float or dtype == torch.double + dtype = torch.cfloat if dtype == torch.float else torch.cdouble + + pi = torch.tensor(math.pi) + if random_real: + real_part = torch.rand(H, N // 2) + else: + real_part = 0.5 * torch.ones(H, N // 2) + if random_imag: + imag_part = N // 2 * torch.rand(H, N // 2) + else: + imag_part = repeat(torch.arange(N // 2), "n -> h n", h=H) + + real_part = real_scale * real_part + if scaling == "random": + imag_part = torch.randn(H, N // 2) + elif scaling == "real": + imag_part = 0 * imag_part + real_part = 1 + repeat(torch.arange(N // 2), "n -> h n", h=H) + elif scaling in ["linear", "lin"]: + imag_part = pi * imag_part + elif scaling in [ + "inverse", + "inv", + ]: # Based on asymptotics of the default HiPPO matrix + imag_part = 1 / pi * N * (N / (1 + 2 * imag_part) - 1) + elif scaling in ["inverse2", "inv2"]: + imag_part = 1 / pi * N * (N / (1 + imag_part) - 1) + elif scaling in ["quadratic", "quad"]: + imag_part = 1 / pi * (1 + 2 * imag_part) ** 2 + elif scaling in ["legs", "hippo"]: + w, _, _, _ = hippo.nplr("legsd", N) + imag_part = w.imag + + else: + raise NotImplementedError + imag_part = imag_scale * imag_part + w = -real_part + 1j * imag_part + + # Initialize B + if random_B: + B = torch.randn(H, N // 2, dtype=dtype) + else: + B = torch.ones(H, N // 2, dtype=dtype) + + if normalize: + norm = ( + -B / w + ) # (H, N) # Result if you integrate the kernel with constant 1 function + zeta = 2 * torch.sum( + torch.abs(norm) ** 2, dim=-1, keepdim=True + ) # Variance with a random C vector + B = B / zeta**0.5 + + P = torch.randn(rank, H, N // 2, dtype=dtype) + if diagonal: + P = P * 0.0 + V = torch.eye(N, dtype=dtype)[:, : N // 2] # Only used in testing + V = repeat(V, "n m -> h n m", h=H) + + return w, P, B, V + + +def ssm(measure, N, R, H, **ssm_args): + """Dispatcher to create single SSM initialization + + N: state size + R: rank (for DPLR parameterization) + H: number of independent SSM copies + """ + + if measure == "dplr": + w, P, B, V = dplr(N=N, rank=R, H=H, **ssm_args) + elif measure.startswith("diag"): + args = measure.split("-") + assert args[0] == "diag" and len(args) > 1 + scaling = args[1] + w, P, B, V = dplr(scaling=scaling, N=N, rank=R, H=H, diagonal=True, **ssm_args) + else: + w, P, B, V = hippo.nplr(measure, N, R, **ssm_args) + w = repeat(w, "n -> s n", s=H) + P = repeat(P, "r n -> r s n", s=H) + B = repeat(B, "n -> s n", s=H) + V = repeat(V, "n m -> s n m", s=H) + return w, P, B, V + + +combinations = { + "hippo": ["legs", "fourier"], + "diag": ["diag-inv", "diag-lin"], + "all": ["legs", "fourier", "diag-inv", "diag-lin"], +} + + +def combination(measures, N, R, S, **ssm_args): + if isinstance(measures, str): + measures = combinations[measures] if measures in combinations else [measures] + + assert ( + S % len(measures) == 0 + ), f"{S} independent trainable SSM copies must be multiple of {len(measures)} different measures" + w, P, B, V = zip( + *[ssm(measure, N, R, S // len(measures), **ssm_args) for measure in measures] + ) + w = torch.cat(w, dim=0) # (S N) + P = torch.cat(P, dim=1) # (R S N) + B = torch.cat(B, dim=0) # (S N) + V = torch.cat(V, dim=0) # (S N N) + return w, P, B, V diff --git a/src/clm/module_library/ff.py b/src/clm/module_library/ff.py new file mode 100644 index 00000000..16341a0b --- /dev/null +++ b/src/clm/module_library/ff.py @@ -0,0 +1,65 @@ +from functools import partial +from torch import nn +from .sequence_module import SequenceModule +from .util_modules import LinearActivation, DropoutNd + + +class FF(SequenceModule): + def __init__( + self, + d_input, + # expand=2, # changed the default value from 2 to 4 + expand=4, # changed the default value from 2 to 4 + d_output=None, + transposed=False, + activation="gelu", + initializer=None, + dropout=0.0, + tie_dropout=False, + ): + super().__init__() + self.d_output = d_input if d_output is None else d_output + self.transposed = transposed + d_inner = expand * d_input + + linear1 = LinearActivation( + d_input, + d_inner, + transposed=transposed, + activation=activation, + initializer=initializer, + activate=True, + ) + dropout_cls = ( + partial(DropoutNd, transposed=self.transposed) + if tie_dropout + else nn.Dropout + ) + # dropout_cls = nn.Dropout2d if self.transposed else nn.Dropout + drop = dropout_cls(dropout) if dropout > 0.0 else nn.Identity() + + linear2 = LinearActivation( + d_inner, + self.d_output, + transposed=transposed, + activation=None, + initializer=initializer, + activate=False, + ) + + self.ff = nn.Sequential( + linear1, + drop, + linear2, + ) + + def forward(self, x, *args, **kwargs): + return self.ff(x), None + + def step(self, x, state, **kwargs): + # x: [batch, d_input] + if self.transposed: + # expects: [batch, d_input, seq_len] + return self.ff(x.unsqueeze(-1)).squeeze(-1), state + else: + return self.ff(x), state diff --git a/src/clm/module_library/hippo.py b/src/clm/module_library/hippo.py new file mode 100644 index 00000000..9bd1daba --- /dev/null +++ b/src/clm/module_library/hippo.py @@ -0,0 +1,277 @@ +import torch +import numpy as np +from scipy import special as ss +from einops import rearrange +from opt_einsum import contract + + +def embed_c2r(A): + A = rearrange(A, "... m n -> ... m () n ()") + A = np.pad(A, ((0, 0), (0, 1), (0, 0), (0, 1))) + np.pad( + A, ((0, 0), (1, 0), (0, 0), (1, 0)) + ) + return rearrange(A, "m x n y -> (m x) (n y)") + + +# TODO take in 'torch' option to return torch instead of numpy, and converts the shape of B from (N, 1) to (N) +def transition(measure, N, **measure_args): + """A, B transition matrices for different measures + + measure: the type of measure + legt - Legendre (translated) + legs - Legendre (scaled) + glagt - generalized Laguerre (translated) + lagt, tlagt - previous versions of (tilted) Laguerre with slightly different normalization + """ + # Laguerre (translated) + if measure == "lagt": + b = measure_args.get("beta", 1.0) + A = np.eye(N) / 2 - np.tril(np.ones((N, N))) + B = b * np.ones((N, 1)) + # Generalized Laguerre + # alpha 0, beta small is most stable (limits to the 'lagt' measure) + # alpha 0, beta 1 has transition matrix A = [lower triangular 1] + elif measure == "glagt": + alpha = measure_args.get("alpha", 0.0) + beta = measure_args.get("beta", 0.01) + A = -np.eye(N) * (1 + beta) / 2 - np.tril(np.ones((N, N)), -1) + B = ss.binom(alpha + np.arange(N), np.arange(N))[:, None] + + L = np.exp( + 0.5 * (ss.gammaln(np.arange(N) + alpha + 1) - ss.gammaln(np.arange(N) + 1)) + ) + A = (1.0 / L[:, None]) * A * L[None, :] + B = ( + (1.0 / L[:, None]) + * B + * np.exp(-0.5 * ss.gammaln(1 - alpha)) + * beta ** ((1 - alpha) / 2) + ) + # Legendre (translated) + elif measure == "legt": + Q = np.arange(N, dtype=np.float64) + R = (2 * Q + 1) ** 0.5 + j, i = np.meshgrid(Q, Q) + A = R[:, None] * np.where(i < j, (-1.0) ** (i - j), 1) * R[None, :] + B = R[:, None] + A = -A + + # Halve again for timescale correctness + A *= 0.5 + B *= 0.5 + # LMU: equivalent to LegT up to normalization + elif measure == "lmu": + Q = np.arange(N, dtype=np.float64) + R = (2 * Q + 1)[:, None] # / theta + j, i = np.meshgrid(Q, Q) + A = np.where(i < j, -1, (-1.0) ** (i - j + 1)) * R + B = (-1.0) ** Q[:, None] * R + # Legendre (scaled) + elif measure == "legs": + q = np.arange(N, dtype=np.float64) + col, row = np.meshgrid(q, q) + r = 2 * q + 1 + M = -(np.where(row >= col, r, 0) - np.diag(q)) + T = np.sqrt(np.diag(2 * q + 1)) + A = T @ M @ np.linalg.inv(T) + B = np.diag(T)[:, None] + B = ( + B.copy() + ) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) + elif measure == "legsd": + q = np.arange(N, dtype=np.float64) + col, row = np.meshgrid(q, q) + r = 2 * q + 1 + M = -(np.where(row >= col, r, 0) - np.diag(q)) + T = np.sqrt(np.diag(2 * q + 1)) + A = T @ M @ np.linalg.inv(T) + B = np.diag(T)[:, None] + B = ( + B.copy() + ) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) + A += 0.5 * B * B[None, :, 0] + B = B / 2.0 + elif measure in ["fourier_diag", "foud"]: + freqs = np.arange(N // 2) + d = np.stack([freqs, np.zeros(N // 2)], axis=-1).reshape(-1)[:-1] + A = 2 * np.pi * (-np.diag(d, 1) + np.diag(d, -1)) + A = A - 0.5 * np.eye(N) + B = np.zeros(N) + B[0::2] = 2**0.5 + B[0] = 1 + B = B[:, None] + elif measure in ["fourier", "fout"]: + freqs = np.arange(N // 2) + d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:] + A = np.pi * (-np.diag(d, 1) + np.diag(d, -1)) + B = np.zeros(N) + B[0::2] = 2**0.5 + B[0] = 1 + + # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case + A = A - B[:, None] * B[None, :] + B = B[:, None] + elif measure == "fourier_decay": + freqs = np.arange(N // 2) + d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:] + A = np.pi * (-np.diag(d, 1) + np.diag(d, -1)) + B = np.zeros(N) + B[0::2] = 2**0.5 + B[0] = 1 + + # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case + A = A - 0.5 * B[:, None] * B[None, :] + B = 0.5 * B[:, None] + elif measure == "fourier2": # Double everything: orthonormal on [0, 1] + freqs = 2 * np.arange(N // 2) + d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:] + A = np.pi * (-np.diag(d, 1) + np.diag(d, -1)) + B = np.zeros(N) + B[0::2] = 2**0.5 + B[0] = 1 + + # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case + A = A - B[:, None] * B[None, :] * 2 + B = B[:, None] * 2 + elif measure == "random": + A = np.random.randn(N, N) / N + B = np.random.randn(N, 1) + elif measure == "diagonal": + A = -np.diag(np.exp(np.random.randn(N))) + B = np.random.randn(N, 1) + else: + raise NotImplementedError + + return A, B + + +def rank_correction(measure, N, rank=1, dtype=torch.float): + """Return low-rank matrix L such that A + L is normal""" + + if measure == "legs": + assert rank >= 1 + P = torch.sqrt(0.5 + torch.arange(N, dtype=dtype)).unsqueeze(0) # (1 N) + elif measure == "legt": + assert rank >= 2 + P = torch.sqrt(1 + 2 * torch.arange(N, dtype=dtype)) # (N) + P0 = P.clone() + P0[0::2] = 0.0 + P1 = P.clone() + P1[1::2] = 0.0 + P = torch.stack([P0, P1], dim=0) # (2 N) + P *= 2 ** ( + -0.5 + ) # Halve the rank correct just like the original matrix was halved + elif measure == "lagt": + assert rank >= 1 + P = 0.5**0.5 * torch.ones(1, N, dtype=dtype) + elif measure in ["fourier", "fout"]: + P = torch.zeros(N) + P[0::2] = 2**0.5 + P[0] = 1 + P = P.unsqueeze(0) + elif measure == "fourier_decay": + P = torch.zeros(N) + P[0::2] = 2**0.5 + P[0] = 1 + P = P.unsqueeze(0) + P = P / 2**0.5 + elif measure == "fourier2": + P = torch.zeros(N) + P[0::2] = 2**0.5 + P[0] = 1 + P = 2**0.5 * P.unsqueeze(0) + elif measure in ["fourier_diag", "foud", "legsd"]: + P = torch.zeros(1, N, dtype=dtype) + else: + raise NotImplementedError + + d = P.size(0) + if rank > d: + P = torch.cat([P, torch.zeros(rank - d, N, dtype=dtype)], dim=0) # (rank N) + return P + + +def initial_C(measure, N, dtype=torch.float): + """Return C that captures the other endpoint in the HiPPO approximation""" + + if measure == "legt": + C = (torch.arange(N, dtype=dtype) * 2 + 1) ** 0.5 * (-1) ** torch.arange(N) + elif measure == "fourier": + C = torch.zeros(N) + C[0::2] = 2**0.5 + C[0] = 1 + else: + C = torch.zeros(N, dtype=dtype) # (N) + + return C + + +def nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=True): + """Return w, p, q, V, B such that + (w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V + i.e. A = V[w - p q^*]V^*, B = V B + """ + assert dtype == torch.float or dtype == torch.double + cdtype = torch.cfloat if dtype == torch.float else torch.cdouble + + A, B = transition(measure, N) + A = torch.as_tensor(A, dtype=dtype) # (N, N) + B = torch.as_tensor(B, dtype=dtype)[:, 0] # (N,) + + P = rank_correction(measure, N, rank=rank, dtype=dtype) # (r N) + AP = A + torch.sum(P.unsqueeze(-2) * P.unsqueeze(-1), dim=-3) + + # We require AP to be nearly skew-symmetric + _A = AP + AP.transpose(-1, -2) + if ( + err := torch.sum((_A - _A[0, 0] * torch.eye(N)) ** 2) / N + ) > 1e-5: # if not torch.allclose(_A - _A[0,0]*torch.eye(N), torch.zeros(N, N), atol=1e-5): + print("WARNING: HiPPO matrix not skew symmetric", err) + + # Take advantage of identity + skew-symmetric form to calculate real and imaginary parts separately + # Imaginary part can use eigh instead of eig + w_re = torch.mean(torch.diagonal(AP), -1, keepdim=True) + + # Diagonalize in double precision + if diagonalize_precision: + AP = AP.to(torch.double) + # w, V = torch.linalg.eig(AP) # (..., N) (..., N, N) + w_im, V = torch.linalg.eigh(AP * -1j) # (..., N) (..., N, N) + if diagonalize_precision: + w_im, V = w_im.to(cdtype), V.to(cdtype) + w = w_re + 1j * w_im + # Check: V w V^{-1} = A + # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)) + + # Only keep half of each conjugate pair + _, idx = torch.sort(w.imag) + w_sorted = w[idx] + V_sorted = V[:, idx] + + # There is an edge case when eigenvalues can be 0, which requires some machinery to handle + # We use a huge hack here: Assume only one pair is 0, and that it is the first row/column of A (only happens in Fourier case) + V = V_sorted[:, : N // 2] + w = w_sorted[: N // 2] + assert w[-2].abs() > 1e-4, "Only 1 zero eigenvalue allowed in diagonal part of A" + if w[-1].abs() < 1e-4: + V[:, -1] = 0.0 + V[0, -1] = 2**-0.5 + V[1, -1] = 2**-0.5 * 1j + + _AP = V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2) + if (err := torch.sum((2 * _AP.real - AP) ** 2) / N) > 1e-5: + print( + "Warning: Diagonalization of A matrix not numerically precise - error", err + ) + # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)) + + V_inv = V.conj().transpose(-1, -2) + + # C = initial_C(measure, N, dtype=dtype) + B = contract("ij, j -> i", V_inv, B.to(V)) # V^* B + # C = contract('ij, j -> i', V_inv, C.to(V)) # V^* C + P = contract("ij, ...j -> ...i", V_inv, P.to(V)) # V^* P + + # return w, P, B, C, V + return w, P, B, V diff --git a/src/clm/module_library/kernel.py b/src/clm/module_library/kernel.py new file mode 100644 index 00000000..6f9425cf --- /dev/null +++ b/src/clm/module_library/kernel.py @@ -0,0 +1,718 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from einops import rearrange, repeat +from opt_einsum import contract, contract_expression + +from . import dplr +from .krylov import krylov, power +from .cauchy import cauchy_naive + + +_conj = lambda x: torch.cat([x, x.conj()], dim=-1) +_c2r = torch.view_as_real +_r2c = torch.view_as_complex + +if tuple(map(int, torch.__version__.split(".")[:2])) >= (1, 10): + _resolve_conj = lambda x: x.conj().resolve_conj() +else: + _resolve_conj = lambda x: x.conj() + + +class OptimModule(nn.Module): + """Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters""" + + def register(self, name, tensor, lr=None): + """Register a tensor with a configurable learning rate and 0 weight decay""" + + if lr == 0.0: + self.register_buffer(name, tensor) + else: + self.register_parameter(name, nn.Parameter(tensor)) + + optim = {"weight_decay": 0.0} + if lr is not None: + optim["lr"] = lr + setattr(getattr(self, name), "_optim", optim) + + +class SSKernelNPLR(OptimModule): + """ + Stores a representation of and computes the SSKernel function K_L(dt, A, B, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR) + """ + + @torch.no_grad() + def _setup_C(self, L): + """Construct C~ from C + + Two modes are supported: go directly to length L if self.L is 1, or length is doubled + """ + + if self.L.item() == 0: + double_length = False + elif L > self.L.item(): # 2*int(self.L) == L: + double_length = True + L = self.L.item() # Convenience for the math below + else: + return + + C = _r2c(self.C) + dA, _ = self._setup_state() + dA_L = power(L, dA) + # Multiply C by I - dA_L + C_ = _conj(C) + prod = contract("h m n, c h n -> c h m", dA_L.transpose(-1, -2), C_) + if double_length: + prod = -prod # Multiply by I + dA_L instead + C_ = C_ - prod + C_ = C_[..., : self.N] # Take conjugate pairs again + self.C.copy_(_c2r(C_)) + + self.L = 2 * self.L if double_length else self.L + L # Preserve type/device + + def _omega(self, L, dtype, device, cache=True): + """Calculate (and cache) FFT nodes and their "unprocessed" version with the bilinear transform + This should be called everytime the internal length self.L changes""" + + # Use cached if available + if cache and hasattr(self, "omega") and self.omega.size(-1) == L // 2 + 1: + return self.omega, self.z + + omega = torch.tensor( + np.exp(-2j * np.pi / (L)), dtype=dtype, device=device + ) # \omega_{2L} + omega = omega ** torch.arange(0, L // 2 + 1, device=device) + z = 2 * (1 - omega) / (1 + omega) + + # Cache if necessary + if cache: + self.omega = omega + self.z = z + return omega, z + + def __init__( + self, + w, + P, + B, + C, + log_dt, + L=None, # starting/maximum length of kernel + lr=None, + verbose=False, + keops=False, + real_type="exp", # ['none' | 'exp' | 'relu' | sigmoid'] + real_tolerance=1e-3, + bandlimit=None, + ): + """ + L: Maximum length; this module computes an SSM kernel of length L + A is represented by diag(w) - PP^* + w: (S, N) diagonal part + P: (R, S, N) low-rank part + + B: (S, N) + C: (C, H, N) + dt: (H) timescale per feature + lr: [dict | float | None] hook to set lr of special parameters (A, B, dt) + + Dimensions: + N (or d_state): state size + H (or d_model): total SSM copies + S (or n_ssm): number of trainable copies of (A, B, dt); must divide H + R (or rank): rank of low-rank part + C (or channels): system is 1-dim to C-dim + + The forward pass of this Module returns a tensor of shape (C, H, L) + + Note: tensor shape N here denotes half the true state size, because of conjugate symmetry + """ + + super().__init__() + self.verbose = verbose + self.keops = keops + self.bandlimit = bandlimit + self.real_type = real_type + self.real_tolerance = real_tolerance + + # Rank of low-rank correction + self.rank = P.shape[-3] + assert w.size(-1) == P.size(-1) == B.size(-1) == C.size(-1) + self.H = log_dt.size(-1) + self.N = w.size(-1) + + # Check different SSM inits + assert w.size(-2) == P.size(-2) == B.size(-2) # n_ssm + assert self.H % w.size(0) == 0 + self.n_ssm = w.size(0) + self.repeat = self.H // w.size( + 0 + ) # Each trainable SSM needs to be duplicated this many times + + # Broadcast everything to correct shapes + C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))) # (C, H, N) + B = B.unsqueeze(0) # (1, 1, N) + + # Register parameters + self.C = nn.Parameter(_c2r(_resolve_conj(C))) + if lr is None or isinstance(lr, float): + lr_dict = {} + else: + lr_dict, lr = lr, None + self.register("log_dt", log_dt, lr_dict.get("dt", lr)) + self.register("B", _c2r(B), lr_dict.get("B", lr)) + self.register("P", _c2r(P), lr_dict.get("A", lr)) + self.register("inv_w_real", self._w_init(w.real), lr_dict.get("A", lr)) + self.register("w_imag", w.imag, lr_dict.get("A", lr)) + + self.l_max = L + self.register_buffer("L", torch.tensor(0)) # Internal length + + def _w_init(self, w_real): + w_real = torch.clamp(w_real, max=-self.real_tolerance) + if self.real_type == "none": + return -w_real + elif self.real_type == "exp": + return torch.log(-w_real) # Some of the HiPPO methods have real part 0 + elif self.real_type == "relu": + return -w_real + elif self.real_type == "sigmoid": + return torch.logit(-w_real) + elif self.real_type == "softplus": + return torch.log(torch.exp(-w_real) - 1) + else: + raise NotImplementedError + + def _w(self): + # Get the internal w (diagonal) parameter + if self.real_type == "none": + w_real = -self.inv_w_real + elif self.real_type == "exp": + w_real = -torch.exp(self.inv_w_real) + elif self.real_type == "relu": + w_real = -F.relu(self.inv_w_real) + elif self.real_type == "sigmoid": + w_real = -F.sigmoid(self.inv_w_real) + elif self.real_type == "softplus": + w_real = -F.softplus(self.inv_w_real) + else: + raise NotImplementedError + w = w_real + 1j * self.w_imag + return w + + def forward(self, state=None, rate=1.0, L=None): + """ + state: (B, H, N) initial state + rate: sampling rate factor + L: target length + + returns: + (C, H, L) convolution kernel (generally C=1) + (B, H, L) output from initial state + """ + + # Initialize C~ if necessary (done in forward pass so it's on the correct device) + if self.L.item() == 0 and self.l_max is not None and self.l_max > 0: + self._setup_C(self.l_max) + + # Handle sampling rate logic + # The idea is that this kernel's length (in continuous units) is self.L, while we are asked to provide a kernel of length L at (relative) frequency rate + if L is None: + L = round(self.L.item() / rate) + + # Increase the internal length if needed + continuous_L = round(rate * L) + while continuous_L > self.L.item(): + self._setup_C(continuous_L) + discrete_L = round(self.L.item() / rate) + + dt = torch.exp(self.log_dt) * rate + B = _r2c(self.B) + C = _r2c(self.C) + P = _r2c(self.P) + Q = P.conj() + w = self._w() # (S, N) where S=n_ssm + + # Address bandlimiting + if self.bandlimit is not None: + freqs = w.imag.abs() / (2 * math.pi) # (H, N) + freqs = dt[:, None] / rate * freqs # (H, N) + mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0) + C = C * mask + + # Get FFT nodes of right length + omega, z = self._omega( + discrete_L, dtype=w.dtype, device=w.device, cache=(rate == 1.0) + ) + + # Broadcast parameters to same hidden features H + B = repeat(B, "1 t n -> 1 (v t) n", v=self.repeat) + P = repeat(P, "r t n -> r (v t) n", v=self.repeat) + Q = repeat(Q, "r t n -> r (v t) n", v=self.repeat) + w = repeat(w, "t n -> (v t) n", v=self.repeat) + + # Augment B + if state is not None: + # Have to "unbilinear" the state to put it into the same "type" as B + # Compute 1/dt * (I + dt/2 A) @ state + + # Can do this without expanding (maybe minor speedup using conj symmetry in theory), but it's easier to read this way + s = _conj(state) if state.size(-1) == self.N else state # (B H N) + sA = s * _conj(w) - contract( # (B H N) + "bhm, rhm, rhn -> bhn", s, _conj(Q), _conj(P) + ) + s = s / dt.unsqueeze(-1) + sA / 2 + s = s[..., : self.N] + + B = torch.cat([s, B], dim=-3) # (B+1, H, N) + + # Incorporate dt into A + w = w * dt.unsqueeze(-1) # (H N) + + # Stack B and p, C and q for convenient batching + B = torch.cat([B, P], dim=-3) # (B+1+R, H, N) + C = torch.cat([C, Q], dim=-3) # (C+R, H, N) + + # Incorporate B and C batch dimensions + v = B.unsqueeze(-3) * C.unsqueeze(-4) # (B+1+R, C+R, H, N) + + # Calculate resolvent at omega + # if has_cauchy_extension and z.dtype == torch.cfloat and not self.keops: + # r = cauchy_mult(v, z, w, symmetric=True) + # elif has_pykeops: + # r = cauchy_conj(v, z, w) + # else: + r = cauchy_naive(v, z, w) + r = r * dt[None, None, :, None] # (B+1+R, C+R, H, L) + + # Low-rank Woodbury correction + if self.rank == 1: + k_f = r[:-1, :-1, :, :] - r[:-1, -1:, :, :] * r[-1:, :-1, :, :] / ( + 1 + r[-1:, -1:, :, :] + ) + elif self.rank == 2: + r00 = r[: -self.rank, : -self.rank, :, :] + r01 = r[: -self.rank, -self.rank :, :, :] + r10 = r[-self.rank :, : -self.rank, :, :] + r11 = r[-self.rank :, -self.rank :, :, :] + det = (1 + r11[:1, :1, :, :]) * (1 + r11[1:, 1:, :, :]) - r11[ + :1, 1:, :, : + ] * r11[1:, :1, :, :] + s = ( + r01[:, :1, :, :] * (1 + r11[1:, 1:, :, :]) * r10[:1, :, :, :] + + r01[:, 1:, :, :] * (1 + r11[:1, :1, :, :]) * r10[1:, :, :, :] + - r01[:, :1, :, :] * (r11[:1, 1:, :, :]) * r10[1:, :, :, :] + - r01[:, 1:, :, :] * (r11[1:, :1, :, :]) * r10[:1, :, :, :] + ) + s = s / det + k_f = r00 - s + else: + r00 = r[: -self.rank, : -self.rank, :, :] + r01 = r[: -self.rank, -self.rank :, :, :] + r10 = r[-self.rank :, : -self.rank, :, :] + r11 = r[-self.rank :, -self.rank :, :, :] + r11 = rearrange(r11, "a b h n -> h n a b") + r11 = torch.linalg.inv(torch.eye(self.rank, device=r.device) + r11) + r11 = rearrange(r11, "h n a b -> a b h n") + k_f = r00 - torch.einsum( + "i j h n, j k h n, k l h n -> i l h n", r01, r11, r10 + ) + + # Final correction for the bilinear transform + k_f = k_f * 2 / (1 + omega) + + # Move from frequency to coefficients + k = torch.fft.irfft(k_f, n=discrete_L) # (B+1, C, H, L) + + # # Truncate to target length + k = k[..., :L] + + if state is not None: + k_state = k[:-1, :, :, :] # (B, C, H, L) + else: + k_state = None + k_B = k[-1, :, :, :] # (C H L) + + return k_B, k_state + + @torch.no_grad() + def double_length(self): + self._setup_C(2 * self.L) + + @torch.no_grad() + def _check(self): + """Check if A, B, C parameters and vanilla SSKernel construction can be recovered""" + + # assert self.L > 0, "Set up module first" + + K = self.forward(L=self.l_max)[0] + + self._setup_step() + K_ = krylov(self.l_max, self.dA, self.dB, self.dC) + + diff = K - K_ + + @torch.no_grad() + def _setup_linear(self): + """Create parameters that allow fast linear stepping of state""" + w = self._w() + B = _r2c(self.B) # (H N) + P = _r2c(self.P) + Q = P.conj() + + # Repeat w shape properly + B = repeat(B, "1 t n -> 1 (v t) n", v=self.repeat) + P = repeat(P, "r t n -> r (v t) n", v=self.repeat) + Q = repeat(Q, "r t n -> r (v t) n", v=self.repeat) + w = repeat(w, "t n -> (v t) n", v=self.repeat) + + # Prepare Linear stepping + dt = torch.exp(self.log_dt) + D = (2.0 / dt.unsqueeze(-1) - w).reciprocal() # (H, N) + R = ( + torch.eye(self.rank, dtype=w.dtype, device=w.device) + + 2 * contract("r h n, h n, s h n -> h r s", Q, D, P).real + ) # (H R R) + Q_D = rearrange(Q * D, "r h n -> h r n") + try: + R = torch.linalg.solve(R, Q_D) # (H R N) + except: + R = torch.tensor( + np.linalg.solve( + R.to(Q_D).contiguous().detach().cpu(), + Q_D.contiguous().detach().cpu(), + ) + ).to(Q_D) + R = rearrange(R, "h r n -> r h n") + + self.step_params = { + "D": D, # (H N) + "R": R, # (R H N) + "P": P, # (R H N) + "Q": Q, # (R H N) + "B": B, # (1 H N) + "E": 2.0 / dt.unsqueeze(-1) + w, # (H N) + } + + def _step_state_linear(self, u=None, state=None): + """ + Version of the step function that has time O(N) instead of O(N^2) per step, which takes advantage of the DPLR form and bilinear discretization. + + Unfortunately, as currently implemented it's about 2x slower because it calls several sequential operations. Perhaps a fused CUDA kernel implementation would be much faster + + u: (H) input + state: (H, N/2) state with conjugate pairs + Optionally, the state can have last dimension N + Returns: same shape as state + """ + C = _r2c(self.C) # View used for dtype/device + + if u is None: # Special case used to find dA + u = torch.zeros(self.H, dtype=C.dtype, device=C.device) + if state is None: # Special case used to find dB + state = torch.zeros(self.H, self.N, dtype=C.dtype, device=C.device) + + step_params = self.step_params.copy() + if ( + state.size(-1) == self.N + ): # Only store half of the conjugate pairs; should be true by default + # There should be a slightly faster way using conjugate symmetry + contract_fn = lambda p, x, y: contract( + "r h n, r h m, ... h m -> ... h n", _conj(p), _conj(x), _conj(y) + )[ + ..., : self.N + ] # inner outer product + else: + assert state.size(-1) == 2 * self.N + step_params = {k: _conj(v) for k, v in step_params.items()} + # TODO worth setting up a contract_expression in default_state if we want to use this at inference time for stepping + contract_fn = lambda p, x, y: contract( + "r h n, r h m, ... h m -> ... h n", p, x, y + ) # inner outer product + D = step_params["D"] # (H N) + E = step_params["E"] # (H N) + R = step_params["R"] # (R H N) + P = step_params["P"] # (R H N) + Q = step_params["Q"] # (R H N) + B = step_params["B"] # (1 H N) + + new_state = E * state - contract_fn(P, Q, state) # (B H N) + new_state = new_state + 2.0 * B * u.unsqueeze(-1) # (B H N) + new_state = D * (new_state - contract_fn(P, R, new_state)) + + return new_state + + def _setup_state(self): + """Construct dA and dB for discretized state equation""" + + # Construct dA and dB by using the stepping + self._setup_linear() + C = _r2c(self.C) # Just returns a view that we use for finding dtype/device + + state = torch.eye(2 * self.N, dtype=C.dtype, device=C.device).unsqueeze( + -2 + ) # (N 1 N) + dA = self._step_state_linear(state=state) + dA = rearrange(dA, "n h m -> h m n") + + u = C.new_ones(self.H) + dB = self._step_state_linear(u=u) + dB = _conj(dB) + dB = rearrange(dB, "1 h n -> h n") # (H N) + return dA, dB + + def _step_state(self, u, state): + """Must be called after self.default_state() is used to construct an initial state!""" + next_state = self.state_contraction(self.dA, state) + self.input_contraction( + self.dB, u + ) + return next_state + + def _setup_step(self, mode="dense"): + """Set up dA, dB, dC discretized parameters for stepping""" + self.dA, self.dB = self._setup_state() + + # Calculate original C + C = _conj(_r2c(self.C)) # (H C N) + if self.L.item() == 0: + dC = C + else: + # self.C represents C_tilde + dA_L = power(self.L.item(), self.dA) + I = torch.eye(self.dA.size(-1)).to(dA_L) + + dC = torch.linalg.solve( + I - dA_L.transpose(-1, -2), + C.unsqueeze(-1), + ).squeeze(-1) + self.dC = dC + + # Do special preprocessing for different step modes + + self._step_mode = mode + if mode == "linear": + # Linear case: special step function for the state, we need to handle output + # use conjugate symmetry by default, which affects the output projection + self.dC = 2 * self.dC[:, :, : self.N] + elif mode == "diagonal": + # Eigendecomposition of the A matrix + L, V = torch.linalg.eig(self.dA) + V_inv = torch.linalg.inv(V) + # Change the parameterization to diagonalize + self.dA = L + self.dB = contract("h n m, h m -> h n", V_inv, self.dB) + self.dC = contract("h n m, c h n -> c h m", V, self.dC) + + elif mode == "dense": + pass + else: + raise NotImplementedError( + "NPLR Kernel step mode must be {'dense' | 'linear' | 'diagonal'}" + ) + + def default_state(self, *batch_shape): + C = _r2c(self.C) + N = C.size(-1) + H = C.size(-2) + + # Cache the tensor contractions we will later do, for efficiency + # These are put in this function because they depend on the batch size + step_mode = getattr( + self, "_step_mode", "dense" + ) # Used in default_state, which is called without _setup_step() in forward_state() + if step_mode != "linear": + N *= 2 + + if step_mode == "diagonal": + self.state_contraction = contract_expression( + "h n, ... h n -> ... h n", + (H, N), + batch_shape + (H, N), + ) + else: + # Dense (quadratic) case: expand all terms + self.state_contraction = contract_expression( + "h m n, ... h n -> ... h m", + (H, N, N), + batch_shape + (H, N), + ) + + self.input_contraction = contract_expression( + "h n, ... h -> ... h n", + (H, N), # self.dB.shape + batch_shape + (H,), + ) + + self.output_contraction = contract_expression( + "c h n, ... h n -> ... c h", + (C.shape[0], H, N), # self.dC.shape + batch_shape + (H, N), + ) + + state = torch.zeros(*batch_shape, H, N, dtype=C.dtype, device=C.device) + return state + + def step(self, u, state): + """Must have called self._setup_step() and created state with self.default_state() before calling this""" + + if self._step_mode == "linear": + new_state = self._step_state_linear(u, state) + else: + new_state = self._step_state(u, state) + y = self.output_contraction(self.dC, new_state) + return y.real, new_state + + +class SSKernel(nn.Module): + """Wrapper around SSKernel parameterizations. + + The SSKernel is expected to support the interface + forward() + default_state() + _setup_step() + step() + """ + + def __init__( + self, + H, + N=64, + L=None, + measure="legs", + rank=1, + channels=1, + dt_min=0.001, + dt_max=0.1, + deterministic=False, + lr=None, + mode="nplr", + n_ssm=None, + verbose=False, + measure_args={}, + **kernel_args, + ): + """State Space Kernel which computes the convolution kernel $\\bar{K}$ + + H: Number of independent SSM copies; controls the size of the model. Also called d_model in the config. + N: State size (dimensionality of parameters A, B, C). Also called d_state in the config. Generally shouldn't need to be adjusted and doens't affect speed much. + L: Maximum length of convolution kernel, if known. Should work in the majority of cases even if not known. + measure: Options for initialization of (A, B). For NPLR mode, recommendations are "legs", "fout", "hippo" (combination of both). For Diag mode, recommendations are "diag-inv", "diag-lin", "diag-legs", and "diag" (combination of diag-inv and diag-lin) + rank: Rank of low-rank correction for NPLR mode. Needs to be increased for measure "legt" + channels: C channels turns the SSM from a 1-dim to C-dim map; can think of it having C separate "heads" per SSM. This was partly a feature to make it easier to implement bidirectionality; it is recommended to set channels=1 and adjust H to control parameters instead + dt_min, dt_max: min and max values for the step size dt (\Delta) + mode: Which kernel algorithm to use. 'nplr' is the full S4 model; 'diag' is the simpler S4D; 'slow' is a dense version for testing + n_ssm: Number of independent trainable (A, B) SSMs, e.g. n_ssm=1 means all A/B parameters are tied across the H different instantiations of C. n_ssm=None means all H SSMs are completely independent. Generally, changing this option can save parameters but doesn't affect performance or speed much. This parameter must divide H + lr: Passing in a number (e.g. 0.001) sets attributes of SSM parameers (A, B, dt). A custom optimizer hook is needed to configure the optimizer to set the learning rates appropriately for these parameters. + """ + super().__init__() + self.N = N + self.H = H + dtype, cdtype = torch.float, torch.cfloat + self.channels = channels + self.n_ssm = n_ssm if n_ssm is not None else H + self.mode = mode + self.verbose = verbose + self.kernel_args = kernel_args + + # Generate dt + if deterministic: + log_dt = torch.exp(torch.linspace(math.log(dt_min), math.log(dt_max), H)) + else: + log_dt = torch.rand(self.H, dtype=dtype) * ( + math.log(dt_max) - math.log(dt_min) + ) + math.log(dt_min) + + # Compute the preprocessed representation + w, P, B, V = dplr.combination(measure, self.N, rank, self.n_ssm, **measure_args) + + # Broadcast C to have H channels + if deterministic: + C = torch.zeros(channels, self.n_ssm, self.N, dtype=cdtype) + C[:, :, :1] = 1.0 + C = contract("hmn, chn -> chm", V.conj().transpose(-1, -2), C) # V^* C + C = ( + repeat(C, "c t n -> c (v t) n", v=self.n_ssm // C.size(-2)) + .clone() + .contiguous() + ) + else: + C = torch.randn(channels, self.H, self.N // 2, dtype=cdtype) + + # Broadcast other parameters to have n_ssm copies + assert ( + self.n_ssm % B.size(-2) == 0 + and self.n_ssm % P.size(-2) == 0 + and self.n_ssm % w.size(-2) == 0 + ) + # Broadcast tensors to n_ssm copies + # These will be the parameters, so make sure tensors are materialized and contiguous + B = repeat(B, "t n -> (v t) n", v=self.n_ssm // B.size(-2)).clone().contiguous() + P = ( + repeat(P, "r t n -> r (v t) n", v=self.n_ssm // P.size(-2)) + .clone() + .contiguous() + ) + w = repeat(w, "t n -> (v t) n", v=self.n_ssm // w.size(-2)).clone().contiguous() + + self.kernel = SSKernelNPLR( + w, + P, + B, + C, + log_dt, + L=L, + lr=lr, + verbose=verbose, + **kernel_args, + ) + + def forward(self, state=None, L=None, rate=None): + return self.kernel(state=state, L=L, rate=rate) + + @torch.no_grad() + def forward_state(self, u, state): + """Forward the state through a sequence, i.e. computes the state after passing chunk through SSM + + state: (B, H, N) + u: (B, H, L) + + Returns: (B, H, N) + """ + + if hasattr(self.kernel, "forward_state"): + return self.kernel.forward_state(u, state) + + dA, dB = self.kernel._setup_state() # Construct dA, dB matrices + # dA, dB = self.kernel.dA, self.kernel.dB # (H N N) (H N) + + conj = state.size(-1) != dA.size(-1) + if conj: + state = _conj(state) + + v = contract( + "h n, b h l -> b h n l", dB, u.flip(-1) + ) # dB.unsqueeze(-1) * u.flip(-1).unsqueeze(-2) + AL, v = power(u.size(-1), dA, v) + next_state = contract("h m n, b h n -> b h m", AL, state) + next_state = next_state + v + + if conj: + next_state = next_state[..., : next_state.size(-1) // 2] + return next_state + + def _setup_step(self, **kwargs): + # This method is intended to be private so that setting up an S4 module with + # ``` + # if hasattr(module, 'setup_step'): module.setup_step() + # ``` + # will not trigger this method multiple times + self.kernel._setup_step(**kwargs) + + def step(self, u, state, **kwargs): + y, state = self.kernel.step(u, state, **kwargs) + return y, state + + def default_state(self, *args, **kwargs): + return self.kernel.default_state(*args, **kwargs) diff --git a/src/clm/module_library/krylov.py b/src/clm/module_library/krylov.py new file mode 100644 index 00000000..9d2b1e4b --- /dev/null +++ b/src/clm/module_library/krylov.py @@ -0,0 +1,209 @@ +import torch +import torch.nn.functional as F +from einops import rearrange + +from .toeplitz import causal_convolution + + +def krylov_sequential(L, A, b, c=None): + """Constant matrix A + + A : (..., N, N) + b : (..., N) + c : (..., N) + + Returns + if c: + x : (..., L) + x[i, l] = c[i] @ A^l @ b[i] + + else: + x : (..., N, L) + x[i, l] = A^l @ b[i] + """ + + # Check which of dim b and c is smaller to save memory + if c is not None and c.numel() < b.numel(): + return krylov_sequential(L, A.transpose(-1, -2), c, b) + + b_ = b + x = [] + for _ in range(L): + if c is not None: + x_ = torch.sum( + c * b_, dim=-1 + ) # (...) # could be faster with matmul or einsum? + else: + x_ = b_ + x.append(x_) + b_ = (A @ b_.unsqueeze(-1)).squeeze(-1) + + x = torch.stack(x, dim=-1) + return x + + +def krylov(L, A, b, c=None, return_power=False): + """ + Compute the Krylov matrix (b, Ab, A^2b, ...) using the squaring trick. + + If return_power=True, return A^{L-1} as well + """ + # TODO There is an edge case if L=1 where output doesn't get broadcasted, which might be an issue if caller is expecting broadcasting semantics... can deal with it if it arises + + x = b.unsqueeze(-1) # (..., N, 1) + A_ = A + + AL = None + if return_power: + AL = torch.eye(A.shape[-1], dtype=A.dtype, device=A.device) + _L = L - 1 + + done = L == 1 + # loop invariant: _L represents how many indices left to compute + while not done: + if return_power: + if _L % 2 == 1: + AL = A_ @ AL + _L //= 2 + + # Save memory on last iteration + l = x.shape[-1] + if L - l <= l: + done = True + _x = x[..., : L - l] + else: + _x = x + + _x = A_ @ _x + x = torch.cat( + [x, _x], dim=-1 + ) # there might be a more efficient way of ordering axes + if not done: + A_ = A_ @ A_ + + assert x.shape[-1] == L + + if c is not None: + x = torch.einsum("...nl, ...n -> ...l", x, c) + x = x.contiguous() # WOW!! + if return_power: + return x, AL + else: + return x + + +@torch.no_grad() +def power(L, A, v=None): + """Compute A^L and the scan sum_i A^i v_i + + A: (..., N, N) + v: (..., N, L) + """ + + I = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device) + + powers = [A] + l = 1 + while True: + if L % 2 == 1: + I = powers[-1] @ I + L //= 2 + if L == 0: + break + l *= 2 + if v is None: + powers = [powers[-1] @ powers[-1]] + else: + powers.append(powers[-1] @ powers[-1]) + + if v is None: + return I + + # Invariants: + # powers[-1] := A^l + # l := largest po2 at most L + + # Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A + # We do this reverse divide-and-conquer for efficiency reasons: + # 1) it involves fewer padding steps for non-po2 L + # 2) it involves more contiguous arrays + + # Take care of edge case for non-po2 arrays + # Note that this initial step is a no-op for the case of power of 2 (l == L) + k = v.size(-1) - l + v_ = powers.pop() @ v[..., l:] + v = v[..., :l] + v[..., :k] = v[..., :k] + v_ + + # Handle reduction for power of 2 + while v.size(-1) > 1: + v = rearrange(v, "... (z l) -> ... z l", z=2) + v = v[..., 0, :] + powers.pop() @ v[..., 1, :] + return I, v.squeeze(-1) + + +def krylov_toeplitz(L, A, b, c=None): + """Specializes to lower triangular Toeplitz matrix A represented by its diagonals + + A : (..., N) + b : (..., N) + c : (..., N) + + Returns + x : (..., N, L) + x[i, l] = A^l @ b[i] + """ + x = b.unsqueeze(0) # (1, ..., N) + A_ = A + while x.shape[0] < L: + xx = causal_convolution(A_, x) + x = torch.cat( + [x, xx], dim=0 + ) # there might be a more efficient way of ordering axes + A_ = causal_convolution(A_, A_) + x = x[:L, ...] # (L, ..., N) + if c is not None: + x = torch.einsum("l...n, ...n -> ...l", x, c) + else: + x = rearrange(x, "l ... n -> ... n l") + x = x.contiguous() + return x + + +def krylov_toeplitz_(L, A, b, c=None): + """Padded version of krylov_toeplitz that saves some fft's + + TODO currently not faster than original version, not sure why + """ + N = A.shape[-1] + + x = b.unsqueeze(0) # (1, ..., N) + x = F.pad(x, (0, N)) + A = F.pad(A, (0, N)) + done = L == 1 + while not done: + l = x.shape[0] + # Save memory on last iteration + if L - l <= l: + done = True + _x = x[: L - l] + else: + _x = x + Af = torch.fft.rfft(A, n=2 * N, dim=-1) + xf = torch.fft.rfft(_x, n=2 * N, dim=-1) + xf_ = Af * xf + x_ = torch.fft.irfft(xf_, n=2 * N, dim=-1) + x_[..., N:] = 0 + x = torch.cat( + [x, x_], dim=0 + ) # there might be a more efficient way of ordering axes + if not done: + A = torch.fft.irfft(Af * Af, n=2 * N, dim=-1) + A[..., N:] = 0 + x = x[:L, ..., :N] # (L, ..., N) + if c is not None: + x = torch.einsum("l...n, ...n -> ...l", x, c) + else: + x = rearrange(x, "l ... n -> ... n l") + x = x.contiguous() + return x diff --git a/src/clm/module_library/pool.py b/src/clm/module_library/pool.py new file mode 100644 index 00000000..f6f2bdf4 --- /dev/null +++ b/src/clm/module_library/pool.py @@ -0,0 +1,62 @@ +import torch.nn.functional as F +from einops import rearrange, reduce + +from .sequence_module import SequenceModule +from .util_modules import LinearActivation + + +class DownAvgPool(SequenceModule): + def __init__(self, d_input, stride=1, expand=None, transposed=True): + super().__init__() + self.d_input = d_input + self.stride = stride + self.expand = expand + self.transposed = transposed + + if self.expand is not None: + self.linear = LinearActivation( + d_input, + d_input * expand, + transposed=transposed, + ) + + def forward(self, x): + if not self.transposed: + x = rearrange(x, "b ... d -> b d ...") + + if self.stride > 1: + # einops appears slower than F + if x.ndim == 3: + x = F.avg_pool1d(x, self.stride, self.stride) + elif x.ndim == 4: + x = F.avg_pool2d(x, self.stride, self.stride) + else: + # Reduction string e.g. "b d (l1 2) (l2 2) -> b d l1 l2" + reduce_str = ( + "b d " + + " ".join([f"(l{i} {self.stride})" for i in range(x.ndim - 2)]) + + " -> b d " + + " ".join([f"l{i}" for i in range(x.ndim - 2)]) + ) + x = reduce(x, reduce_str, "mean") + + # if self.expand > 1: + # x = repeat(x, 'b d ... -> b (d e) ...', e=self.expand) + + if not self.transposed: + x = rearrange(x, "b d ... -> b ... d") + if self.expand is not None: + x = self.linear(x) + return x, None + + def step(self, x, state, **kwargs): + if self.stride > 1 or self.expand > 1: + raise NotImplementedError + return x, state + + @property + def d_output(self): + if self.expand is None: + return self.d_input + else: + return self.d_input * self.expand diff --git a/src/clm/module_library/residual.py b/src/clm/module_library/residual.py new file mode 100644 index 00000000..50513e25 --- /dev/null +++ b/src/clm/module_library/residual.py @@ -0,0 +1,23 @@ +from torch import nn + + +class Residual(nn.Module): + """Residual connection with constant affine weights. Can simulate standard residual, no residual, and "constant gates".""" + + def __init__(self, i_layer, d_input, d_model, alpha=1.0, beta=1.0): + # print("ConstantResidual extra kwargs", kwargs) + super().__init__() + assert (d_input == d_model) or alpha == 0.0 + self.i_layer = i_layer + self.d_input = d_input + self.d_model = d_model + self.alpha = alpha + self.beta = beta + + @property + def d_output(self): + return self.d_model + + def forward(self, x, y, transposed): # TODO documentation of transposed + y = self.beta * y if self.beta != 1.0 else y + return self.alpha * x + y if self.alpha else y diff --git a/src/clm/module_library/s4.py b/src/clm/module_library/s4.py new file mode 100644 index 00000000..cb462e26 --- /dev/null +++ b/src/clm/module_library/s4.py @@ -0,0 +1,290 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import opt_einsum as oe +from einops import rearrange + +optimized = True + +if optimized: + contract = oe.contract +else: + contract = torch.einsum + +from .kernel import SSKernel +from .util_modules import LinearActivation, Activation, DropoutNd + + +class S4(nn.Module): + def __init__( + self, + d_model, + d_state=64, + l_max=None, + channels=1, + bidirectional=False, + # Arguments for position-wise feedforward components + activation="gelu", + postact="glu", + initializer=None, + weight_norm=False, + hyper_act=None, + dropout=0.0, + tie_dropout=False, + bottleneck=None, + gate=None, + transposed=True, + verbose=False, + shift=False, + linear=False, + # SSM Kernel arguments + **kernel_args, + ): + """ + d_state: the dimension of the state, also denoted by N + l_max: the maximum kernel length, also denoted by L. Set l_max=None to always use a global kernel + channels: can be interpreted as a number of "heads"; the SSM is a map from a 1-dim to C-dim sequence. It's not recommended to change this unless desperate for things to tune; instead, increase d_model for larger models + bidirectional: if True, convolution kernel will be two-sided + + Position-wise feedforward components: + -------------------- + activation: activation in between SS and FF + postact: activation after FF + initializer: initializer on FF + weight_norm: weight normalization on FF + hyper_act: use a "hypernetwork" multiplication (experimental) + dropout: standard dropout argument. tie_dropout=True ties the dropout mask across the sequence length, emulating nn.Dropout1d + + Other arguments: + -------------------- + transposed: choose backbone axis ordering of (B, L, H) (if False) or (B, H, L) (if True) [B=batch size, L=sequence length, H=hidden dimension] + gate: add gated activation (GSS) + bottleneck: reduce SSM dimension (GSS) + shift: experimental option, shouldn't affect results + linear: Remove pointwise components so that the entire module is a linear SSM + + See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" + + Other options are all experimental and should not need to be configured + """ + + super().__init__() + + self.d_model = d_model + self.H = d_model + self.N = d_state + self.L = l_max + self.bidirectional = bidirectional + self.channels = channels + self.transposed = transposed + self.shift = shift + self.linear = linear + + self.gate = gate + self.bottleneck = bottleneck + + if bottleneck is not None: + self.H = self.H // bottleneck + self.input_linear = LinearActivation( + self.d_model, + self.H, + transposed=self.transposed, + initializer=initializer, + activation=activation, + activate=True, + weight_norm=weight_norm, + ) + + if gate is not None: + self.input_gate = LinearActivation( + self.d_model, + self.d_model * gate, + transposed=self.transposed, + initializer=initializer, + activation=activation, + activate=True, + weight_norm=weight_norm, + ) + self.output_gate = LinearActivation( + self.d_model * gate, + self.d_model, + transposed=self.transposed, + initializer=initializer, + activation=None, + activate=False, + weight_norm=weight_norm, + ) + + # optional multiplicative modulation GLU-style + # https://arxiv.org/abs/2002.05202 + self.hyper = hyper_act is not None + if self.hyper: + channels *= 2 + self.hyper_activation = Activation(hyper_act) + + self.D = nn.Parameter(torch.randn(channels, self.H)) + + if self.bidirectional: + channels *= 2 + + # SSM Kernel + self.kernel = SSKernel( + self.H, + N=self.N, + L=self.L, + channels=channels, + verbose=verbose, + **kernel_args, + ) + + # Pointwise + if not self.linear: + self.activation = Activation(activation) + # dropout_fn = nn.Dropout2d if self.transposed else nn.Dropout # Broken in torch==1.11 + dropout_fn = DropoutNd if tie_dropout else nn.Dropout + self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() + # position-wise output transform to mix features + if not self.linear: + self.output_linear = LinearActivation( + self.H * self.channels, + self.d_model * (1 if self.gate is None else self.gate), + transposed=self.transposed, + initializer=initializer, + activation=postact, + activate=True, + weight_norm=weight_norm, + ) + + def forward( + self, u, state=None, rate=1.0, lengths=None, **kwargs + ): # absorbs return_output and transformer src mask + """ + u: (B H L) if self.transposed else (B L H) + state: (H N) never needed unless you know what you're doing + + Returns: same shape as u + """ + if not self.transposed: + u = u.transpose(-1, -2) + + L = u.size(-1) + # Mask out padding tokens + # TODO handle option for mask - instead of lengths, which assumes suffix padding + if isinstance(lengths, int): + if lengths != L: + lengths = torch.tensor(lengths, dtype=torch.long, device=u.device) + else: + lengths = None + if lengths is not None: + assert ( + isinstance(lengths, torch.Tensor) + and lengths.ndim == 1 + and lengths.size(0) in [1, u.size(0)] + ) + mask = torch.where( + torch.arange(L, device=lengths.device) < lengths[:, None, None], + 1.0, + 0.0, + ) + u = u * mask + + if self.gate is not None: + v = self.input_gate(u) + if self.bottleneck is not None: + u = self.input_linear(u) + + # Compute SS Kernel + L_kernel = L if self.L is None else min(L, round(self.L / rate)) + k, k_state = self.kernel( + L=L_kernel, rate=rate, state=state + ) # (C H L) (B C H L) + # Convolution + if self.bidirectional: + k0, k1 = rearrange(k, "(s c) h l -> s c h l", s=2) + k = F.pad(k0, (0, L)) + F.pad(k1.flip(-1), (L, 0)) + if self.shift: + # Try flip and pad to correct for potential off-by-one + k_f = torch.fft.rfft(F.pad(k.flip(-1), (L, 0)), n=2 * L) # (C H L) + u_f = torch.fft.rfft(F.pad(u.flip(-1), (L, 0)), n=2 * L) # (B H L) + y_f = contract( + "bhl,chl->bchl", u_f, k_f + ) # k_f.unsqueeze(-4) * u_f.unsqueeze(-3) # (B C H L) + y = torch.fft.irfft(y_f, n=L_kernel + L)[..., L:].flip(-1) # (B C H L) + else: + k_f = torch.fft.rfft(k, n=L_kernel + L) # (C H L) + u_f = torch.fft.rfft(u, n=L_kernel + L) # (B H L) + y_f = contract("bhl,chl->bchl", u_f, k_f) + y = torch.fft.irfft(y_f, n=L_kernel + L)[..., :L] # (B C H L) + + # Compute D term in state space equation - essentially a skip connection + y = y + contract("bhl,ch->bchl", u, self.D) + + # Compute state update + if state is not None: + assert ( + not self.bidirectional + ), "Bidirectional not supported with state forwarding" + y = y + k_state # + next_state = self.kernel.forward_state(u, state) + else: + next_state = None + + # Optional hyper-network multiplication + if self.hyper: + y, yh = rearrange(y, "b (s c) h l -> s b c h l", s=2) + y = self.hyper_activation(yh) * y + + # Reshape to flatten channels + y = rearrange(y, "... c h l -> ... (c h) l") + + if not self.linear: + y = self.dropout(self.activation(y)) + + if not self.transposed: + y = y.transpose(-1, -2) + + if not self.linear: + y = self.output_linear(y) + + if self.gate is not None: + y = self.output_gate(y * v) + return y, next_state + + def setup_step(self, **kwargs): + self.kernel._setup_step(**kwargs) + + def step(self, u, state): + """Step one time step as a recurrent model. Intended to be used during validation. + + u: (B H) + state: (B H N) + Returns: output (B H), state (B H N) + """ + assert not self.training + # u = u.squeeze(1) # (B H) + y, next_state = self.kernel.step(u, state) # (B C H) + y = y + u.unsqueeze(-2) * self.D + y = rearrange(y, "b c h -> b (c h)") + y = self.activation(y) + if self.transposed: + y = self.output_linear(y.unsqueeze(-1)).squeeze(-1) + else: + y = self.output_linear(y) + return y, next_state + + def default_state(self, *batch_shape, device=None): + # kernel is not a SequenceModule so it doesn't need to adhere to same interface + # the kernel will know the device of its own parameters + return self.kernel.default_state(*batch_shape) + + @property + def d_state(self): + return self.H * self.N + + @property + def d_output(self): + return self.d_model + + @property + def state_to_tensor(self): + return lambda state: rearrange("... h n -> ... (h n)", state) diff --git a/src/clm/module_library/sequence_model.py b/src/clm/module_library/sequence_model.py new file mode 100644 index 00000000..03949f98 --- /dev/null +++ b/src/clm/module_library/sequence_model.py @@ -0,0 +1,204 @@ +from functools import partial +from typing import Sequence, Mapping +import torch +import torch.nn as nn +from einops import rearrange + +from .sequence_residual_block import SequenceResidualBlock +from .sequence_module import SequenceModule +from .util_modules import Normalization, DropoutNd + + +def is_list(x): + return isinstance(x, Sequence) and not isinstance(x, str) + + +def is_dict(x): + return isinstance(x, Mapping) + + +def to_dict(x, recursive=True): + """Convert Sequence or Mapping object to dict + + lists get converted to {0: x[0], 1: x[1], ...} + """ + if is_list(x): + x = {i: v for i, v in enumerate(x)} + if is_dict(x): + if recursive: + return {k: to_dict(v, recursive=recursive) for k, v in x.items()} + else: + return dict(x) + else: + return x + + +def to_list(x, recursive=False): + """Convert an object to list. + + If Sequence (e.g. list, tuple, Listconfig): just return it + + Special case: If non-recursive and not a list, wrap in list + """ + if is_list(x): + if recursive: + return [to_list(_x) for _x in x] + else: + return list(x) + else: + if recursive: + return x + else: + return [x] + + +class SequenceModel(SequenceModule): + def __init__( + self, + d_model, # Resize input (useful for deep models with residuals) + n_layers=1, # Number of layers + transposed=False, # Transpose inputs so each layer receives (batch, dim, length) + dropout=0.0, # Dropout parameter applied on every residual and every layer + tie_dropout=False, # Tie dropout mask across sequence like nn.Dropout1d/nn.Dropout2d + prenorm=True, # Pre-norm vs. post-norm + n_repeat=1, # Each layer is repeated n times per stage before applying pooling + layer=None, # Layer config, must be specified + # residual=None, # Residual config + residual="R", # Residual config # changed the default value from None to "R" + # norm=None, # Normalization config (e.g. layer vs batch) + norm="layer", # Normalization config (e.g. layer vs batch) # changed the default value from None to "layer" + pool=None, # Config for pooling layer per stage + # track_norms=True, # Log norms of each layer output; changed the default value from True to False + track_norms=False, # Log norms of each layer output; changed the default value from True to False + dropinp=0.0, # Input dropout + ): + super().__init__() + # Save arguments needed for forward pass + self.d_model = d_model + self.transposed = transposed + self.track_norms = track_norms + + # Input dropout (not really used) + dropout_fn = ( + partial(DropoutNd, transposed=self.transposed) + if tie_dropout + else nn.Dropout + ) + self.drop = dropout_fn(dropinp) if dropinp > 0.0 else nn.Identity() + + layer = to_list(layer, recursive=False) + + # Some special arguments are passed into each layer + for _layer in layer: + # If layers don't specify dropout, add it + if _layer.get("dropout", None) is None: + _layer["dropout"] = dropout + # Ensure all layers are shaped the same way + _layer["transposed"] = transposed + + # Duplicate layers + layers = layer * n_layers * n_repeat + + # Instantiate layers + _layers = [] + d = d_model + for l, layer in enumerate(layers): + # Pool at the end of every n_repeat blocks + pool_cfg = pool if (l + 1) % n_repeat == 0 else None + block = SequenceResidualBlock( + d, + l + 1, + prenorm=prenorm, + dropout=dropout, + tie_dropout=tie_dropout, + transposed=transposed, + layer_config=layer, + residual=residual, + norm=norm, + pool=pool_cfg, + ) + _layers.append(block) + d = block.d_output + + self.d_output = d + self.layers = nn.ModuleList(_layers) + if prenorm: + if norm is None: + self.norm = None + elif isinstance(norm, str): + self.norm = Normalization( + self.d_output, transposed=self.transposed, _name_=norm + ) + else: + self.norm = Normalization( + self.d_output, transposed=self.transposed, **norm + ) + else: + self.norm = nn.Identity() + + def forward(self, inputs, *args, state=None, **kwargs): + """Inputs assumed to be (batch, sequence, dim)""" + if self.transposed: + inputs = rearrange(inputs, "b ... d -> b d ...") + inputs = self.drop(inputs) + + # Track norms + if self.track_norms: + output_norms = [torch.mean(inputs.detach() ** 2)] + + # Apply layers + outputs = inputs + prev_states = [None] * len(self.layers) if state is None else state + next_states = [] + for layer, prev_state in zip(self.layers, prev_states): + outputs, state = layer(outputs, *args, state=prev_state, **kwargs) + next_states.append(state) + if self.track_norms: + output_norms.append(torch.mean(outputs.detach() ** 2)) + if self.norm is not None: + outputs = self.norm(outputs) + + if self.transposed: + outputs = rearrange(outputs, "b d ... -> b ... d") + + if self.track_norms: + metrics = to_dict(output_norms, recursive=False) + self.metrics = {f"norm/{i}": v for i, v in metrics.items()} + + return outputs, next_states + + @property + def d_state(self): + d_states = [layer.d_state for layer in self.layers] + return sum([d for d in d_states if d is not None]) + + @property + def state_to_tensor(self): + # Slightly hacky way to implement this in a curried manner (so that the function can be extracted from an instance) + # Somewhat more sound may be to turn this into a @staticmethod and grab subclasses using hydra.utils.get_class + def fn(state): + x = [ + _layer.state_to_tensor(_state) + for (_layer, _state) in zip(self.layers, state) + ] + x = [_x for _x in x if _x is not None] + return torch.cat(x, dim=-1) + + return fn + + def default_state(self, *batch_shape, device=None): + return [ + layer.default_state(*batch_shape, device=device) for layer in self.layers + ] + + def step(self, x, state, **kwargs): + prev_states = [None] * len(self.layers) if state is None else state + next_states = [] + layer_idx = 0 + for layer, prev_state in zip(self.layers, prev_states): + x, state = layer.step(x, state=prev_state, **kwargs) + next_states.append(state) + layer_idx += 1 + + x = self.norm(x) + return x, next_states diff --git a/src/clm/module_library/sequence_module.py b/src/clm/module_library/sequence_module.py new file mode 100644 index 00000000..4f8a4ffa --- /dev/null +++ b/src/clm/module_library/sequence_module.py @@ -0,0 +1,131 @@ +from torch import nn +import functools + +class SequenceModule(nn.Module): + """Abstract sequence model class. All models must adhere to this interface + + A SequenceModule is generally a model that transforms an input of shape + (n_batch, l_sequence, d_model) to (n_batch, l_sequence, d_output) + + REQUIRED methods and attributes + forward, d_model, d_output: controls standard forward pass, a sequence-to-sequence transformation + __init__ should also satisfy the following interface; see SequenceIdentity for an example + def __init__(self, d_model, transposed=False, **kwargs) + + OPTIONAL methods + default_state, step: allows stepping the model recurrently with a hidden state + state_to_tensor, d_state: allows decoding from hidden state + """ + + @property + def d_model(self): + """Model dimension (generally same as input dimension). + + This attribute is required for all SequenceModule instantiations. + It is used by the rest of the pipeline (e.g. model backbone, encoder) to track the internal shapes of the full model. + """ + if getattr(self, "_d_model", None) is None: + raise NotImplementedError("SequenceModule instantiation must set d_model") + return self._d_model + + @d_model.setter + def d_model(self, d): + self._d_model = d + + @property + def d_output(self): + """Output dimension of model. + + This attribute is required for all SequenceModule instantiations. + It is used by the rest of the pipeline (e.g. model backbone, decoder) to track the internal shapes of the full model. + """ + if getattr(self, "_d_output", None) is None: + raise NotImplementedError("SequenceModule instantiation must specify d_output for decoder") + return self._d_output + + @d_output.setter + def d_output(self, d): + self._d_output = d + + def forward(self, x, state=None, **kwargs): + """Forward pass of sequence model, a sequence-to-sequence transformation with an optional state. + + Generally, this should map a tensor of shape (batch, length, self.d_model) to (batch, length, self.d_output) + + Additionally, it returns a "state" which can be any additional information + For example, RNN and SSM layers may return their hidden state, + while some types of transformer layers (e.g. Transformer-XL) may want to pass a state as well + """ + return x, None + + @property + def state_to_tensor(self): + """Returns a function mapping a state to a single tensor. + + This method should be implemented if one wants to use the hidden state instead of the output sequence for final prediction. + Currently only used with the StateDecoder. + """ + return lambda _: None + + @property + def d_state(self): + """ Returns dimension of output of self.state_to_tensor """ + return None + + + def default_state(self, *batch_shape, device=None): + """Create initial state for a batch of inputs.""" + + return None + + def step(self, x, state=None, **kwargs): + """Step the model recurrently for one step of the input sequence. + + For example, this should correspond to unrolling an RNN for one step. + If the forward pass has signature (B, L, H1) -> (B, L, H2), + this method should generally have signature (B, H1) -> (B, H2) with an optional recurrent state. + """ + raise NotImplementedError + +def TransposedModule(module): + """Wrap a SequenceModule class to accept transposed parameter, handle state, absorb kwargs""" + # https://stackoverflow.com/a/65470430/1980685 + @functools.wraps(module, updated=()) + class TransposedModule(module): + def __init__(self, *args, transposed=False, **kwargs): + super().__init__(*args, **kwargs) + self.transposed = transposed + + def forward(self, x, state=None, **kwargs): + if self.transposed: x = x.transpose(-1, -2) + x, next_state = super().forward(x, state) # Don't use kwarg because nn.LSTM + next_state = None if state is None else next_state + if self.transposed: x = x.transpose(-1,-2) + return x, next_state + # https://stackoverflow.com/questions/5352781/how-to-set-class-names-dynamically + # TransposedModule.__name__ = module.__name__ # functools wraps is better solution + return TransposedModule + +@TransposedModule +class SequenceIdentity(SequenceModule): + """Simple SequenceModule for testing purposes""" + + def __init__(self, d_model, dropout=0.0, **kwargs): + """Default interface for SequenceModule + + d_model: input dimension (sometimes denoted H for hidden dimension) + transposed: if True, inputs have axis ordering (B, H, L) instead of (B, H, L) + """ + super().__init__() + self.d_model = d_model + self.d_output = d_model + + + def forward(self, x, state=None): + return x, state + + def default_state(self, *batch_shape, device=None): + return None + + def step(self, x, state=None, **kwargs): + return x, state diff --git a/src/clm/module_library/sequence_residual_block.py b/src/clm/module_library/sequence_residual_block.py new file mode 100644 index 00000000..ba5551a9 --- /dev/null +++ b/src/clm/module_library/sequence_residual_block.py @@ -0,0 +1,148 @@ +from torch import nn + +from functools import partial + +from .util_modules import Normalization, StochasticDepth, DropoutNd +from .sequence_module import SequenceModule +from .s4 import S4 +from .ff import FF +from .pool import DownAvgPool +from .residual import Residual + + +class SequenceResidualBlock(SequenceModule): + def __init__( + self, + d_input, + i_layer=None, # Only needs to be passed into certain residuals like Decay + prenorm=True, + dropout=0.0, + tie_dropout=False, + transposed=False, + layer_config=None, # Config for black box module + residual=None, # Config for residual function + norm=None, # Config for normalization layer + pool=None, + drop_path=0.0, + ): + super().__init__() + + self.i_layer = i_layer + self.d_input = d_input + # self.layer = instantiate(registry.layer, layer, d_input) + # layer_config = layer.copy() + # layer_cls = registry.get_layer(layer["_name_"]) + layer_config = layer_config.copy() + if layer_config["_name_"] == "s4": + layer_cls = S4 + elif layer_config["_name_"] == "ff": + layer_cls = FF + layer_config.pop("_name_") + self.layer = layer_cls(d_input, **layer_config) + + self.prenorm = prenorm + self.transposed = transposed + + # Residual + # d_residual is the output dimension after residual + if residual is None: + self.residual = None + self.d_residual = self.layer.d_output + else: + # self.residual = instantiate( + # residual_registry, residual, i_layer, d_input, self.layer.d_output + # ) + self.residual = Residual(i_layer, d_input, self.layer.d_output) + # instantiate( + # residual_registry, residual, i_layer, d_input, self.layer.d_output + # ) + self.d_residual = self.residual.d_output + + # Normalization + d_norm = d_input if self.prenorm else self.d_residual + # We don't use config to directly instantiate since Normalization has some special cases + if norm is None: + self.norm = None + elif isinstance(norm, str): + self.norm = Normalization(d_norm, transposed=self.transposed, _name_=norm) + else: + self.norm = Normalization(d_norm, transposed=self.transposed, **norm) + + # Pool + if pool is not None: + self.pool = DownAvgPool(self.d_residual, transposed=self.transposed) + + # Dropout + dropout_cls = ( + partial(DropoutNd, transposed=self.transposed) + if tie_dropout + else nn.Dropout + ) + self.drop = dropout_cls(dropout) if dropout > 0.0 else nn.Identity() + + # Stochastic depth + self.drop_path = ( + StochasticDepth(drop_path, mode="row") if drop_path > 0.0 else nn.Identity() + ) + + @property + def d_output(self): + return self.pool.d_output if self.pool is not None else self.d_residual + + @property + def d_state(self): + return self.layer.d_state + + @property + def state_to_tensor(self): + return self.layer.state_to_tensor + + def default_state(self, *args, **kwargs): + return self.layer.default_state(*args, **kwargs) + + def forward(self, x, state=None, **kwargs): + y = x + + # Pre-norm + if self.norm is not None and self.prenorm: + y = self.norm(y) + + # Black box layer + y, state = self.layer(y, state=state, **kwargs) + + # Residual + if self.residual is not None: + y = self.residual(x, self.drop_path(self.drop(y)), self.transposed) + # Post-norm + if self.norm is not None and not self.prenorm: + y = self.norm(y) + + # Pool + if self.pool is not None: + y, _ = self.pool(y) + + return y, state + + def step(self, x, state, **kwargs): + y = x + + # Pre-norm + if self.norm is not None and self.prenorm: + y = self.norm.step(y) + + # Black box layer + y, state = self.layer.step(y, state, **kwargs) + # Residual + if self.residual is not None: + y = self.residual( + x, y, transposed=self.transposed + ) # NOTE this would not work with concat residual function (catformer) + # Post-norm + if self.norm is not None and not self.prenorm: + y = self.norm.step(y) + + # Pool + if self.pool is not None: + y, _ = self.pool(y) + + return y, state diff --git a/src/clm/module_library/toeplitz.py b/src/clm/module_library/toeplitz.py new file mode 100644 index 00000000..5e382442 --- /dev/null +++ b/src/clm/module_library/toeplitz.py @@ -0,0 +1,156 @@ +import torch +import torch.nn.functional as F + + +def construct_toeplitz(v, f=0.0): + """Explicit construction of Krylov matrix [v A @ v A^2 @ v ... A^{n-1} @ v] + where A = Z_f. This uses vectorized indexing and cumprod so it's much + faster than using the Krylov function. + Parameters: + v: the starting vector of size n or (rank, n). + f: real number + Returns: + K: Krylov matrix of size (n, n) or (rank, n, n). + """ + n = v.shape[-1] + a = torch.arange(n, device=v.device) + b = -a + indices = a[:, None] + b[None] + K = v[..., indices] + K[..., indices < 0] *= f + return K + + +def triangular_toeplitz_multiply_(u, v, sum=None): + n = u.shape[-1] + u_expand = F.pad(u, (0, n)) + v_expand = F.pad(v, (0, n)) + u_f = torch.fft.rfft(u_expand, n=2 * n, dim=-1) + v_f = torch.fft.rfft(v_expand, n=2 * n, dim=-1) + uv_f = u_f * v_f + if sum is not None: + uv_f = uv_f.sum(dim=sum) + output = torch.fft.irfft(uv_f, n=2 * n, dim=-1)[..., :n] + return output + + +def triangular_toeplitz_multiply_padded_(u, v): + """Same as triangular_toeplitz_multiply but inputs and output assume to be 0-padded already.""" + n = u.shape[-1] + assert n % 2 == 0 + u_f = torch.fft.rfft(u, n=n, dim=-1) + v_f = torch.fft.rfft(v, n=n, dim=-1) + uv_f = u_f * v_f + output = torch.fft.irfft(uv_f, n=n, dim=-1) + output[..., n:] = 0 + return output + + +class TriangularToeplitzMult(torch.autograd.Function): + @staticmethod + def forward(ctx, u, v): + ctx.save_for_backward(u, v) + return triangular_toeplitz_multiply_(u, v) + + @staticmethod + def backward(ctx, grad): + u, v = ctx.saved_tensors + d_u = triangular_toeplitz_multiply_(grad.flip(-1), v).flip(-1) + d_v = triangular_toeplitz_multiply_(grad.flip(-1), u).flip(-1) + return d_u, d_v + + +class TriangularToeplitzMultFast(torch.autograd.Function): + @staticmethod + def forward(ctx, u, v): + n = u.shape[-1] + u_expand = F.pad(u, (0, n)) + v_expand = F.pad(v, (0, n)) + u_f = torch.fft.rfft(u_expand, n=2 * n, dim=-1) + v_f = torch.fft.rfft(v_expand, n=2 * n, dim=-1) + + ctx.save_for_backward(u_f, v_f) + + uv_f = u_f * v_f + output = torch.fft.irfft(uv_f, n=2 * n, dim=-1)[..., :n] + return output + + @staticmethod + def backward(ctx, grad): + u_f, v_f = ctx.saved_tensors + n = grad.shape[-1] + g_expand = F.pad(grad.flip(-1), (0, n)) + g_f = torch.fft.rfft(g_expand, n=2 * n, dim=-1) + gu_f = g_f * u_f + gv_f = g_f * v_f + d_u = torch.fft.irfft(gv_f, n=2 * n, dim=-1)[..., :n] + d_v = torch.fft.irfft(gu_f, n=2 * n, dim=-1)[..., :n] + d_u = d_u.flip(-1) + d_v = d_v.flip(-1) + return d_u, d_v + + +class TriangularToeplitzMultPadded(torch.autograd.Function): + @staticmethod + def forward(ctx, u, v): + ctx.save_for_backward(u, v) + output = triangular_toeplitz_multiply_(u, v) + return output + + @staticmethod + def backward(ctx, grad): + u, v = ctx.saved_tensors + d_u = triangular_toeplitz_multiply_padded_(grad.flip(-1), v).flip(-1) + d_v = triangular_toeplitz_multiply_padded_(grad.flip(-1), u).flip(-1) + return d_u, d_v + + +class TriangularToeplitzMultPaddedFast(torch.autograd.Function): + """Trade off speed (20-25% faster) for more memory (20-25%)""" + + @staticmethod + def forward(ctx, u, v): + n = u.shape[-1] + u_f = torch.fft.rfft(u, n=n, dim=-1) + v_f = torch.fft.rfft(v, n=n, dim=-1) + + ctx.save_for_backward(u_f, v_f) + + uv_f = u_f * v_f + output = torch.fft.irfft(uv_f, n=n, dim=-1) + output[..., n // 2 :].zero_() + return output + + @staticmethod + def backward(ctx, grad): + u_f, v_f = ctx.saved_tensors + n = grad.shape[-1] + g_expand = F.pad(grad[..., : n // 2].flip(-1), (0, n // 2)) + g_f = torch.fft.rfft(g_expand, n=n, dim=-1) + gu_f = g_f * u_f + gv_f = g_f * v_f + d_u = torch.fft.irfft(gv_f, n=n, dim=-1) + d_v = torch.fft.irfft(gu_f, n=n, dim=-1) + d_u[..., n // 2 :].zero_() + d_v[..., n // 2 :].zero_() + d_u[..., : n // 2] = d_u[..., : n // 2].flip(-1) # TODO + d_v[..., : n // 2] = d_v[..., : n // 2].flip(-1) # TODO + return d_u, d_v + + +# triangular_toeplitz_multiply = triangular_toeplitz_multiply_ +triangular_toeplitz_multiply = TriangularToeplitzMult.apply +triangular_toeplitz_multiply_fast = TriangularToeplitzMultFast.apply +triangular_toeplitz_multiply_padded = TriangularToeplitzMultPadded.apply +triangular_toeplitz_multiply_padded_fast = TriangularToeplitzMultPaddedFast.apply + + +def causal_convolution(u, v, fast=True, pad=False): + if not pad and not fast: + return triangular_toeplitz_multiply(u, v) + if not pad and fast: + return triangular_toeplitz_multiply_fast(u, v) + if pad and not fast: + return triangular_toeplitz_multiply_padded(u, v) + if pad and fast: + return triangular_toeplitz_multiply_padded_fast(u, v) diff --git a/src/clm/module_library/util_modules.py b/src/clm/module_library/util_modules.py new file mode 100644 index 00000000..9f4666f3 --- /dev/null +++ b/src/clm/module_library/util_modules.py @@ -0,0 +1,318 @@ +import math +from functools import partial +import torch +from torch import nn +from einops import rearrange +from opt_einsum import contract + + +def get_initializer(name, activation=None): + if activation in [None, "id", "identity", "linear", "modrelu"]: + nonlinearity = "linear" + elif activation in ["relu", "tanh", "sigmoid"]: + nonlinearity = activation + elif activation in ["gelu", "swish", "silu"]: + nonlinearity = "relu" # Close to ReLU so approximate with ReLU's gain + else: + raise NotImplementedError( + f"get_initializer: activation {activation} not supported" + ) + + if name == "uniform": + initializer = partial(torch.nn.init.kaiming_uniform_, nonlinearity=nonlinearity) + elif name == "normal": + initializer = partial(torch.nn.init.kaiming_normal_, nonlinearity=nonlinearity) + elif name == "xavier": + initializer = torch.nn.init.xavier_normal_ + elif name == "zero": + initializer = partial(torch.nn.init.constant_, val=0) + elif name == "one": + initializer = partial(torch.nn.init.constant_, val=1) + else: + raise NotImplementedError( + f"get_initializer: initializer type {name} not supported" + ) + + return initializer + + +def Activation(activation=None, size=None, dim=-1): + if activation in [None, "id", "identity", "linear"]: + return nn.Identity() + elif activation == "tanh": + return nn.Tanh() + elif activation == "relu": + return nn.ReLU() + elif activation == "gelu": + return nn.GELU() + elif activation in ["swish", "silu"]: + return nn.SiLU() + elif activation == "glu": + return nn.GLU(dim=dim) + elif activation == "sigmoid": + return nn.Sigmoid() + elif activation == "softplus": + return nn.Softplus() + else: + raise NotImplementedError( + "hidden activation '{}' is not implemented".format(activation) + ) + + +class TransposedLinear(nn.Module): + """Linear module on the second-to-last dimension + Assumes shape (B, D, L), where L can be 1 or more axis + """ + + def __init__(self, d_input, d_output, bias=True): + super().__init__() + + self.weight = nn.Parameter(torch.empty(d_output, d_input)) + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # nn.Linear default init + # nn.init.kaiming_uniform_(self.weight, nonlinearity='linear') # should be equivalent + + if bias: + self.bias = nn.Parameter(torch.empty(d_output)) + bound = 1 / math.sqrt(d_input) + nn.init.uniform_(self.bias, -bound, bound) + setattr(self.bias, "_optim", {"weight_decay": 0.0}) + else: + self.bias = 0.0 + + def forward(self, x): + num_axis = len(x.shape[2:]) # num_axis in L, for broadcasting bias + y = contract("b u ..., v u -> b v ...", x, self.weight) + self.bias.view( + -1, *[1] * num_axis + ) + return y + + +def LinearActivation( + d_input, + d_output, + bias=True, + zero_bias_init=False, + transposed=False, + initializer=None, + activation=None, + activate=False, # Apply activation as part of this module + weight_norm=False, + **kwargs, +): + """Returns a linear nn.Module with control over axes order, initialization, and activation""" + + # Construct core module + # linear_cls = partial(nn.Conv1d, kernel_size=1) if transposed else nn.Linear + linear_cls = TransposedLinear if transposed else nn.Linear + if activation == "glu": + d_output *= 2 + linear = linear_cls(d_input, d_output, bias=bias, **kwargs) + + # Initialize weight + if initializer is not None: + get_initializer(initializer, activation)(linear.weight) + + # Initialize bias + if bias and zero_bias_init: + nn.init.zeros_(linear.bias) + + # Weight norm + if weight_norm: + linear = nn.utils.weight_norm(linear) + + if activate and activation is not None: + activation = Activation(activation, d_output, dim=1 if transposed else -1) + linear = nn.Sequential(linear, activation) + return linear + + +class DropoutNd(nn.Module): + def __init__(self, p: float = 0.5, tie=True, transposed=True): + """ + tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d) + """ + super().__init__() + if p < 0 or p >= 1: + raise ValueError( + "dropout probability has to be in [0, 1), " "but got {}".format(p) + ) + self.p = p + self.tie = tie + self.transposed = transposed + self.binomial = torch.distributions.binomial.Binomial(probs=1 - self.p) + + def forward(self, X): + """X: (batch, dim, lengths...)""" + if self.training: + if not self.transposed: + X = rearrange(X, "b d ... -> b ... d") + # binomial = torch.distributions.binomial.Binomial(probs=1-self.p) # This is incredibly slow + mask_shape = X.shape[:2] + (1,) * (X.ndim - 2) if self.tie else X.shape + # mask = self.binomial.sample(mask_shape) + mask = torch.rand(*mask_shape, device=X.device) < 1.0 - self.p + X = X * mask * (1.0 / (1 - self.p)) + if not self.transposed: + X = rearrange(X, "b ... d -> b d ...") + return X + return X + + +class Normalization(nn.Module): + def __init__( + self, + d, + transposed=False, # Length dimension is -1 or -2 + _name_="layer", + **kwargs, + ): + super().__init__() + self.transposed = transposed + self._name_ = _name_ + + if _name_ == "layer": + self.channel = True # Normalize over channel dimension + if self.transposed: + self.norm = TransposedLN(d, **kwargs) + else: + self.norm = nn.LayerNorm(d, **kwargs) + elif _name_ == "instance": + self.channel = False + norm_args = {"affine": False, "track_running_stats": False} + norm_args.update(kwargs) + self.norm = nn.InstanceNorm1d( + d, **norm_args + ) # (True, True) performs very poorly + elif _name_ == "batch": + self.channel = False + norm_args = {"affine": True, "track_running_stats": True} + norm_args.update(kwargs) + self.norm = nn.BatchNorm1d(d, **norm_args) + elif _name_ == "group": + self.channel = False + self.norm = nn.GroupNorm(1, d, *kwargs) + elif _name_ == "none": + self.channel = True + self.norm = nn.Identity() + else: + raise NotImplementedError + + def forward(self, x): + # Handle higher dimension logic + shape = x.shape + if self.transposed: + x = rearrange(x, "b d ... -> b d (...)") + else: + x = rearrange(x, "b ... d -> b (...)d ") + + # The cases of LayerNorm / no normalization are automatically handled in all cases + # Instance/Batch Norm work automatically with transposed axes + if self.channel or self.transposed: + x = self.norm(x) + else: + x = x.transpose(-1, -2) + x = self.norm(x) + x = x.transpose(-1, -2) + + x = x.view(shape) + return x + + def step(self, x, **kwargs): + assert self._name_ in ["layer", "none"] + if self.transposed: + x = x.unsqueeze(-1) + x = self.forward(x) + if self.transposed: + x = x.squeeze(-1) + return x + + +class TransposedLN(nn.Module): + """LayerNorm module over second dimension + Assumes shape (B, D, L), where L can be 1 or more axis + + This is slow and a dedicated CUDA/Triton implementation shuld provide substantial end-to-end speedup + """ + + def __init__(self, d, scalar=True): + super().__init__() + self.scalar = scalar + if self.scalar: + self.m = nn.Parameter(torch.zeros(1)) + self.s = nn.Parameter(torch.ones(1)) + setattr(self.m, "_optim", {"weight_decay": 0.0}) + setattr(self.s, "_optim", {"weight_decay": 0.0}) + else: + self.ln = nn.LayerNorm(d) + + def forward(self, x): + if self.scalar: + # calc. stats over D dim / channels + s, m = torch.std_mean(x, dim=1, unbiased=False, keepdim=True) + y = (self.s / s) * (x - m + self.m) + else: + # move channel to last axis, apply layer_norm, then move channel back to second axis + _x = self.ln(rearrange(x, "b d ... -> b ... d")) + y = rearrange(_x, "b ... d -> b d ...") + return y + + +def stochastic_depth(input: torch.tensor, p: float, mode: str, training: bool = True): + """ + Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth" + `_ used for randomly dropping residual + branches of residual architectures. + + Args: + input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one + being its batch i.e. a batch with ``N`` rows. + p (float): probability of the input to be zeroed. + mode (str): ``"batch"`` or ``"row"``. + ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes + randomly selected rows from the batch. + training: apply stochastic depth if is ``True``. Default: ``True`` + + Returns: + Tensor[N, ...]: The randomly zeroed tensor. + """ + if p < 0.0 or p > 1.0: + raise ValueError( + "drop probability has to be between 0 and 1, but got {}".format(p) + ) + if mode not in ["batch", "row"]: + raise ValueError( + "mode has to be either 'batch' or 'row', but got {}".format(mode) + ) + if not training or p == 0.0: + return input + + survival_rate = 1.0 - p + if mode == "row": + size = [input.shape[0]] + [1] * (input.ndim - 1) + else: + size = [1] * input.ndim + noise = torch.empty(size, dtype=input.dtype, device=input.device) + noise = noise.bernoulli_(survival_rate).div_(survival_rate) + return input * noise + + +class StochasticDepth(nn.Module): + """ + See :func:`stochastic_depth`. + """ + + def __init__(self, p: float, mode: str) -> None: + # TODO(karan): need to upgrade to torchvision==0.11.0 to use StochasticDepth directly + # from torchvision.ops import StochasticDepth + super().__init__() + self.p = p + self.mode = mode + + def forward(self, input): + return stochastic_depth(input, self.p, self.mode, self.training) + + def __repr__(self) -> str: + tmpstr = self.__class__.__name__ + "(" + tmpstr += "p=" + str(self.p) + tmpstr += ", mode=" + str(self.mode) + tmpstr += ")" + return tmpstr diff --git a/src/clm/src/__init__.py b/src/clm/src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/clm/src/callbacks/norms.py b/src/clm/src/callbacks/norms.py new file mode 100644 index 00000000..a6d8b6c3 --- /dev/null +++ b/src/clm/src/callbacks/norms.py @@ -0,0 +1,39 @@ +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.parsing import AttributeDict +from omegaconf import OmegaConf + +class TrackNorms(pl.Callback): + + # TODO do callbacks happen before or after the method in the main LightningModule? + # @rank_zero_only # needed? + def on_after_training_step(self, batch, batch_idx, trainer: pl.Trainer, pl_module: pl.LightningModule): + # Log extra metrics + metrics = {} + + if hasattr(pl_module, "_grad_norms"): + metrics.update(pl_module._grad_norms) + + self.log_dict( + metrics, + on_step=True, + on_epoch=False, + prog_bar=False, + add_dataloader_idx=False, + sync_dist=True, + ) + + + def on_after_backward(self, trainer: pl.Trainer, pl_module: pl.LightningModule): + # example to inspect gradient information in tensorboard + if OmegaConf.select(trainer.hparams, 'trainer.track_grad_norms'): # TODO dot notation should work with omegaconf? + norms = {} + for name, p in pl_module.named_parameters(): + if p.grad is None: + continue + + # param_norm = float(p.grad.data.norm(norm_type)) + param_norm = torch.mean(p.grad.data ** 2) + norms[f"grad_norm.{name}"] = param_norm + pl_module._grad_norms = norms + diff --git a/src/clm/src/callbacks/params.py b/src/clm/src/callbacks/params.py new file mode 100644 index 00000000..f3ddd1ff --- /dev/null +++ b/src/clm/src/callbacks/params.py @@ -0,0 +1,37 @@ +from typing import Any + +import pytorch_lightning as pl +from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.parsing import AttributeDict + + +class ParamsLog(pl.Callback): + """ Log the number of parameters of the model """ + def __init__( + self, + total: bool = True, + trainable: bool = True, + fixed: bool = True, + ): + super().__init__() + self._log_stats = AttributeDict( + { + 'total_params_log': total, + 'trainable_params_log': trainable, + 'non_trainable_params_log': fixed, + } + ) + + @rank_zero_only + def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + logs = {} + if self._log_stats.total_params_log: + logs["params/total"] = sum(p.numel() for p in pl_module.parameters()) + if self._log_stats.trainable_params_log: + logs["params/trainable"] = sum(p.numel() for p in pl_module.parameters() + if p.requires_grad) + if self._log_stats.non_trainable_params_log: + logs["params/fixed"] = sum(p.numel() for p in pl_module.parameters() + if not p.requires_grad) + if trainer.logger: + trainer.logger.log_hyperparams(logs) diff --git a/src/clm/src/callbacks/progressive_resizing.py b/src/clm/src/callbacks/progressive_resizing.py new file mode 100644 index 00000000..85638db6 --- /dev/null +++ b/src/clm/src/callbacks/progressive_resizing.py @@ -0,0 +1,118 @@ +import numpy as np +from pytorch_lightning.callbacks import Callback + +import clm.src.utils as utils +from clm.src.utils import registry + + +class ProgressiveResizing(Callback): + + def __init__(self, stage_params: list): + """ + stage_params is a list of dicts + e.g. stage_params = [ + {'resolution': 4, 'epochs': 50}, # 32 x 32 + {'resolution': 2, 'epochs': 30}, # 64 x 64 + {'resolution': 1, 'epochs': 20}, # 128 x 128 + ] + """ + super().__init__() + assert len(stage_params) > 0, 'No stages specified' + assert all([{'resolution', 'epochs'} <= set(stage.keys()) for stage in stage_params]), \ + 'stage_params must contain keys: resolution and epochs' + + self.stage_params = stage_params + self.stage_epochs_cume = np.cumsum([stage['epochs'] for stage in stage_params]) + + self._current_stage = 0 + + def _verify_stages(self, trainer, model): + # Double-check that stage parameters are correct, otherwise we'll fail in the middle of training + for stage in self.stage_params: + if hasattr(stage, 'scheduler'): + # Verify that we can actually create the scheduler when we need to update it in each stage + scheduler = utils.instantiate(registry.scheduler, {**model.hparams.scheduler, **stage['scheduler']}, trainer.optimizers[0]) + del scheduler + + def on_train_start(self, trainer, model) -> None: + # Verify all the stage parameters are correct + self._verify_stages(trainer, model) + + print(f"Training starts at {trainer.current_epoch}") + if trainer.current_epoch == 0: + # Update the model to the first stage + self._update_to_current_stage(trainer, model) + else: + # Preemption or resumption of progressive resizing + # Update the stage to the current one + self._current_stage = int(np.searchsorted(self.stage_epochs_cume - 1, trainer.current_epoch)) + self._starting_stage = np.any(trainer.current_epoch == self.stage_epochs_cume) + + print("Progressive Resizing: Restarting at Stage {}".format(self._current_stage)) + if self._starting_stage: + self._update_lr_scheduler(trainer, model) + + # Set the dataloader and model + self._update_dataloaders(trainer, model) + self._update_model(trainer, model) + + return super().on_train_start(trainer, model) + + def _update_lr_scheduler(self, trainer, model): + if not hasattr(self.stage_params[self._current_stage], 'scheduler'): + # No scheduler specified, so don't update the current scheduler + return + + assert len(trainer.lr_schedulers) == 1 + # Reinitialize the scheduler + # We don't need to carry over information from the last scheduler e.g. the last_epoch property, + # because that will mess with the new scheduler when we step it + hparams = {**model.hparams.scheduler, **self.stage_params[self._current_stage]['scheduler']} + + # Note that passing in the optimizer below is okay: the scheduler will be reinitialized and doesn't seem to inherit any current lr info from the optimizer + trainer.lr_schedulers[0]['scheduler'] = utils.instantiate(registry.scheduler, hparams, trainer.optimizers[0]) + + print("\tChanged scheduler to {}".format(hparams)) + + def _update_dataloaders(self, trainer, model): + # Set the train resolution and reset the dataloader + model.hparams.loader.train_resolution = self.stage_params[self._current_stage]['resolution'] + trainer.reset_train_dataloader(model) + + print('\tChanged resolution to {}'.format(self.stage_params[self._current_stage]['resolution'])) + + def _update_model(self, trainer, model): + if not hasattr(self.stage_params[self._current_stage], 'bandlimit'): + return + + # Update the bandlimit value for the model: this is a hack to make sure the model is updated + # Iterate over all the modules + for module in model.modules(): + if hasattr(module, 'bandlimit'): + module.bandlimit = self.stage_params[self._current_stage]['bandlimit'] + + print('\tChanged bandlimit to {}'.format(self.stage_params[self._current_stage]['bandlimit'])) + + def _update_to_current_stage(self, trainer, model): + print("Progressive Resizing: Moving to Stage {}".format(self._current_stage)) + # Update the train dataloader, model and scheduler + self._update_dataloaders(trainer, model) + self._update_model(trainer, model) + self._update_lr_scheduler(trainer, model) + + + def on_train_epoch_end(self, trainer, model): + """ + Check to see if new stage is reached for the next epoch, and if so, prepare the new stage by + changing the dataloader. + + (We do next epoch so that the dataloader is prepared before the next epoch) + """ + next_epoch = trainer.current_epoch + 1 + + # Check if stage should be increased + if next_epoch >= self.stage_epochs_cume[self._current_stage] and self._current_stage < len(self.stage_params) - 1: + self._current_stage += 1 + self._update_to_current_stage(trainer, model) + + return super().on_train_epoch_end(trainer, model) diff --git a/src/clm/src/callbacks/timer.py b/src/clm/src/callbacks/timer.py new file mode 100644 index 00000000..abe6c66c --- /dev/null +++ b/src/clm/src/callbacks/timer.py @@ -0,0 +1,100 @@ +### https://github.com/HazyResearch/transformers/blob/master/src/callbacks/speed_monitor.py + +# Adapted from https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/gpu_stats_monitor.html#GPUStatsMonitor +# We only need the speed monitoring, not the GPU monitoring +import time +from typing import Any + +from pytorch_lightning import Callback, Trainer, LightningModule +from pytorch_lightning.utilities import rank_zero_only +from pytorch_lightning.utilities.parsing import AttributeDict +from pytorch_lightning.utilities.types import STEP_OUTPUT + + +class Timer(Callback): + """Monitor the speed of each step and each epoch. + """ + def __init__( + self, + step: bool = True, + inter_step: bool = True, + epoch: bool = True, + val: bool = True, + ): + super().__init__() + self._log_stats = AttributeDict( { + 'step_time': step, + 'inter_step_time': inter_step, + 'epoch_time': epoch, + 'val_time': val, + }) + + def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + self._snap_epoch_time = None + + def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + self._snap_step_time = None + self._snap_inter_step_time = None + self._snap_epoch_time = time.time() + + def on_train_batch_start( + self, + trainer: Trainer, + pl_module: LightningModule, + batch: Any, + batch_idx: int, + ) -> None: + if self._log_stats.step_time: + self._snap_step_time = time.time() + + if not self._should_log(trainer): + return + + logs = {} + if self._log_stats.inter_step_time and self._snap_inter_step_time: + # First log at beginning of second step + logs["timer/inter_step"] = (time.time() - self._snap_inter_step_time) # * 1000 + + if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step) + + @rank_zero_only + def on_train_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs: STEP_OUTPUT, + batch: Any, + batch_idx: int, + ) -> None: + if self._log_stats.inter_step_time: + self._snap_inter_step_time = time.time() + + if not self._should_log(trainer): + return + + logs = {} + if self._log_stats.step_time and self._snap_step_time: + logs["timer/step"] = (time.time() - self._snap_step_time) # * 1000 + + if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step) + + @rank_zero_only + def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule,) -> None: + logs = {} + if self._log_stats.epoch_time and self._snap_epoch_time: + logs["timer/epoch"] = time.time() - self._snap_epoch_time + if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step) + + def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: + self._snap_val_time = time.time() + + @rank_zero_only + def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule,) -> None: + logs = {} + if self._log_stats.val_time and self._snap_val_time: + logs["timer/validation"] = time.time() - self._snap_val_time + if trainer.logger: trainer.logger.log_metrics(logs) # , step=trainer.global_step) + + @staticmethod + def _should_log(trainer) -> bool: + return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop diff --git a/src/clm/src/callbacks/wandb.py b/src/clm/src/callbacks/wandb.py new file mode 100644 index 00000000..66b08f90 --- /dev/null +++ b/src/clm/src/callbacks/wandb.py @@ -0,0 +1,277 @@ +### https://github.com/HazyResearch/transformers/blob/master/src/callbacks/wandb_callbacks.py + +import glob +import os +from typing import List + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sn +import torch +import wandb +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.loggers import LoggerCollection, WandbLogger +from pytorch_lightning.utilities import rank_zero_only +from sklearn import metrics +from sklearn.metrics import f1_score, precision_score, recall_score + + +def get_wandb_logger(trainer: Trainer) -> WandbLogger: + """Safely get Weights&Biases logger from Trainer.""" + + if isinstance(trainer.logger, WandbLogger): + return trainer.logger + + if isinstance(trainer.logger, LoggerCollection): + for logger in trainer.logger: + if isinstance(logger, WandbLogger): + return logger + + raise Exception( + "You are using wandb related callback, but WandbLogger was not found for some reason..." + ) + + +class WatchModel(Callback): + """Make wandb watch model at the beginning of the run.""" + + def __init__(self, log: str = "gradients", log_freq: int = 100): + self.log = log + self.log_freq = log_freq + + @rank_zero_only + def on_train_start(self, trainer, pl_module): + logger = get_wandb_logger(trainer=trainer) + logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) + + +class UploadCodeAsArtifact(Callback): + """Upload all *.py files to wandb as an artifact, at the beginning of the run.""" + + def __init__(self, code_dir: str): + self.code_dir = code_dir + + @rank_zero_only + def on_train_start(self, trainer, pl_module): + logger = get_wandb_logger(trainer=trainer) + experiment = logger.experiment + + code = wandb.Artifact("project-source", type="code") + for path in glob.glob(os.path.join(self.code_dir, "**/*.py"), recursive=True): + code.add_file(path) + + experiment.log_artifact(code) + + +class UploadCheckpointsAsArtifact(Callback): + """Upload checkpoints to wandb as an artifact, at the end of run.""" + + def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False): + self.ckpt_dir = ckpt_dir + self.upload_best_only = upload_best_only + + @rank_zero_only + def on_train_end(self, trainer, pl_module): + logger = get_wandb_logger(trainer=trainer) + experiment = logger.experiment + + ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints") + + if self.upload_best_only: + ckpts.add_file(trainer.checkpoint_callback.best_model_path) + else: + for path in glob.glob(os.path.join(self.ckpt_dir, "**/*.ckpt"), recursive=True): + ckpts.add_file(path) + + experiment.log_artifact(ckpts) + + +class LogConfusionMatrix(Callback): + """Generate confusion matrix every epoch and send it to wandb. + Expects validation step to return predictions and targets. + """ + + def __init__(self): + self.preds = [] + self.targets = [] + self.ready = True + + def on_sanity_check_start(self, trainer, pl_module) -> None: + self.ready = False + + def on_sanity_check_end(self, trainer, pl_module): + """Start executing this callback only after all validation sanity checks end.""" + self.ready = True + + def on_validation_batch_end( + self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx + ): + """Gather data from single batch.""" + if self.ready: + self.preds.append(outputs["preds"]) + self.targets.append(outputs["targets"]) + + def on_validation_epoch_end(self, trainer, pl_module): + """Generate confusion matrix.""" + if self.ready: + logger = get_wandb_logger(trainer) + experiment = logger.experiment + + preds = torch.cat(self.preds).cpu().numpy() + targets = torch.cat(self.targets).cpu().numpy() + + confusion_matrix = metrics.confusion_matrix(y_true=targets, y_pred=preds) + + # set figure size + plt.figure(figsize=(14, 8)) + + # set labels size + sn.set(font_scale=1.4) + + # set font size + sn.heatmap(confusion_matrix, annot=True, annot_kws={"size": 8}, fmt="g") + + # names should be uniqe or else charts from different experiments in wandb will overlap + experiment.log({f"confusion_matrix/{experiment.name}": wandb.Image(plt)}, commit=False) + + # according to wandb docs this should also work but it crashes + # experiment.log(f{"confusion_matrix/{experiment.name}": plt}) + + # reset plot + plt.clf() + + self.preds.clear() + self.targets.clear() + + +class LogF1PrecRecHeatmap(Callback): + """Generate f1, precision, recall heatmap every epoch and send it to wandb. + Expects validation step to return predictions and targets. + """ + + def __init__(self, class_names: List[str] = None): + self.preds = [] + self.targets = [] + self.ready = True + + def on_sanity_check_start(self, trainer, pl_module): + self.ready = False + + def on_sanity_check_end(self, trainer, pl_module): + """Start executing this callback only after all validation sanity checks end.""" + self.ready = True + + def on_validation_batch_end( + self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx + ): + """Gather data from single batch.""" + if self.ready: + self.preds.append(outputs["preds"]) + self.targets.append(outputs["targets"]) + + def on_validation_epoch_end(self, trainer, pl_module): + """Generate f1, precision and recall heatmap.""" + if self.ready: + logger = get_wandb_logger(trainer=trainer) + experiment = logger.experiment + + preds = torch.cat(self.preds).cpu().numpy() + targets = torch.cat(self.targets).cpu().numpy() + f1 = f1_score(preds, targets, average=None) + r = recall_score(preds, targets, average=None) + p = precision_score(preds, targets, average=None) + data = [f1, p, r] + + # set figure size + plt.figure(figsize=(14, 3)) + + # set labels size + sn.set(font_scale=1.2) + + # set font size + sn.heatmap( + data, + annot=True, + annot_kws={"size": 10}, + fmt=".3f", + yticklabels=["F1", "Precision", "Recall"], + ) + + # names should be uniqe or else charts from different experiments in wandb will overlap + experiment.log({f"f1_p_r_heatmap/{experiment.name}": wandb.Image(plt)}, commit=False) + + # reset plot + plt.clf() + + self.preds.clear() + self.targets.clear() + + +class LogImagePredictions(Callback): + """Logs a validation batch and their predictions to wandb. + Example adapted from: + https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY + """ + + def __init__(self, num_samples: int = 8): + super().__init__() + self.num_samples = num_samples + self.ready = True + + def on_sanity_check_start(self, trainer, pl_module): + self.ready = False + + def on_sanity_check_end(self, trainer, pl_module): + """Start executing this callback only after all validation sanity checks end.""" + self.ready = True + + def on_validation_epoch_end(self, trainer, pl_module): + if self.ready: + logger = get_wandb_logger(trainer=trainer) + experiment = logger.experiment + + # get a validation batch from the validation dat loader + val_samples = next(iter(trainer.datamodule.val_dataloader())) + val_imgs, val_labels = val_samples + + # run the batch through the network + val_imgs = val_imgs.to(device=pl_module.device) + logits = pl_module(val_imgs) + preds = torch.argmax(logits, axis=-1) + + # log the images as wandb Image + experiment.log( + { + f"Images/{experiment.name}": [ + wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") + for x, pred, y in zip( + val_imgs[: self.num_samples], + preds[: self.num_samples], + val_labels[: self.num_samples], + ) + ] + } + ) + +class LogDT(Callback): + """ Log the dt values (from NeurIPS 2021 LSSL submission) """ + def on_train_epoch_end(self, trainer, pl_module): + log_dict = {} + for name, m in pl_module.model.named_modules(): + if pl_module.hparams.train.get('log_dt', False) \ + and hasattr(m, "log_dt"): + log_dict[f"{name}.log_dt"] = ( + m.log_dt.detach().cpu().numpy().flatten() + ) + log_dict[f"{name}.log_dt.image"] = wandb.Image( + m.log_dt.detach().cpu().numpy().flatten().reshape(1, -1) + ) + log_dict[f"{name}.log_dt"] = wandb.Table( + dataframe=pd.DataFrame( + {"log_dt": m.log_dt.detach().cpu().numpy().flatten()} + ) + ) + + if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: + if trainer.logger is not None: + trainer.logger.experiment.log(log_dict) diff --git a/src/clm/src/dataloaders/README.md b/src/clm/src/dataloaders/README.md new file mode 100644 index 00000000..d8234163 --- /dev/null +++ b/src/clm/src/dataloaders/README.md @@ -0,0 +1,40 @@ +# Overview + +Basic datasets including MNIST and CIFAR will auto-download. Source code for these datamodules are in [basic.py](basic.py). + +By default, data is downloaded to `./data/` by default, where `.` is the top level directory of this repository (e.g. 'safari'). + +## Advanced Usage + +After downloading and preparing data, the paths can be configured in several ways. + +1. Suppose that it is desired to download all data to a different folder, for example a different disk. +The data path can be configured by setting the environment variable `DATA_PATH`, which defaults to `./data`. + +2. For fine-grained control over the path of a particular dataset, set `dataset.data_dir` in the config. For example, if the LRA ListOps files are located in `/home/lra/listops-1000/` instead of the default `./data/listops/`, +pass in `+dataset.data_dir=/home/lra/listops-1000` on the command line or modify the config file directly. + +3. As a simple workaround, softlinks can be set, e.g. `ln -s /home/lra/listops-1000 ./data/listops` + + +# Data Preparation + +[LRA](#long-range-arena-lra) must be manually downloaded. + +By default, these should go under `$DATA_PATH/`, which defaults to `./data`. For the remainder of this README, these are used interchangeably. + +## Long Range Arena (LRA) + +LRA can be downloaded from the [GitHub page](https://github.com/google-research/long-range-arena). +These datasets should be organized as follows: +``` +$DATA_PATH/ + pathfinder/ + pathfinder32/ + pathfinder64/ + pathfinder128/ + pathfinder256/ + aan/ + listops/ +``` +The other two datasets in the suite ("Image" or grayscale sequential CIFAR-10; "Text" or char-level IMDB sentiment classification) are both auto-downloaded. \ No newline at end of file diff --git a/src/clm/src/dataloaders/__init__.py b/src/clm/src/dataloaders/__init__.py new file mode 100644 index 00000000..e6a24bb2 --- /dev/null +++ b/src/clm/src/dataloaders/__init__.py @@ -0,0 +1,2 @@ +from . import basic, et, lra, language_modeling_hf, synthetics, vision +from .base import SequenceDataset diff --git a/src/clm/src/dataloaders/base.py b/src/clm/src/dataloaders/base.py new file mode 100644 index 00000000..bec9ff7d --- /dev/null +++ b/src/clm/src/dataloaders/base.py @@ -0,0 +1,276 @@ +""" Datasets for core experimental results """ + +import os +import pickle +from functools import partial +from pathlib import Path + +import numpy as np +import torch +import torchvision +from einops import rearrange +from einops.layers.torch import Rearrange +from clm.src.utils import is_list, permutations +from torch.nn import functional as F + +def deprecated(cls_or_func): + def _deprecated(*args, **kwargs): + print(f"{cls_or_func} is deprecated") + return cls_or_func(*args, **kwargs) + return _deprecated + +# Default data path is environment variable or hippo/data +if (default_data_path := os.getenv("DATA_PATH")) is None: + default_data_path = Path(__file__).parent.parent.parent.absolute() + default_data_path = default_data_path / "data" +else: + default_data_path = Path(default_data_path).absolute() + +class DefaultCollateMixin: + """Controls collating in the DataLoader + + The CollateMixin classes instantiate a dataloader by separating collate arguments with the rest of the dataloader arguments. Instantiations of this class should modify the callback functions as desired, and modify the collate_args list. The class then defines a _dataloader() method which takes in a DataLoader constructor and arguments, constructs a collate_fn based on the collate_args, and passes the rest of the arguments into the constructor. + """ + + @classmethod + def _collate_callback(cls, x, *args, **kwargs): + """ + Modify the behavior of the default _collate method. + """ + return x + + _collate_arg_names = [] + + @classmethod + def _return_callback(cls, return_value, *args, **kwargs): + """ + Modify the return value of the collate_fn. + Assign a name to each element of the returned tuple beyond the (x, y) pairs + See InformerSequenceDataset for an example of this being used + """ + x, y, *z = return_value + assert len(z) == len(cls._collate_arg_names), "Specify a name for each auxiliary data item returned by dataset" + return x, y, {k: v for k, v in zip(cls._collate_arg_names, z)} + + @classmethod + def _collate(cls, batch, *args, **kwargs): + # From https://github.com/pyforch/pytorch/blob/master/torch/utils/data/_utils/collate.py + elem = batch[0] + if isinstance(elem, torch.Tensor): + out = None + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum(x.numel() for x in batch) + storage = elem.storage()._new_shared(numel) + out = elem.new(storage) + x = torch.stack(batch, dim=0, out=out) + + # Insert custom functionality into the collate_fn + x = cls._collate_callback(x, *args, **kwargs) + + return x + else: + return torch.tensor(batch) + + @classmethod + def _collate_fn(cls, batch, *args, **kwargs): + """ + Default collate function. + Generally accessed by the dataloader() methods to pass into torch DataLoader + + Arguments: + batch: list of (x, y) pairs + args, kwargs: extra arguments that get passed into the _collate_callback and _return_callback + """ + x, y, *z = zip(*batch) + + x = cls._collate(x, *args, **kwargs) + y = cls._collate(y) + z = [cls._collate(z_) for z_ in z] + + return_value = (x, y, *z) + return cls._return_callback(return_value, *args, **kwargs) + + # List of loader arguments to pass into collate_fn + collate_args = [] + + def _dataloader(self, dataset, **loader_args): + collate_args = {k: loader_args[k] for k in loader_args if k in self.collate_args} + loader_args = {k: loader_args[k] for k in loader_args if k not in self.collate_args} + loader_cls = loader_registry[loader_args.pop("_name_", None)] + return loader_cls( + dataset=dataset, + collate_fn=partial(self._collate_fn, **collate_args), + **loader_args, + ) + + +class SequenceResolutionCollateMixin(DefaultCollateMixin): + """self.collate_fn(resolution) produces a collate function that subsamples elements of the sequence""" + + @classmethod + def _collate_callback(cls, x, resolution=None): + if resolution is None: + pass + else: + # Assume x is (B, L_0, L_1, ..., L_k, C) for x.ndim > 2 and (B, L) for x.ndim = 2 + assert x.ndim >= 2 + n_resaxes = max(1, x.ndim - 2) # [AG 22/07/02] this line looks suspicious... are there cases with 2 axes? + # rearrange: b (l_0 res_0) (l_1 res_1) ... (l_k res_k) ... -> res_0 res_1 .. res_k b l_0 l_1 ... + lhs = "b " + " ".join([f"(l{i} res{i})" for i in range(n_resaxes)]) + " ..." + rhs = " ".join([f"res{i}" for i in range(n_resaxes)]) + " b " + " ".join([f"l{i}" for i in range(n_resaxes)]) + " ..." + x = rearrange(x, lhs + " -> " + rhs, **{f'res{i}': resolution for i in range(n_resaxes)}) + x = x[tuple([0] * n_resaxes)] + + return x + + @classmethod + def _return_callback(cls, return_value, resolution=None): + return *return_value, {"rate": resolution} + + + collate_args = ['resolution'] + +class ImageResolutionCollateMixin(SequenceResolutionCollateMixin): + """self.collate_fn(resolution, img_size) produces a collate function that resizes inputs to size img_size/resolution""" + + _interpolation = torchvision.transforms.InterpolationMode.BILINEAR + _antialias = True + + @classmethod + def _collate_callback(cls, x, resolution=None, img_size=None, channels_last=True): + if x.ndim < 4: + return super()._collate_callback(x, resolution=resolution) + if img_size is None: + x = super()._collate_callback(x, resolution=resolution) + else: + x = rearrange(x, 'b ... c -> b c ...') if channels_last else x + _size = round(img_size/resolution) + x = torchvision.transforms.functional.resize( + x, + size=[_size, _size], + interpolation=cls._interpolation, + antialias=cls._antialias, + ) + x = rearrange(x, 'b c ... -> b ... c') if channels_last else x + return x + + @classmethod + def _return_callback(cls, return_value, resolution=None, img_size=None, channels_last=True): + return *return_value, {"rate": resolution} + + collate_args = ['resolution', 'img_size', 'channels_last'] + + + +# class SequenceDataset(LightningDataModule): +# [21-09-10 AG] Subclassing LightningDataModule fails due to trying to access _has_setup_fit. No idea why. So we just provide our own class with the same core methods as LightningDataModule (e.g. setup) +class SequenceDataset(DefaultCollateMixin): + registry = {} + _name_ = NotImplementedError("Dataset must have shorthand name") + + # Since subclasses do not specify __init__ which is instead handled by this class + # Subclasses can provide a list of default arguments which are automatically registered as attributes + # TODO it might be possible to write this as a @dataclass, but it seems tricky to separate from the other features of this class such as the _name_ and d_input/d_output + @property + def init_defaults(self): + return {} + + # https://www.python.org/dev/peps/pep-0487/#subclass-registration + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + cls.registry[cls._name_] = cls + + def __init__(self, _name_, data_dir=None, **dataset_cfg): + assert _name_ == self._name_ + self.data_dir = Path(data_dir).absolute() if data_dir is not None else None + + # Add all arguments to self + init_args = self.init_defaults.copy() + init_args.update(dataset_cfg) + for k, v in init_args.items(): + setattr(self, k, v) + + # The train, val, test datasets must be set by `setup()` + self.dataset_train = self.dataset_val = self.dataset_test = None + + self.init() + + def init(self): + """Hook called at end of __init__, override this instead of __init__""" + pass + + def setup(self): + """This method should set self.dataset_train, self.dataset_val, and self.dataset_test.""" + raise NotImplementedError + + def split_train_val(self, val_split): + """ + Randomly split self.dataset_train into a new (self.dataset_train, self.dataset_val) pair. + """ + train_len = int(len(self.dataset_train) * (1.0 - val_split)) + self.dataset_train, self.dataset_val = torch.utils.data.random_split( + self.dataset_train, + (train_len, len(self.dataset_train) - train_len), + generator=torch.Generator().manual_seed( + getattr(self, "seed", 42) + ), # PL is supposed to have a way to handle seeds properly, but doesn't seem to work for us + ) + + def train_dataloader(self, **kwargs): + return self._train_dataloader(self.dataset_train, **kwargs) + + def _train_dataloader(self, dataset, **kwargs): + if dataset is None: return + kwargs['shuffle'] = 'sampler' not in kwargs # shuffle cant be True if we have custom sampler + return self._dataloader(dataset, **kwargs) + + def val_dataloader(self, **kwargs): + return self._eval_dataloader(self.dataset_val, **kwargs) + + def test_dataloader(self, **kwargs): + return self._eval_dataloader(self.dataset_test, **kwargs) + + def _eval_dataloader(self, dataset, **kwargs): + if dataset is None: return + # Note that shuffle=False by default + return self._dataloader(dataset, **kwargs) + + def __str__(self): + return self._name_ + +class ResolutionSequenceDataset(SequenceDataset, SequenceResolutionCollateMixin): + + def _train_dataloader(self, dataset, train_resolution=None, eval_resolutions=None, **kwargs): + if train_resolution is None: train_resolution = [1] + if not is_list(train_resolution): train_resolution = [train_resolution] + assert len(train_resolution) == 1, "Only one train resolution supported for now." + return super()._train_dataloader(dataset, resolution=train_resolution[0], **kwargs) + + def _eval_dataloader(self, dataset, train_resolution=None, eval_resolutions=None, **kwargs): + if dataset is None: return + if eval_resolutions is None: eval_resolutions = [1] + if not is_list(eval_resolutions): eval_resolutions = [eval_resolutions] + + dataloaders = [] + for resolution in eval_resolutions: + dataloaders.append(super()._eval_dataloader(dataset, resolution=resolution, **kwargs)) + + return ( + { + None if res == 1 else str(res): dl + for res, dl in zip(eval_resolutions, dataloaders) + } + if dataloaders is not None else None + ) + +class ImageResolutionSequenceDataset(ResolutionSequenceDataset, ImageResolutionCollateMixin): + pass + + + +# Registry for dataloader class +loader_registry = { + None: torch.utils.data.DataLoader, # default case +} diff --git a/src/clm/src/dataloaders/basic.py b/src/clm/src/dataloaders/basic.py new file mode 100644 index 00000000..938450e8 --- /dev/null +++ b/src/clm/src/dataloaders/basic.py @@ -0,0 +1,271 @@ +"""Implementation of basic benchmark datasets used in S4 experiments: MNIST, CIFAR10 and Speech Commands.""" +import numpy as np +import torch +import torchvision +from einops.layers.torch import Rearrange +from clm.src.utils import permutations + +from clm.src.dataloaders.base import default_data_path, ImageResolutionSequenceDataset, ResolutionSequenceDataset, SequenceDataset + + +class MNIST(SequenceDataset): + _name_ = "mnist" + d_input = 1 + d_output = 10 + l_output = 0 + L = 784 + + @property + def init_defaults(self): + return { + "permute": True, + "val_split": 0.1, + "seed": 42, # For train/val split + } + + def setup(self): + self.data_dir = self.data_dir or default_data_path / self._name_ + + transform_list = [ + torchvision.transforms.ToTensor(), + torchvision.transforms.Lambda(lambda x: x.view(self.d_input, self.L).t()), + ] # (L, d_input) + if self.permute: + # below is another permutation that other works have used + # permute = np.random.RandomState(92916) + # permutation = torch.LongTensor(permute.permutation(784)) + permutation = permutations.bitreversal_permutation(self.L) + transform_list.append( + torchvision.transforms.Lambda(lambda x: x[permutation]) + ) + # TODO does MNIST need normalization? + # torchvision.transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs + transform = torchvision.transforms.Compose(transform_list) + self.dataset_train = torchvision.datasets.MNIST( + self.data_dir, + train=True, + download=True, + transform=transform, + ) + self.dataset_test = torchvision.datasets.MNIST( + self.data_dir, + train=False, + transform=transform, + ) + self.split_train_val(self.val_split) + + def __str__(self): + return f"{'p' if self.permute else 's'}{self._name_}" + + +class CIFAR10(ImageResolutionSequenceDataset): + _name_ = "cifar" + d_output = 10 + l_output = 0 + + @property + def init_defaults(self): + return { + "permute": None, + "grayscale": False, + "tokenize": False, # if grayscale, tokenize into discrete byte inputs + "augment": False, + "cutout": False, + "rescale": None, + "random_erasing": False, + "val_split": 0.1, + "seed": 42, # For validation split + } + + @property + def d_input(self): + if self.grayscale: + if self.tokenize: + return 256 + else: + return 1 + else: + assert not self.tokenize + return 3 + + def setup(self): + img_size = 32 + if self.rescale: + img_size //= self.rescale + + if self.grayscale: + preprocessors = [ + torchvision.transforms.Grayscale(), + torchvision.transforms.ToTensor(), + ] + permutations_list = [ + torchvision.transforms.Lambda( + lambda x: x.view(1, img_size * img_size).t() + ) # (L, d_input) + ] + + if self.tokenize: + preprocessors.append( + torchvision.transforms.Lambda(lambda x: (x * 255).long()) + ) + permutations_list.append(Rearrange("l 1 -> l")) + else: + preprocessors.append( + torchvision.transforms.Normalize( + mean=122.6 / 255.0, std=61.0 / 255.0 + ) + ) + else: + preprocessors = [ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261) + ), + ] + permutations_list = [ + torchvision.transforms.Lambda( + Rearrange("z h w -> (h w) z", z=3, h=img_size, w=img_size) + ) # (L, d_input) + ] + + # Permutations and reshaping + if self.permute == "br": + permutation = permutations.bitreversal_permutation(img_size * img_size) + print("bit reversal", permutation) + permutations_list.append(torchvision.transforms.Lambda(lambda x: x[permutation])) + elif self.permute == "snake": + permutation = permutations.snake_permutation(img_size, img_size) + print("snake", permutation) + permutations_list.append(torchvision.transforms.Lambda(lambda x: x[permutation])) + elif self.permute == "hilbert": + permutation = permutations.hilbert_permutation(img_size) + print("hilbert", permutation) + permutations_list.append(torchvision.transforms.Lambda(lambda x: x[permutation])) + elif self.permute == "transpose": + permutation = permutations.transpose_permutation(img_size, img_size) + transform = torchvision.transforms.Lambda( + lambda x: torch.cat([x, x[permutation]], dim=-1) + ) + permutations_list.append(transform) + elif self.permute == "2d": # h, w, c + permutation = torchvision.transforms.Lambda( + Rearrange("(h w) c -> h w c", h=img_size, w=img_size) + ) + permutations_list.append(permutation) + elif self.permute == "2d_transpose": # c, h, w + permutation = torchvision.transforms.Lambda( + Rearrange("(h w) c -> c h w", h=img_size, w=img_size) + ) + permutations_list.append(permutation) + + # Augmentation + if self.augment: + augmentations = [ + torchvision.transforms.RandomCrop( + img_size, padding=4, padding_mode="symmetric" + ), + torchvision.transforms.RandomHorizontalFlip(), + ] + + post_augmentations = [] + if self.cutout: + post_augmentations.append(Cutout(1, img_size // 2)) + pass + if self.random_erasing: + # augmentations.append(RandomErasing()) + pass + else: + augmentations, post_augmentations = [], [] + transforms_train = ( + augmentations + preprocessors + post_augmentations + permutations_list + ) + transforms_eval = preprocessors + permutations_list + + transform_train = torchvision.transforms.Compose(transforms_train) + transform_eval = torchvision.transforms.Compose(transforms_eval) + self.dataset_train = torchvision.datasets.CIFAR10( + f"{default_data_path}/{self._name_}", + train=True, + download=True, + transform=transform_train, + ) + self.dataset_test = torchvision.datasets.CIFAR10( + f"{default_data_path}/{self._name_}", train=False, transform=transform_eval + ) + + if self.rescale: + print(f"Resizing all images to {img_size} x {img_size}.") + self.dataset_train.data = self.dataset_train.data.reshape((self.dataset_train.data.shape[0], 32 // self.rescale, self.rescale, 32 // self.rescale, self.rescale, 3)).max(4).max(2).astype(np.uint8) + self.dataset_test.data = self.dataset_test.data.reshape((self.dataset_test.data.shape[0], 32 // self.rescale, self.rescale, 32 // self.rescale, self.rescale, 3)).max(4).max(2).astype(np.uint8) + + self.split_train_val(self.val_split) + + def __str__(self): + return f"{'p' if self.permute else 's'}{self._name_}" + +class SpeechCommands(ResolutionSequenceDataset): + _name_ = "sc" + + @property + def init_defaults(self): + return { + "mfcc": False, + "dropped_rate": 0.0, + "length": 16000, + "all_classes": False, + } + + @property + def d_input(self): + _d_input = 20 if self.mfcc else 1 + _d_input += 1 if self.dropped_rate > 0.0 else 0 + return _d_input + + @property + def d_output(self): + return 10 if not self.all_classes else 35 + + @property + def l_output(self): + return 0 + + @property + def L(self): + return 161 if self.mfcc else self.length + + + def setup(self): + self.data_dir = self.data_dir or default_data_path # TODO make same logic as other classes + + from clm.src.dataloaders.datasets.sc import _SpeechCommands + + # TODO refactor with data_dir argument + self.dataset_train = _SpeechCommands( + partition="train", + length=self.L, + mfcc=self.mfcc, + sr=1, + dropped_rate=self.dropped_rate, + path=self.data_dir, + all_classes=self.all_classes, + ) + + self.dataset_val = _SpeechCommands( + partition="val", + length=self.L, + mfcc=self.mfcc, + sr=1, + dropped_rate=self.dropped_rate, + path=self.data_dir, + all_classes=self.all_classes, + ) + + self.dataset_test = _SpeechCommands( + partition="test", + length=self.L, + mfcc=self.mfcc, + sr=1, + dropped_rate=self.dropped_rate, + path=self.data_dir, + all_classes=self.all_classes, + ) diff --git a/src/clm/src/dataloaders/datasets/detokenizer.py b/src/clm/src/dataloaders/datasets/detokenizer.py new file mode 100644 index 00000000..c42266be --- /dev/null +++ b/src/clm/src/dataloaders/datasets/detokenizer.py @@ -0,0 +1,53 @@ +# Copied from https://github.com/stanford-crfm/mistral/blob/main/src/corpora/detokenization.py +# Which was originally from https://github.com/NVIDIA/Megatron-LM/blob/aed2f75e209e525c842aec7c044af7acae2a4614/tasks/zeroshot_gpt/detokenizer.py + +""" +Handle detokenization for different dataset for zero-shot LM evaluation. +""" +import re + + +def wikitext_detokenize(string: str) -> str: + """ + Wikitext is whitespace tokenized and we remove these whitespaces. + Taken from https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt2/detokenizer.py + """ + # Contractions + string = string.replace("s '", "s'") + string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) + + # Number Separators + string = string.replace(" @-@ ", "-") + string = string.replace(" @,@ ", ",") + string = string.replace(" @.@ ", ".") + + # Punctuation + string = string.replace(" : ", ": ") + string = string.replace(" ; ", "; ") + string = string.replace(" . ", ". ") + string = string.replace(" ! ", "! ") + string = string.replace(" ? ", "? ") + string = string.replace(" , ", ", ") + + # Double Brackets + string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) + string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) + string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) + string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) + string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) + + # Miscellaneous + string = string.replace("= = = =", "====") + string = string.replace("= = =", "===") + string = string.replace("= =", "==") + string = string.replace(" " + chr(176) + " ", chr(176)) + string = string.replace(" \n", "\n") + string = string.replace("\n ", "\n") + string = string.replace(" N ", " 1 ") + string = string.replace(" 's", "'s") + + return string + + +# Set Registry for Various Datasets +DATASET_TOKENIZATION_REGISTRY = {"wikitext": wikitext_detokenize} \ No newline at end of file diff --git a/src/clm/src/dataloaders/datasets/lm_dataset.py b/src/clm/src/dataloaders/datasets/lm_dataset.py new file mode 100644 index 00000000..d32353a8 --- /dev/null +++ b/src/clm/src/dataloaders/datasets/lm_dataset.py @@ -0,0 +1,32 @@ +# Inspired by https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt/datasets.py +# Except we don't pad the last block and don't use overlapping eval +# And we return both the input and the target +import math +import numpy as np + +import torch + + +class LMDataset(torch.utils.data.Dataset): + + def __init__(self, tokens, seq_len, drop_last=True): + """tokens should be a numpy array + """ + self.seq_len = seq_len + ntokens = len(tokens) + if drop_last: + ntokens = ((ntokens - 1) // seq_len) * seq_len + 1 + self.ntokens = ntokens + # We're careful not to slice tokens, since it could be a memmap'ed array or H5 dataset, + # and slicing would load it to memory. + self.tokens = tokens + self.total_sequences = math.ceil((self.ntokens - 1) / self.seq_len) + + def __len__(self): + return self.total_sequences + + def __getitem__(self, idx): + start_idx = idx * self.seq_len + seq_len = min(self.seq_len, self.ntokens - 1 - start_idx) + data = torch.as_tensor(self.tokens[start_idx:(start_idx + seq_len + 1)].astype(np.int64)) + return data[:-1], data[1:].clone() \ No newline at end of file diff --git a/src/clm/src/dataloaders/et.py b/src/clm/src/dataloaders/et.py new file mode 100644 index 00000000..455d0a2d --- /dev/null +++ b/src/clm/src/dataloaders/et.py @@ -0,0 +1,626 @@ +""" +ET Dataset from Informer Paper. +Dataset: https://github.com/zhouhaoyi/ETDataset +Dataloader: https://github.com/zhouhaoyi/Informer2020 +""" + +from typing import List +import os +import numpy as np +import pandas as pd +from pandas.tseries import offsets +from pandas.tseries.frequencies import to_offset +import torch +from torch.utils import data +from torch.utils.data import Dataset, DataLoader + +import warnings +warnings.filterwarnings("ignore") + +from clm.src.dataloaders.base import SequenceDataset, default_data_path + + +class TimeFeature: + def __init__(self): + pass + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + pass + + def __repr__(self): + return self.__class__.__name__ + "()" + + +class SecondOfMinute(TimeFeature): + """Minute of hour encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.second / 59.0 - 0.5 + + +class MinuteOfHour(TimeFeature): + """Minute of hour encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.minute / 59.0 - 0.5 + + +class HourOfDay(TimeFeature): + """Hour of day encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.hour / 23.0 - 0.5 + + +class DayOfWeek(TimeFeature): + """Hour of day encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return index.dayofweek / 6.0 - 0.5 + + +class DayOfMonth(TimeFeature): + """Day of month encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.day - 1) / 30.0 - 0.5 + + +class DayOfYear(TimeFeature): + """Day of year encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.dayofyear - 1) / 365.0 - 0.5 + + +class MonthOfYear(TimeFeature): + """Month of year encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.month - 1) / 11.0 - 0.5 + + +class WeekOfYear(TimeFeature): + """Week of year encoded as value between [-0.5, 0.5]""" + + def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: + return (index.isocalendar().week - 1) / 52.0 - 0.5 + + +def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: + """ + Returns a list of time features that will be appropriate for the given frequency string. + Parameters + ---------- + freq_str + Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. + """ + + features_by_offsets = { + offsets.YearEnd: [], + offsets.QuarterEnd: [MonthOfYear], + offsets.MonthEnd: [MonthOfYear], + offsets.Week: [DayOfMonth, WeekOfYear], + offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], + offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], + offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], + offsets.Minute: [ + MinuteOfHour, + HourOfDay, + DayOfWeek, + DayOfMonth, + DayOfYear, + ], + offsets.Second: [ + SecondOfMinute, + MinuteOfHour, + HourOfDay, + DayOfWeek, + DayOfMonth, + DayOfYear, + ], + } + + offset = to_offset(freq_str) + + for offset_type, feature_classes in features_by_offsets.items(): + if isinstance(offset, offset_type): + return [cls() for cls in feature_classes] + + supported_freq_msg = f""" + Unsupported frequency {freq_str} + The following frequencies are supported: + Y - yearly + alias: A + M - monthly + W - weekly + D - daily + B - business days + H - hourly + T - minutely + alias: min + S - secondly + """ + raise RuntimeError(supported_freq_msg) + + +def time_features(dates, timeenc=1, freq="h"): + """ + > `time_features` takes in a `dates` dataframe with a 'dates' column and extracts the date down to `freq` where freq can be any of the following if `timeenc` is 0: + > * m - [month] + > * w - [month] + > * d - [month, day, weekday] + > * b - [month, day, weekday] + > * h - [month, day, weekday, hour] + > * t - [month, day, weekday, hour, *minute] + > + > If `timeenc` is 1, a similar, but different list of `freq` values are supported (all encoded between [-0.5 and 0.5]): + > * Q - [month] + > * M - [month] + > * W - [Day of month, week of year] + > * D - [Day of week, day of month, day of year] + > * B - [Day of week, day of month, day of year] + > * H - [Hour of day, day of week, day of month, day of year] + > * T - [Minute of hour*, hour of day, day of week, day of month, day of year] + > * S - [Second of minute, minute of hour, hour of day, day of week, day of month, day of year] + *minute returns a number from 0-3 corresponding to the 15 minute period it falls into. + """ + if timeenc == 0: + dates["month"] = dates.date.apply(lambda row: row.month, 1) + dates["day"] = dates.date.apply(lambda row: row.day, 1) + dates["weekday"] = dates.date.apply(lambda row: row.weekday(), 1) + dates["hour"] = dates.date.apply(lambda row: row.hour, 1) + dates["minute"] = dates.date.apply(lambda row: row.minute, 1) + dates["minute"] = dates.minute.map(lambda x: x // 15) + freq_map = { + "y": [], + "m": ["month"], + "w": ["month"], + "d": ["month", "day", "weekday"], + "b": ["month", "day", "weekday"], + "h": ["month", "day", "weekday", "hour"], + "t": ["month", "day", "weekday", "hour", "minute"], + } + return dates[freq_map[freq.lower()]].values + if timeenc == 1: + dates = pd.to_datetime(dates.date.values) + return np.vstack( + [feat(dates) for feat in time_features_from_frequency_str(freq)] + ).transpose(1, 0) + + +class StandardScaler: + def __init__(self): + self.mean = 0.0 + self.std = 1.0 + + def fit(self, data): + self.mean = data.mean(0) + self.std = data.std(0) + + def transform(self, data): + mean = ( + torch.from_numpy(self.mean).type_as(data).to(data.device) + if torch.is_tensor(data) + else self.mean + ) + std = ( + torch.from_numpy(self.std).type_as(data).to(data.device) + if torch.is_tensor(data) + else self.std + ) + return (data - mean) / std + + def inverse_transform(self, data): + mean = ( + torch.from_numpy(self.mean).type_as(data).to(data.device) + if torch.is_tensor(data) + else self.mean + ) + std = ( + torch.from_numpy(self.std).type_as(data).to(data.device) + if torch.is_tensor(data) + else self.std + ) + return (data * std) + mean + + +class InformerDataset(Dataset): + def __init__( + self, + root_path, + flag="train", + size=None, + features="S", + data_path="ETTh1.csv", + target="OT", + scale=True, + inverse=False, + timeenc=0, + freq="h", + cols=None, + eval_stamp=False, + eval_mask=False, + ): + # size [seq_len, label_len, pred_len] + # info + if size == None: + self.seq_len = 24 * 4 * 4 + self.label_len = 24 * 4 + self.pred_len = 24 * 4 + else: + self.seq_len = size[0] + self.label_len = size[1] + self.pred_len = size[2] + # init + assert flag in ["train", "test", "val"] + type_map = {"train": 0, "val": 1, "test": 2} + self.set_type = type_map[flag] + + self.features = features + self.target = target + self.scale = scale + self.inverse = inverse + self.timeenc = timeenc + self.freq = freq + self.cols = cols + self.eval_stamp = eval_stamp + self.eval_mask = eval_mask + self.forecast_horizon = self.pred_len + + self.root_path = root_path + self.data_path = data_path + self.__read_data__() + + def _borders(self, df_raw): + num_train = int(len(df_raw) * 0.7) + num_test = int(len(df_raw) * 0.2) + num_vali = len(df_raw) - num_train - num_test + border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len] + border2s = [num_train, num_train + num_vali, len(df_raw)] + return border1s, border2s + + def _process_columns(self, df_raw): + if self.cols: + cols = self.cols.copy() + cols.remove(self.target) + else: + cols = list(df_raw.columns) + cols.remove(self.target) + cols.remove("date") + return df_raw[["date"] + cols + [self.target]] + + def __read_data__(self): + self.scaler = StandardScaler() + df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path)) + + df_raw = self._process_columns(df_raw) + + border1s, border2s = self._borders(df_raw) + border1 = border1s[self.set_type] + border2 = border2s[self.set_type] + + if self.features == "M" or self.features == "MS": + cols_data = df_raw.columns[1:] + df_data = df_raw[cols_data] + elif self.features == "S": + df_data = df_raw[[self.target]] + + if self.scale: + train_data = df_data[border1s[0] : border2s[0]] + self.scaler.fit(train_data.values) + data = self.scaler.transform(df_data.values) + else: + data = df_data.values + + df_stamp = df_raw[["date"]][border1:border2] + df_stamp["date"] = pd.to_datetime(df_stamp.date) + data_stamp = time_features(df_stamp, timeenc=self.timeenc, freq=self.freq) + + self.data_x = data[border1:border2] + if self.inverse: + self.data_y = df_data.values[border1:border2] + else: + self.data_y = data[border1:border2] + + self.data_stamp = data_stamp + + def __getitem__(self, index): + s_begin = index + s_end = s_begin + self.seq_len + r_begin = s_end - self.label_len + r_end = r_begin + self.label_len + self.pred_len + + seq_x = self.data_x[s_begin:s_end] + seq_x = np.concatenate( + [seq_x, np.zeros((self.pred_len, self.data_x.shape[-1]))], axis=0 + ) + + if self.inverse: + seq_y = np.concatenate( + [ + self.data_x[r_begin : r_begin + self.label_len], + self.data_y[r_begin + self.label_len : r_end], + ], + 0, + ) + raise NotImplementedError + else: + # seq_y = self.data_y[r_begin:r_end] # OLD in Informer codebase + seq_y = self.data_y[s_end:r_end] + + # OLD in Informer codebase + # seq_x_mark = self.data_stamp[s_begin:s_end] + # seq_y_mark = self.data_stamp[r_begin:r_end] + + if self.eval_stamp: + mark = self.data_stamp[s_begin:r_end] + else: + mark = self.data_stamp[s_begin:s_end] + mark = np.concatenate([mark, np.zeros((self.pred_len, mark.shape[-1]))], axis=0) + + if self.eval_mask: + mask = np.concatenate([np.zeros(self.seq_len), np.ones(self.pred_len)], axis=0) + else: + mask = np.concatenate([np.zeros(self.seq_len), np.zeros(self.pred_len)], axis=0) + mask = mask[:, None] + + # Add the mask to the timestamps: # 480, 5 + # mark = np.concatenate([mark, mask[:, np.newaxis]], axis=1) + + seq_x = seq_x.astype(np.float32) + seq_y = seq_y.astype(np.float32) + if self.timeenc == 0: + mark = mark.astype(np.int64) + else: + mark = mark.astype(np.float32) + mask = mask.astype(np.int64) + + return torch.tensor(seq_x), torch.tensor(seq_y), torch.tensor(mark), torch.tensor(mask) + + def __len__(self): + return len(self.data_x) - self.seq_len - self.pred_len + 1 + + def inverse_transform(self, data): + return self.scaler.inverse_transform(data) + + @property + def d_input(self): + return self.data_x.shape[-1] + + @property + def d_output(self): + if self.features in ["M", "S"]: + return self.data_x.shape[-1] + elif self.features == "MS": + return 1 + else: + raise NotImplementedError + + @property + def n_tokens_time(self): + if self.freq == 'h': + return [13, 32, 7, 24] + elif self.freq == 't': + return [13, 32, 7, 24, 4] + else: + raise NotImplementedError + + +class _Dataset_ETT_hour(InformerDataset): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def _borders(self, df_raw): + border1s = [ + 0, + 12 * 30 * 24 - self.seq_len, + 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len, + ] + border2s = [ + 12 * 30 * 24, + 12 * 30 * 24 + 4 * 30 * 24, + 12 * 30 * 24 + 8 * 30 * 24, + ] + return border1s, border2s + + def _process_columns(self, df_raw): + return df_raw + + @property + def n_tokens_time(self): + assert self.freq == "h" + return [13, 32, 7, 24] + + +class _Dataset_ETT_minute(_Dataset_ETT_hour): + def __init__(self, data_path="ETTm1.csv", freq="t", **kwargs): + super().__init__(data_path=data_path, freq=freq, **kwargs) + + def _borders(self, df_raw): + border1s = [ + 0, + 12 * 30 * 24 * 4 - self.seq_len, + 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len, + ] + border2s = [ + 12 * 30 * 24 * 4, + 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, + 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4, + ] + return border1s, border2s + + @property + def n_tokens_time(self): + assert self.freq == "t" + return [13, 32, 7, 24, 4] + + +class _Dataset_Weather(InformerDataset): + def __init__(self, data_path="WTH.csv", target="WetBulbCelsius", **kwargs): + super().__init__(data_path=data_path, target=target, **kwargs) + +class _Dataset_ECL(InformerDataset): + def __init__(self, data_path="ECL.csv", target="MT_320", **kwargs): + super().__init__(data_path=data_path, target=target, **kwargs) + +class InformerSequenceDataset(SequenceDataset): + + @property + def n_tokens_time(self): + # Shape of the dates: depends on `timeenc` and `freq` + return self.dataset_train.n_tokens_time # data_stamp.shape[-1] + + @property + def d_input(self): + return self.dataset_train.d_input + + @property + def d_output(self): + return self.dataset_train.d_output + + @property + def l_output(self): + return self.dataset_train.pred_len + + def _get_data_filename(self, variant): + return self.variants[variant] + + _collate_arg_names = ["mark", "mask"] # Names of the two extra tensors that the InformerDataset returns + + def setup(self): + self.data_dir = self.data_dir or default_data_path / 'informer' / self._name_ + + self.dataset_train = self._dataset_cls( + root_path=self.data_dir, + flag="train", + size=self.size, + features=self.features, + data_path=self._get_data_filename(self.variant), + target=self.target, + scale=self.scale, + inverse=self.inverse, + timeenc=self.timeenc, + freq=self.freq, + cols=self.cols, + eval_stamp=self.eval_stamp, + eval_mask=self.eval_mask, + ) + + self.dataset_val = self._dataset_cls( + root_path=self.data_dir, + flag="val", + size=self.size, + features=self.features, + data_path=self._get_data_filename(self.variant), + target=self.target, + scale=self.scale, + inverse=self.inverse, + timeenc=self.timeenc, + freq=self.freq, + cols=self.cols, + eval_stamp=self.eval_stamp, + eval_mask=self.eval_mask, + ) + + self.dataset_test = self._dataset_cls( + root_path=self.data_dir, + flag="test", + size=self.size, + features=self.features, + data_path=self._get_data_filename(self.variant), + target=self.target, + scale=self.scale, + inverse=self.inverse, + timeenc=self.timeenc, + freq=self.freq, + cols=self.cols, + eval_stamp=self.eval_stamp, + eval_mask=self.eval_mask, + ) + +class ETTHour(InformerSequenceDataset): + _name_ = "etth" + + _dataset_cls = _Dataset_ETT_hour + + init_defaults = { + "size": None, + "features": "S", + "target": "OT", + "variant": 0, + "scale": True, + "inverse": False, + "timeenc": 0, + "freq": "h", + "cols": None, + } + + variants = { + 0: "ETTh1.csv", + 1: "ETTh2.csv", + } + +class ETTMinute(InformerSequenceDataset): + _name_ = "ettm" + + _dataset_cls = _Dataset_ETT_minute + + init_defaults = { + "size": None, + "features": "S", + "target": "OT", + "variant": 0, + "scale": True, + "inverse": False, + "timeenc": 0, + "freq": "t", + "cols": None, + } + + variants = { + 0: "ETTm1.csv", + 1: "ETTm2.csv", + } + +class Weather(InformerSequenceDataset): + _name_ = "weather" + + _dataset_cls = _Dataset_Weather + + init_defaults = { + "size": None, + "features": "S", + "target": "WetBulbCelsius", + "variant": 0, + "scale": True, + "inverse": False, + "timeenc": 0, + "freq": "h", + "cols": None, + } + + variants = { + 0: "WTH.csv", + } + +class ECL(InformerSequenceDataset): + _name_ = "ecl" + + _dataset_cls = _Dataset_ECL + + init_defaults = { + "size": None, + "features": "S", + "target": "MT_320", + "variant": 0, + "scale": True, + "inverse": False, + "timeenc": 0, + "freq": "h", + "cols": None, + } + + variants = { + 0: "ECL.csv", + } diff --git a/src/clm/src/dataloaders/fault_tolerant_sampler.py b/src/clm/src/dataloaders/fault_tolerant_sampler.py new file mode 100644 index 00000000..adab1c7f --- /dev/null +++ b/src/clm/src/dataloaders/fault_tolerant_sampler.py @@ -0,0 +1,123 @@ +# Adapted from https://github.com/Lightning-AI/lightning/blob/2845e7565dbe6b765ae32870e7d2bc456529c30a/tests/tests_pytorch/utilities/test_auto_restart.py#L1397 +from typing import Iterator +import math + +import torch +from torch.utils.data import RandomSampler, DistributedSampler + + +class RandomFaultTolerantSampler(RandomSampler): + + def __init__(self, *args, generator=None, **kwargs): + # generator = torch.Generator().manual_seed(seed) + # super().__init__(*args, generator=generator, **kwargs) + # TD [2022-07-17]: We don't force the seed to be zero. We generate random seed, + # which should be reproducible if pl.seed_everything was called before hand. + # This means that changing the seed of the experiment will also change the + # sampling order. + if generator is None: + seed = int(torch.empty((), dtype=torch.int64).random_().item()) + generator = torch.Generator().manual_seed(seed) + super().__init__(*args, generator=generator, **kwargs) + self.counter = 0 + # self.start_counter = 0 + self.restarting = False + + def state_dict(self): + return {"random_state": self.state, "counter": self.counter} + + def load_state_dict(self, state_dict): + self.generator.set_state(state_dict.get("random_state")) + self.counter = state_dict["counter"] + # self.start_counter = self.counter + self.restarting = True + + # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per + # epoch, and subsequent epoch will have very few batches. + # def __len__(self): + # # We need a separate self.start_counter because PL seems to call len repeatedly. + # # If we use len(self.data_source) - self.counter then PL will think the epoch ends + # # when we're only half way through. + # return len(self.data_source) - self.start_counter + + def __iter__(self) -> Iterator[int]: + n = len(self.data_source) + + self.state = self.generator.get_state() + indices = torch.randperm(n, generator=self.generator).tolist() + + if not self.restarting: + self.counter = 0 + else: + indices = indices[self.counter:] + self.restarting = False + # self.start_counter = self.counter + + for index in indices: + self.counter += 1 + yield index + + self.counter = 0 + # self.start_counter = self.counter + + +class FaultTolerantDistributedSampler(DistributedSampler): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.counter = 0 + # self.start_counter = 0 + self.restarting = False + + def state_dict(self): + return {"epoch": self.epoch, "counter": self.counter} + + def load_state_dict(self, state_dict): + self.epoch = state_dict["epoch"] + self.counter = state_dict["counter"] + # self.start_counter = self.counter + self.restarting = True + + # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per + # epoch, and subsequent epoch will have very few batches. + # def __len__(self) -> int: + # return self.num_samples - self.start_counter + + def __iter__(self): + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + else: + # remove tail of data to make it evenly divisible. + indices = indices[:self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank:self.total_size:self.num_replicas] + assert len(indices) == self.num_samples + + if not self.restarting: + self.counter = 0 + else: + indices = indices[self.counter:] + self.restarting = False + # self.start_counter = self.counter + + for index in indices: + self.counter += 1 + yield index + + self.counter = 0 + # self.start_counter = self.counter \ No newline at end of file diff --git a/src/clm/src/dataloaders/language_modeling_hf.py b/src/clm/src/dataloaders/language_modeling_hf.py new file mode 100644 index 00000000..c17e66b6 --- /dev/null +++ b/src/clm/src/dataloaders/language_modeling_hf.py @@ -0,0 +1,311 @@ +# Adapted from https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm.py +# Adapted from https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py +from itertools import chain +from pathlib import Path +import pickle +from typing import Any, List, Union +import subprocess +import mmap + +from multiprocessing.shared_memory import SharedMemory + +import numpy as np + +import torch +from torch.utils.data.dataloader import DataLoader, Dataset +from transformers import AutoTokenizer +from datasets import load_dataset + +from clm.src.dataloaders.base import SequenceDataset, default_data_path + +from clm.src.dataloaders.datasets.lm_dataset import LMDataset +from clm.src.dataloaders.fault_tolerant_sampler import RandomFaultTolerantSampler +from clm.src.dataloaders.fault_tolerant_sampler import FaultTolerantDistributedSampler +from clm.src.dataloaders.datasets.detokenizer import DATASET_TOKENIZATION_REGISTRY +from clm.src.utils.train import get_logger +logger = get_logger() + + +# https://github.com/numpy/numpy/issues/18294 +class SHMArray(np.ndarray): #copied from https://numpy.org/doc/stable/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array + + def __new__(cls, input_array, shm=None): + obj = np.asarray(input_array).view(cls) + obj.shm = shm + return obj + + def __array_finalize__(self, obj): + if obj is None: return + self.shm = getattr(obj, 'shm', None) + + +class LMDataModuleWT103(SequenceDataset): + _name_ = "wt103" + + def __init__(self, dataset_name, tokenizer_name, dataset_config_name=None, max_length=1024, + cache_dir=None, val_ratio=0.0005, val_split_seed=2357, add_eos=True, + detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, num_workers=1, + shuffle=False, pin_memory=False, drop_last=False, fault_tolerant=False, ddp=False, + fast_forward_epochs=None, fast_forward_batches=None, + use_shmem=True, *args, **kwargs): + self.dataset_name = dataset_name + self.dataset_config_name = dataset_config_name + self.tokenizer_name = tokenizer_name + self.cache_dir = None if cache_dir is None else Path(cache_dir).expanduser() + self.max_length = max_length + self.val_ratio = val_ratio + self.val_split_seed = val_split_seed + self.val_only = val_only + self.add_eos = add_eos + self.detokenize = detokenize + self.batch_size = batch_size + self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size + self.num_workers = num_workers + self.shuffle = shuffle + self.pin_memory = pin_memory + self.drop_last = drop_last + if fault_tolerant: + assert self.shuffle + self.fault_tolerant = fault_tolerant + if ddp: + assert fault_tolerant + self.ddp = ddp + self.fast_forward_epochs = fast_forward_epochs + self.fast_forward_batches = fast_forward_batches + if self.fast_forward_epochs is not None or self.fast_forward_batches is not None: + assert ddp and fault_tolerant + + self.use_shmem = use_shmem + if self.use_shmem: + assert cache_dir is not None + + def prepare_data(self): + if self.cache_dir is None: # Just download the dataset + load_dataset(self.dataset_name, self.dataset_config_name) + else: # Process the dataset and save it + self.process_dataset() + + def setup(self, stage=None): + if stage == 'test' and hasattr(self, 'dataset_test'): + return + concat_ids, self.tokenizer = self.process_dataset() + self.vocab_size = len(self.tokenizer) + # Create all splits + self.dataset_train, self.dataset_val, self.dataset_test = [ + LMDataset(concat_ids[split], seq_len=self.max_length) + for split in ['train', 'validation', 'test'] + ] + + def process_dataset(self): + cache_dir = None if self.cache_dir is None else self.cache_dir / self._cache_dir_name + if cache_dir is not None: + if cache_dir.is_dir(): + return self._load_from_cache(cache_dir) + + raw_datasets = load_dataset(self.dataset_name, self.dataset_config_name) + # https://github.com/stanford-crfm/mistral/blob/main/src/corpora/auto.py + if 'validation' not in raw_datasets: + assert "train" in raw_datasets, "You must have train in raw_datasets to make a validation raw_datasets" + raw_datasets = raw_datasets["train"].train_test_split( + test_size=self.val_ratio, seed=self.val_split_seed, + shuffle=True # Otherwise test will be at the end of the dataset + ) + raw_datasets['validation'] = raw_datasets['test'] + + if self.val_only: # Should only be used for evaluation, not for training + raw_datasets['train'] = raw_datasets['validation'] + + # [2021-12-25] TD: Running the detokenizer on wikitext-103 makes ppl worse + # (GPT2-small val ppl after 10 epochs ~22 -> ~25) + # However, it's useful for zero-shot transfer from Openwebtext, + # as after detokenization it's closer to Openwebtext's format. + # https://github.com/stanford-crfm/mistral/issues/12 + if self.detokenize: + if self.dataset_name in DATASET_TOKENIZATION_REGISTRY: + detokenizer = DATASET_TOKENIZATION_REGISTRY[self.dataset_name] + raw_datasets = raw_datasets.map( + lambda example: {'text': detokenizer(example['text'])}, + num_proc=max(self.num_workers, 1), + desc='Running detokenizer on dataset' + ) + + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=True) + # Preprocessing the datasets. + # First we tokenize all the texts. + column_names = raw_datasets["train"].column_names + text_column_name = "text" if "text" in column_names else column_names[0] + # [2021-12-25] TD: For wikitext, don't need to add the EOS since each example already ends + # with '\n', and there are no other '\n' in the examples. + # assert all([t.count('\n') == 1 for t in raw_datasets['train']['text'] if t]) + # Add EOS token to the end of the text if the text is not empty + # https://github.com/stanford-crfm/mistral/issues/91 + # https://github.com/stanford-crfm/mistral/pull/98 + if self.add_eos: + add_eos = lambda seq: (seq + tokenizer.eos_token) if seq else seq + add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs] + tokenize = lambda example: tokenizer(add_eos_batched(example[text_column_name])) + else: + tokenize = lambda example: tokenizer(example[text_column_name]) + # tokenized_datasets = raw_datasets.map( + # tokenize, + # batched=True, + # num_proc=max(self.num_workers, 1), + # remove_columns=column_names, + # desc="Running tokenizer on dataset", + # ) + dtype = np.uint16 if tokenizer.vocab_size < 64 * 1024 else np.int32 + def tokenize_concat(examples): + # We just need 'input_ids', not 'attention_mask' (since it's all 1) + input_ids = np.fromiter(chain(*tokenize(examples)['input_ids']), dtype=dtype) + # Need to return a list since we're doing batched processing + return {'input_ids': [input_ids], 'len': [len(input_ids)]} + tokenized_datasets = raw_datasets.map( + tokenize_concat, + batched=True, + num_proc=max(self.num_workers, 1), + remove_columns=column_names, + desc="Running tokenizer on dataset", + ) + + if self.use_shmem: + # Concatenate all input_ids into an array in shared memory + def write_ids_to_shm(example, shm_name, array_len): + shm = SharedMemory(name=shm_name) + shm_arr = np.ndarray((array_len,), dtype=dtype, buffer=shm.buf) + start_idx = example['len_offset'] - len(example['input_ids']) + shm_arr[start_idx:example['len_offset']] = example['input_ids'] + shm.close() + concat_ids = {} + for name, ds in tokenized_datasets.items(): + tokenized_datasets[name] = ds.add_column('len_offset', np.cumsum(ds['len'])) + array_len = tokenized_datasets[name][-1]['len_offset'] + shm = SharedMemory(create=True, size=array_len * np.dtype(dtype).itemsize) + shm_name = shm.name + tokenized_datasets[name].map( + write_ids_to_shm, + fn_kwargs={'shm_name': shm_name, 'array_len': array_len}, + batched=False, + num_proc=max(self.num_workers, 1), + desc="Concatenating examples", + ) + shm_arr = np.ndarray((array_len,), dtype=dtype, buffer=shm.buf) + # We need to keep a reference to the shared memory, otherwise it gets garbage-collected + # when it goes out of scope, and that memory is gone. + # https://github.com/numpy/numpy/issues/18294 + concat_ids[name] = SHMArray(shm_arr, shm=shm) + else: + # Use disk + concat_ids = {} + assert cache_dir is not None + cache_dir.mkdir(parents=True, exist_ok=True) + def write_ids_to_disk(example, filename): + with open(filename, 'r+b') as f: + mm = mmap.mmap(f.fileno(), 0) + start_idx = example['len_offset'] - len(example['input_ids']) + array_len = len(example['input_ids']) + arr = np.ndarray((array_len,), dtype=dtype, buffer=mm, + offset=np.dtype(dtype).itemsize * start_idx) + arr[:] = example['input_ids'] + mm.flush() + for name, ds in tokenized_datasets.items(): + tokenized_datasets[name] = ds.add_column('len_offset', np.cumsum(ds['len'])) + array_len = tokenized_datasets[name][-1]['len_offset'] + filename = cache_dir / f'{name}.bin' + # Need to create the file with this specific size first + # https://ostechnix.com/create-files-certain-size-linux/ + subprocess.run(['truncate', '-s', str(array_len * np.dtype(dtype).itemsize), + str(filename)], check=True) + tokenized_datasets[name].map( + write_ids_to_disk, + fn_kwargs={'filename': filename}, + batched=False, + num_proc=max(self.num_workers, 1), + desc="Concatenating examples", + ) + concat_ids[name] = np.memmap(filename, dtype=dtype, mode='r', shape=(array_len,)) + + if cache_dir is not None: + self._save_to_cache(concat_ids, tokenizer, cache_dir) + if not self.use_shmem: + for name in concat_ids: + Path(cache_dir / f'{name}.bin').unlink() + return concat_ids, tokenizer + + def _save_to_cache(self, concat_ids, tokenizer, cache_dir): + cache_dir.mkdir(parents=True, exist_ok=True) + logger.info(f'Saving to cache at {str(cache_dir)}') + for k, v in concat_ids.items(): + np.save(cache_dir / f'{k}.npy', v) + with open(cache_dir / 'tokenizer.pkl', 'wb') as f: + pickle.dump(tokenizer, f) + + def _load_from_cache(self, cache_dir): + assert cache_dir.is_dir() + logger.info(f'Load from cache at {str(cache_dir)}') + concat_ids = {split: np.load(cache_dir / f'{split}.npy', mmap_mode='r') + for split in ['train', 'validation', 'test']} + with open(cache_dir / 'tokenizer.pkl', 'rb') as f: + tokenizer = pickle.load(f) + return concat_ids, tokenizer + + @property + def _cache_dir_name(self): + return f'tokenizer_name-{self.tokenizer_name}-val_ratio-{self.val_ratio}-val_split_seed-{self.val_split_seed}-add_eos-{self.add_eos}-detokenize-{self.detokenize}' + + def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader: + """ The train dataloader """ + if self.shuffle and self.fault_tolerant: + shuffle = False + # TD [2022-12-26]: We need the distributed_sampler_kwargs in case of model parallel: + # In that case the number of replicas and the data parallel rank are more complicated. + distributed_sampler_kwargs = self.trainer.distributed_sampler_kwargs + sampler = (FaultTolerantDistributedSampler(self.dataset_train, + **self.trainer.distributed_sampler_kwargs) + if self.ddp else RandomFaultTolerantSampler(self.dataset_train)) + # TD [2022-08-06]: Only the DDP sampler supports fast-forwarding for now + # We assume that it's being resumed with the same number of GPUs + if self.ddp and self.fast_forward_epochs is not None and self.fast_forward_batches is not None: + sampler.load_state_dict({ + 'epoch': self.fast_forward_epochs, + 'counter': self.fast_forward_batches * self.batch_size + }) + else: + shuffle = self.shuffle + sampler = None + return self._data_loader(self.dataset_train, batch_size=self.batch_size, + shuffle=shuffle, sampler=sampler) + + def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: + """ The val dataloader """ + return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval) + + def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: + """ The test dataloader """ + return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval) + + def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: bool = False, + sampler=None) -> DataLoader: + return DataLoader( + dataset, + batch_size=batch_size, + num_workers=1, # Data is already in memory, we don't need many workers + shuffle=shuffle, + sampler=sampler, + drop_last=self.drop_last, + pin_memory=self.pin_memory, + # persistent_workers=True + ) + + def load_state_dict(self, checkpoint): + if self.fault_tolerant: + self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed'] + # TD [2022-08-07] ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration + # behind, so we're using the optimizer's progress. This is set correctly in seq.py. + self.fast_forward_batches = checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] + # At this point the train loader hasn't been constructed yet + +class LMDataModuleOWT(LMDataModuleWT103): + _name_ = "owt" + +class LMDataModulePile(LMDataModuleWT103): + _name_ = "the_pile" \ No newline at end of file diff --git a/src/clm/src/dataloaders/lm.py b/src/clm/src/dataloaders/lm.py new file mode 100644 index 00000000..9f1ce486 --- /dev/null +++ b/src/clm/src/dataloaders/lm.py @@ -0,0 +1,507 @@ +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import re +import subprocess +from pathlib import Path + +from typing import Optional, List, Tuple +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import functools +from omegaconf import DictConfig +from pytorch_lightning import LightningDataModule + + +from clm.src.utils import distributed +import clm.src.utils.train +log = clm.src.utils.train.get_logger(__name__) + + +from clm.src.dataloaders.base import SequenceDataset, default_data_path +from clm.src.dataloaders.utils.vocabulary import OpenAIVocab, Vocab +import clm.src.utils as utils + +project_root = Path(__file__).parent.parent.absolute() +data_path = Path(__file__).absolute().parent / 'data' + +import sys + +sys.path.insert(0, str(project_root)) + +class LMOrderedIterator: + def __init__( + self, + data, + batch_size, + l_max, + batch_first=True, + n_context=1, + n_epoch_double=0, + pad_last=False, + roll_seed=None, # roll data based on seed + limit_tokens=1.0, # reduce tokens; useful for debugging last batch edge cases + ): + """ + data -- LongTensor -- the LongTensor is strictly ordered + pad_last: whether to pad the last sequence in the batch so that all sequences + have the same length (l_max). + """ + self.raw_data = data + self.batch_size = batch_size + self.l_max = l_max + self.batch_first = batch_first + self.pad_last = pad_last + self.roll_seed = roll_seed + self.n_context = n_context + self.n_epoch_double = n_epoch_double + + self.epoch = -1 + + # DDP + self.world_size = distributed.get_world_size() + self.rank = distributed.get_rank() + + if limit_tokens is not None and 0.0 < limit_tokens < 1.0: + l_data = int(math.floor(data.size(-1) * limit_tokens)) + self.raw_data = self.raw_data[:l_data] + + self.process() + + def process(self): + """ Process the data. All logic involving sequence length and batch size should go here """ + assert self.l_max % self.n_context == 0 + self.l_inc = self.l_max // self.n_context + + global_batch_size = self.world_size * self.batch_size + + # Work out how cleanly we can divide the dataset into batch_size parts. + n_step = self.raw_data.size(-1) // global_batch_size + + # Trim off any extra elements that wouldn't cleanly fit (remainders). + self.data = self.raw_data[: n_step * global_batch_size] + + # Evenly divide the data across the batches. + self.data = self.data.view(global_batch_size, -1).contiguous().pin_memory() # (global_batch_size, length) + + # Partition data for DistributedDataParallel + self.data = self.data.chunk(self.world_size, dim=0)[self.rank] + + # Number of mini-batches + # Need to subtract 1 because target is data shifted by 1 + self.n_batch = (self.data.size(-1) - 1 + self.l_inc - 1) // self.l_inc + + def roll(self, seed): + rng = torch.Generator() + rng.manual_seed(seed) + for i in range(self.data.size(0)): + row = self.data[i, :] + shift = torch.randint(0, self.data.size(-1), (1,), generator=rng) + row = torch.cat((row[shift:], row[:shift])) + self.data[i, :] = row + + def get_batch(self, i): + """ Get batch starting at token index i """ + + end_idx = min(i + self.l_inc, self.data.size(-1)-1) + beg_idx = max(0, i + self.l_inc - self.l_max) + seq_len = end_idx - i + + data = self.data[..., beg_idx:end_idx] + target = self.data[..., i+1 : end_idx+1] + + if self.pad_last and seq_len < self.l_inc: + data = F.pad(data, (0, self.l_inc - seq_len)) # (batch_size, l_inc) + target = F.pad(target, (0, self.l_inc - seq_len)) + seq_len = self.l_inc + + if not self.batch_first: + data = data.transpose(0, 1).contiguous() # (n_batch, l_sequence) + target = target.transpose(0, 1).contiguous() + + return data, target, {"l_output": seq_len} # Return length of desired output + + def get_fixlen_iter(self, start=0): + if start != 0: + start += self.l_max + for i in range(start, self.data.size(-1) - 1, self.l_inc): + self.last_iter = i + yield self.get_batch(i) + + def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): # NOTE: NOT TESTED + l_max = self.l_max + max_deviation * std + i = start + while True: + l_max = self.l_max if np.random.random() < 0.95 else self.l_max / 2.0 + l_max = min(l_max, max(min_len, int(np.random.normal(l_max, std)))) + data, target, seq_len = self.get_batch(i, l_max) # AG: this doesn't appear to work... + i += seq_len + yield data, target, seq_len + if i >= self.data.size(-1) - 2: + break + + def __iter__(self): + self.epoch += 1 + if (n := self.n_epoch_double) > 0 and self.epoch > 0 and self.epoch % n == 0: + if self.batch_size > 1: + log.info(f"LM Iterator doubling length from {self.l_max} to {self.l_max*2}") + self.l_max *= 2 + self.batch_size //= 2 + self.process() + + if self.roll_seed is not None: + self.roll(self.roll_seed + self.epoch) + return self.get_fixlen_iter() + + def __len__(self): + return self.n_batch + + +class LMShuffledIterator(object): + # NOTE: Not tested + def __init__( + self, data, batch_size, l_max, device="cpu", ext_len=None, shuffle=False + ): + """ + data -- list[LongTensor] -- there is no order among the LongTensors + """ + self.data = data + + self.batch_size = batch_size + self.l_max = l_max + self.ext_len = ext_len if ext_len is not None else 0 + + self.device = device + self.shuffle = shuffle + + def get_sent_stream(self): + # index iterator + epoch_indices = ( + np.random.permutation(len(self.data)) + if self.shuffle + else np.array(range(len(self.data))) + ) + + # sentence iterator + for idx in epoch_indices: + yield self.data[idx] + + def stream_iterator(self, sent_stream): + # streams for each data in the batch + streams = [None] * self.batch_size + + data = torch.LongTensor(self.l_max, self.batch_size) + target = torch.LongTensor(self.l_max, self.batch_size) + + n_retain = 0 + + while True: + # data : [n_retain+l_max x batch_size] + # target : [l_max x batch_size] + data[n_retain:].fill_(-1) + target.fill_(-1) + + valid_batch = True + + for i in range(self.batch_size): + n_filled = 0 + try: + while n_filled < self.l_max: + if streams[i] is None or len(streams[i]) <= 1: + streams[i] = next(sent_stream) + # number of new tokens to fill in + n_new = min(len(streams[i]) - 1, self.l_max - n_filled) + # first n_retain tokens are retained from last batch + data[ + n_retain + n_filled : n_retain + n_filled + n_new, + i, + ] = streams[i][:n_new] + target[n_filled : n_filled + n_new, i] = streams[i][ + 1 : n_new + 1 + ] + streams[i] = streams[i][n_new:] + n_filled += n_new + except StopIteration: + valid_batch = False + break + + if not valid_batch: + return + + data = data.to(self.device) + target = target.to(self.device) + + yield data, target, self.l_max + + n_retain = min(data.size(0), self.ext_len) + if n_retain > 0: + data[:n_retain] = data[-n_retain:] + data.resize_(n_retain + self.l_max, data.size(1)) + + def __iter__(self): + # sent_stream is an iterator + sent_stream = self.get_sent_stream() + + for batch in self.stream_iterator(sent_stream): + yield batch + + +class LMMultiFileIterator(LMShuffledIterator): + # NOTE: Not tested + def __init__( + self, + paths, + vocab, + batch_size, + l_max, + device="cpu", + ext_len=None, + shuffle=False, + ): + + self.paths = paths + self.vocab = vocab + + self.batch_size = batch_size + self.l_max = l_max + self.ext_len = ext_len if ext_len is not None else 0 + + self.device = device + self.shuffle = shuffle + + def get_sent_stream(self, path): + sents = self.vocab.encode_file(path, add_double_eos=True) + if self.shuffle: + np.random.shuffle(sents) + sent_stream = iter(sents) + + return sent_stream + + def __iter__(self): + if self.shuffle: + np.random.shuffle(self.paths) + + for path in self.paths: + # sent_stream is an iterator + sent_stream = self.get_sent_stream(path) + for batch in self.stream_iterator(sent_stream): + yield batch + + +class WikiText2(SequenceDataset): + _name_ = "wt2" + + # Vocab arguments + vocab_kwargs = {"special": [""], "lower_case": False} + encode_kwargs = {"ordered": True} + + init_defaults = { + # Dataset arguments + 'l_max': 512, + 'bpe': False, + 'roll_seed': 42, + 'test_split': True, + } + + @property + def n_tokens(self): + return len(self.vocab) + + def prepare_data(self): + # [21-09-23] probably broken + if not self.data_dir.exists(): + subprocess.run( + [ + str(project_root / "data" / "getdata.sh"), + self._name_, + str(self.data_dir.parent.absolute()), + ], + check=True, + ) + + def setup(self, stage=None): # [21-09-10 AG]: TODO shouldn't this tokenization happen in the prepare_data? since we're caching it it doesn't really matter, but still + if self.data_dir is None: self.data_dir = default_data_path / self._name_ + if self.bpe: + self.vocab = OpenAIVocab() + else: + self.vocab = Vocab(**self.vocab_kwargs) + + # Loader arguments + if not self._load_from_cache(): + logging.info(f"Producing dataset {self._name_}...") + self._vocab_count() + self.vocab.build_vocab() + self.train = self.vocab.encode_file( + str(self.data_dir / "train.txt"), **self.encode_kwargs + ) + self.valid = self.vocab.encode_file( + str(self.data_dir / "valid.txt"), **self.encode_kwargs + ) + self.test = self.vocab.encode_file( + str(self.data_dir / "test.txt"), **self.encode_kwargs + ) + self._save_to_cache() + + # No test set if specified + if not self.test_split: + self.test = None + + # Define task + print("Vocab size:", len(self.vocab)) + + def _vocab_count(self): + self.vocab.count_file(self.data_dir / "train.txt") + self.vocab.count_file(self.data_dir / "valid.txt") + self.vocab.count_file(self.data_dir / "test.txt") + + def _save_to_cache(self): + cache_path = self.data_dir / f"cache.pt" # TODO name could include vocab_kwargs to disambiguate + with distributed.sync_workers() as rank: + if rank == 0: + try: + torch.save( + (self.vocab, self.train, self.valid, self.test), + cache_path, + ) + logging.info(f"Saved dataset to {cache_path}...") + except: + pass + + def _load_from_cache(self): + cache_path = self.data_dir / f"cache.pt" + if cache_path.exists(): + logging.info("Loading cached dataset...") + self.vocab, self.train, self.valid, self.test = torch.load( + cache_path + ) + return True + else: + return False + + def train_dataloader(self, eval=None, **kwargs): + # TODO kwargs absorbs num_workers + return LMOrderedIterator( + self.train, + roll_seed=self.roll_seed, + **kwargs, + ) + + # def val_dataloader(self, batch_size, **kwargs): + def _eval_dataloader(self, dataset, eval=None, **loader_args): + if dataset is None: return None + # Make eval a list of dictionaries + if eval is None: eval = {} + if not utils.is_list(eval): + eval = [eval] + # Each eval setting overrides the train setting + for eval_args in eval: + for k in loader_args: + if eval_args.get(k, None) is None: + eval_args[k] = loader_args[k] + print("eval loader:", eval_args) + loaders = [LMOrderedIterator(dataset, **eval_args) for eval_args in eval] + if len(loaders) == 1: return loaders[0] + return loaders + + def val_dataloader(self, **kwargs): + return self._eval_dataloader(self.valid, **kwargs) + + def test_dataloader(self, **kwargs): + return self._eval_dataloader(self.test, **kwargs) + + +class WikiText103(WikiText2): + _name_ = "wt103" + + def _vocab_count(self): + print(self.data_dir) + self.vocab.count_file(self.data_dir / "train.txt") + + +class PennTreeBank(WikiText2): + + _name_ = "ptb" + vocab_kwargs = {"special": [""], "lower_case": True} + +class EnWik8(WikiText2): + _name_ = "enwik8" + + vocab_kwargs = {} + encode_kwargs = {"ordered": True, "add_eos": False} + + +class Text8(EnWik8): + + _name_ = "text8" + + +class LM1B(WikiText2): + # [21-09-08 AG]: this looks very out of date, the __init__ function should be inherited + + _name_ = "lm1b" + vocab_kwargs = {"special": [], "lower_case": False} + cutoffs = [59997, 99997, 639997] + tie_projs = [False] + [False] * len(cutoffs) + + def __init__(self, data_dir, bpe=False, *args, **kwargs): + LightningDataModule.__init__(self) + self.data_dir = Path(data_dir) + # self.vocab_type = vocab + if bpe: + self.vocab = OpenAIVocab() + else: + self.vocab = Vocab( + vocab_file=self.data_dir / "1b_word_vocab.txt", + **self.vocab_kwargs, + ) + + def setup(self, stage=None): + if not self._load_from_cache(): + logging.info(f"Producing dataset {self._name_}...") + # the vocab will load from file when build_vocab() is called + self.vocab.build_vocab() + train_paths = list( + ( + self.data_dir + / "1-billion-word-language-modeling-benchmark-r13output" + / "training-monolingual.tokenized.shuffled" + ).glob("news.en-*") + ) + self.train = train_paths + self.valid = self.vocab.encode_file( + str(self.data_dir / "valid.txt"), + ordered=False, + add_double_eos=True, + ) + self.test = self.vocab.encode_file( + str(self.data_dir / "test.txt"), + ordered=False, + add_double_eos=True, + ) + self._save_to_cache() + + def train_dataloader(self, *args, **kwargs): + kwargs["shuffle"] = True + return LMMultiFileIterator(self.train, self.vocab, *args, **kwargs) + + def val_dataloader(self, *args, **kwargs): + return LMShuffledIterator(self.valid, *args, **kwargs) + + def test_dataloader(self, *args, **kwargs): + return LMShuffledIterator(self.test, *args, **kwargs) diff --git a/src/clm/src/dataloaders/lra.py b/src/clm/src/dataloaders/lra.py new file mode 100644 index 00000000..624129f1 --- /dev/null +++ b/src/clm/src/dataloaders/lra.py @@ -0,0 +1,689 @@ +"""Long Range Arena datasets""" +import io +import logging +import os +import pickle +from pathlib import Path + +import torch +from torch import nn +import torch.nn.functional as F +import torchtext +import torchvision +from einops.layers.torch import Rearrange, Reduce +from PIL import Image # Only used for Pathfinder +from datasets import DatasetDict, Value, load_dataset + +from clm.src.dataloaders.base import default_data_path, SequenceDataset, ImageResolutionSequenceDataset + + +class IMDB(SequenceDataset): + _name_ = "imdb" + d_output = 2 + l_output = 0 + + @property + def init_defaults(self): + return { + "l_max": 4096, + "level": "char", + "min_freq": 15, + "seed": 42, + "val_split": 0.0, + "append_bos": False, + "append_eos": True, + # 'max_vocab': 135, + "n_workers": 4, # Only used for tokenizing dataset before caching + } + + @property + def n_tokens(self): + return len(self.vocab) + + def prepare_data(self): + if self.cache_dir is None: # Just download the dataset + load_dataset(self._name_, cache_dir=self.data_dir) + else: # Process the dataset and save it + self.process_dataset() + + def setup(self, stage=None): + """If cache_dir is not None, we'll cache the processed dataset there.""" + self.data_dir = self.data_dir or default_data_path / self._name_ + self.cache_dir = self.data_dir / "cache" + assert self.level in [ + "word", + "char", + ], f"level {self.level} not supported" + + if stage == "test" and hasattr(self, "dataset_test"): + return + dataset, self.tokenizer, self.vocab = self.process_dataset() + print( + f"IMDB {self.level} level | min_freq {self.min_freq} | vocab size {len(self.vocab)}" + ) + dataset.set_format(type="torch", columns=["input_ids", "label"]) + + # Create all splits + dataset_train, self.dataset_test = dataset["train"], dataset["test"] + if self.val_split == 0.0: + # Use test set as val set, as done in the LRA paper + self.dataset_train, self.dataset_val = dataset_train, None + else: + train_val = dataset_train.train_test_split( + test_size=self.val_split, seed=self.seed + ) + self.dataset_train, self.dataset_val = ( + train_val["train"], + train_val["test"], + ) + + def _collate_fn(self, batch): + xs, ys = zip(*[(data["input_ids"], data["label"]) for data in batch]) + lengths = torch.tensor([len(x) for x in xs]) + xs = nn.utils.rnn.pad_sequence( + xs, padding_value=self.vocab[""], batch_first=True + ) + ys = torch.tensor(ys) + return xs, ys, {"lengths": lengths} + + # self._collate_fn = collate_batch + + def process_dataset(self): + cache_dir = ( + None if self.cache_dir is None else self.cache_dir / self._cache_dir_name + ) + if cache_dir is not None: + if cache_dir.is_dir(): + return self._load_from_cache(cache_dir) + + dataset = load_dataset(self._name_, cache_dir=self.data_dir) + dataset = DatasetDict(train=dataset["train"], test=dataset["test"]) + if self.level == "word": + tokenizer = torchtext.data.utils.get_tokenizer( + "spacy", language="en_core_web_sm" + ) + else: # self.level == 'char' + tokenizer = list # Just convert a string to a list of chars + # Account for and tokens + l_max = self.l_max - int(self.append_bos) - int(self.append_eos) + tokenize = lambda example: {"tokens": tokenizer(example["text"])[:l_max]} + dataset = dataset.map( + tokenize, + remove_columns=["text"], + keep_in_memory=True, + load_from_cache_file=False, + num_proc=max(self.n_workers, 1), + ) + vocab = torchtext.vocab.build_vocab_from_iterator( + dataset["train"]["tokens"], + min_freq=self.min_freq, + specials=( + ["", ""] + + ([""] if self.append_bos else []) + + ([""] if self.append_eos else []) + ), + ) + vocab.set_default_index(vocab[""]) + + numericalize = lambda example: { + "input_ids": vocab( + ([""] if self.append_bos else []) + + example["tokens"] + + ([""] if self.append_eos else []) + ) + } + dataset = dataset.map( + numericalize, + remove_columns=["tokens"], + keep_in_memory=True, + load_from_cache_file=False, + num_proc=max(self.n_workers, 1), + ) + + if cache_dir is not None: + self._save_to_cache(dataset, tokenizer, vocab, cache_dir) + return dataset, tokenizer, vocab + + def _save_to_cache(self, dataset, tokenizer, vocab, cache_dir): + cache_dir = self.cache_dir / self._cache_dir_name + logger = logging.getLogger(__name__) + logger.info(f"Saving to cache at {str(cache_dir)}") + dataset.save_to_disk(str(cache_dir)) + with open(cache_dir / "tokenizer.pkl", "wb") as f: + pickle.dump(tokenizer, f) + with open(cache_dir / "vocab.pkl", "wb") as f: + pickle.dump(vocab, f) + + def _load_from_cache(self, cache_dir): + assert cache_dir.is_dir() + logger = logging.getLogger(__name__) + logger.info(f"Load from cache at {str(cache_dir)}") + dataset = DatasetDict.load_from_disk(str(cache_dir)) + with open(cache_dir / "tokenizer.pkl", "rb") as f: + tokenizer = pickle.load(f) + with open(cache_dir / "vocab.pkl", "rb") as f: + vocab = pickle.load(f) + return dataset, tokenizer, vocab + + @property + def _cache_dir_name(self): + return f"l_max-{self.l_max}-level-{self.level}-min_freq-{self.min_freq}-append_bos-{self.append_bos}-append_eos-{self.append_eos}" + +class TabularDataset(torch.utils.data.Dataset): + def __init__( + self, + path, + format, + col_idx=None, + skip_header=False, + csv_reader_params=None, + ): + """ + col_idx: the indices of the columns. + """ + if csv_reader_params is None: + csv_reader_params = {} + format = format.lower() + assert format in ["tsv", "csv"] + with io.open(os.path.expanduser(path), encoding="utf8") as f: + if format == "csv": + reader = torchtext.utils.unicode_csv_reader(f, **csv_reader_params) + elif format == "tsv": + reader = torchtext.utils.unicode_csv_reader( + f, delimiter="\t", **csv_reader_params + ) + else: + reader = f + if skip_header: + next(reader) + self._data = [ + line if col_idx is None else [line[c] for c in col_idx] + for line in reader + ] + + def __len__(self): + return len(self._data) + + def __getitem__(self, idx): + return self._data[idx] + + +# LRA tokenizer renames ']' to 'X' and delete parentheses as their tokenizer removes +# non-alphanumeric characters. +# https://github.com/google-research/long-range-arena/blob/264227cbf9591e39dd596d2dc935297a2070bdfe/lra_benchmarks/listops/input_pipeline.py#L46 +def listops_tokenizer(s): + return s.translate({ord("]"): ord("X"), ord("("): None, ord(")"): None}).split() + + +class ListOps(SequenceDataset): + _name_ = "listops" + d_output = 10 + l_output = 0 + + @property + def init_defaults(self): + return { + "l_max": 2048, + "append_bos": False, + "append_eos": True, + # 'max_vocab': 20, # Actual size 18 + "n_workers": 4, # Only used for tokenizing dataset + } + + @property + def n_tokens(self): + return len(self.vocab) + + @property + def _cache_dir_name(self): + return f"l_max-{self.l_max}-append_bos-{self.append_bos}-append_eos-{self.append_eos}" + + def init(self): + if self.data_dir is None: + self.data_dir = default_data_path / self._name_ + self.cache_dir = self.data_dir / self._cache_dir_name + + def prepare_data(self): + if self.cache_dir is None: + for split in ["train", "val", "test"]: + split_path = self.data_dir / f"basic_{split}.tsv" + if not split_path.is_file(): + raise FileNotFoundError( + f""" + File {str(split_path)} not found. + To get the dataset, download lra_release.gz from + https://github.com/google-research/long-range-arena, + then unzip it with tar -xvf lra_release.gz. + Then point data_dir to the listops-1000 directory. + """ + ) + else: # Process the dataset and save it + self.process_dataset() + + def setup(self, stage=None): + if stage == "test" and hasattr(self, "dataset_test"): + return + dataset, self.tokenizer, self.vocab = self.process_dataset() + self.vocab_size = len(self.vocab) + dataset.set_format(type="torch", columns=["input_ids", "Target"]) + self.dataset_train, self.dataset_val, self.dataset_test = ( + dataset["train"], + dataset["val"], + dataset["test"], + ) + + def collate_batch(batch): + xs, ys = zip(*[(data["input_ids"], data["Target"]) for data in batch]) + lengths = torch.tensor([len(x) for x in xs]) + xs = nn.utils.rnn.pad_sequence( + xs, padding_value=self.vocab[""], batch_first=True + ) + ys = torch.tensor(ys) + return xs, ys, {"lengths": lengths} + + self._collate_fn = collate_batch + + def process_dataset(self): + cache_dir = ( + None if self.cache_dir is None else self.cache_dir / self._cache_dir_name + ) + if cache_dir is not None: + if cache_dir.is_dir(): + return self._load_from_cache(cache_dir) + + dataset = load_dataset( + "csv", + data_files={ + "train": str(self.data_dir / "basic_train.tsv"), + "val": str(self.data_dir / "basic_val.tsv"), + "test": str(self.data_dir / "basic_test.tsv"), + }, + delimiter="\t", + keep_in_memory=True, + ) + + tokenizer = listops_tokenizer + + # Account for and tokens + l_max = self.l_max - int(self.append_bos) - int(self.append_eos) + tokenize = lambda example: {"tokens": tokenizer(example["Source"])[:l_max]} + dataset = dataset.map( + tokenize, + remove_columns=["Source"], + keep_in_memory=True, + load_from_cache_file=False, + num_proc=max(self.n_workers, 1), + ) + vocab = torchtext.vocab.build_vocab_from_iterator( + dataset["train"]["tokens"], + specials=( + ["", ""] + + ([""] if self.append_bos else []) + + ([""] if self.append_eos else []) + ), + ) + vocab.set_default_index(vocab[""]) + + numericalize = lambda example: { + "input_ids": vocab( + ([""] if self.append_bos else []) + + example["tokens"] + + ([""] if self.append_eos else []) + ) + } + dataset = dataset.map( + numericalize, + remove_columns=["tokens"], + keep_in_memory=True, + load_from_cache_file=False, + num_proc=max(self.n_workers, 1), + ) + + if cache_dir is not None: + self._save_to_cache(dataset, tokenizer, vocab, cache_dir) + return dataset, tokenizer, vocab + + def _save_to_cache(self, dataset, tokenizer, vocab, cache_dir): + cache_dir = self.cache_dir / self._cache_dir_name + logger = logging.getLogger(__name__) + logger.info(f"Saving to cache at {str(cache_dir)}") + dataset.save_to_disk(str(cache_dir)) + with open(cache_dir / "tokenizer.pkl", "wb") as f: + pickle.dump(tokenizer, f) + with open(cache_dir / "vocab.pkl", "wb") as f: + pickle.dump(vocab, f) + + def _load_from_cache(self, cache_dir): + assert cache_dir.is_dir() + logger = logging.getLogger(__name__) + logger.info(f"Load from cache at {str(cache_dir)}") + dataset = DatasetDict.load_from_disk(str(cache_dir)) + with open(cache_dir / "tokenizer.pkl", "rb") as f: + tokenizer = pickle.load(f) + with open(cache_dir / "vocab.pkl", "rb") as f: + vocab = pickle.load(f) + return dataset, tokenizer, vocab + +class PathFinderDataset(torch.utils.data.Dataset): + """Path Finder dataset.""" + + # There's an empty file in the dataset + blacklist = {"pathfinder32/curv_baseline/imgs/0/sample_172.png"} + + def __init__(self, data_dir, transform=None): + """ + Args: + data_dir (string): Directory with all the images. + transform (callable, optional): Optional transform to be applied + on a sample. + """ + self.data_dir = Path(data_dir).expanduser() + assert self.data_dir.is_dir(), f"data_dir {str(self.data_dir)} does not exist" + self.transform = transform + samples = [] + # for diff_level in ['curv_baseline', 'curv_contour_length_9', 'curv_contour_length_14']: + for diff_level in ["curv_contour_length_14"]: + path_list = sorted( + list((self.data_dir / diff_level / "metadata").glob("*.npy")), + key=lambda path: int(path.stem), + ) + assert path_list, "No metadata found" + for metadata_file in path_list: + with open(metadata_file, "r") as f: + for metadata in f.read().splitlines(): + metadata = metadata.split() + image_path = Path(diff_level) / metadata[0] / metadata[1] + if ( + str(Path(self.data_dir.stem) / image_path) + not in self.blacklist + ): + label = int(metadata[3]) + samples.append((image_path, label)) + self.samples = samples + + def __len__(self): + return len(self.samples) + + def __getitem__(self, idx): + path, target = self.samples[idx] + # https://github.com/pytorch/vision/blob/9b29f3f22783112406d9c1a6db47165a297c3942/torchvision/datasets/folder.py#L247 + with open(self.data_dir / path, "rb") as f: + sample = Image.open(f).convert("L") # Open in grayscale + if self.transform is not None: + sample = self.transform(sample) + return sample, target + +class PathFinder(ImageResolutionSequenceDataset): + _name_ = "pathfinder" + d_input = 1 + d_output = 2 + l_output = 0 + + @property + def n_tokens(self): + if self.tokenize: + return 256 + + @property + def init_defaults(self): + return { + "resolution": 32, + "sequential": True, + "tokenize": False, + "center": True, + "pool": 1, + "val_split": 0.1, + "test_split": 0.1, + "seed": 42, # Controls the train/val/test split + } + + def default_transforms(self): + transform_list = [torchvision.transforms.ToTensor()] + if self.pool > 1: + transform_list.append( + Reduce( + "1 (h h2) (w w2) -> 1 h w", + "mean", + h2=self.pool, + w2=self.pool, + ) + ) + if self.tokenize: + transform_list.append( + torchvision.transforms.Lambda(lambda x: (x * 255).long()) + ) + else: + if self.center: + transform_list.append(torchvision.transforms.Normalize(mean=0.5, std=0.5)) + if self.sequential: + # If tokenize, it makes more sense to get rid of the channel dimension + transform_list.append( + Rearrange("1 h w -> (h w)") + if self.tokenize + else Rearrange("1 h w -> (h w) 1") + ) + else: + transform_list.append(Rearrange("1 h w -> h w 1")) + return torchvision.transforms.Compose(transform_list) + + def prepare_data(self): + if not self.data_dir.is_dir(): + raise FileNotFoundError( + f""" + Directory {str(self.data_dir)} not found. + To get the dataset, download lra_release.gz from + https://github.com/google-research/long-range-arena, + then unzip it with tar -xvf lra_release.gz. + Then point data_dir to the pathfinderX directory, where X is either 32, 64, 128, or 256. + """ + ) + + def setup(self, stage=None): + if self.data_dir is None: + self.data_dir = ( + default_data_path / self._name_ / f"pathfinder{self.resolution}" + ) + + if stage == "test" and hasattr(self, "dataset_test"): + return + # [2021-08-18] TD: I ran into RuntimeError: Too many open files. + # https://github.com/pytorch/pytorch/issues/11201 + # torch.multiprocessing.set_sharing_strategy("file_system") + dataset = PathFinderDataset(self.data_dir, transform=self.default_transforms()) + len_dataset = len(dataset) + val_len = int(self.val_split * len_dataset) + test_len = int(self.test_split * len_dataset) + train_len = len_dataset - val_len - test_len + ( + self.dataset_train, + self.dataset_val, + self.dataset_test, + ) = torch.utils.data.random_split( + dataset, + [train_len, val_len, test_len], + generator=torch.Generator().manual_seed(self.seed), + ) + +class AAN(SequenceDataset): + _name_ = "aan" + d_output = 2 # Use accuracy instead of binary_accuracy + l_output = 0 + + @property + def n_tokens(self): + return len(self.vocab) + + @property + def init_defaults(self): + return { + "l_max": 4000, + # 'max_vocab': 100, # Full size 98 + "append_bos": False, + "append_eos": True, + "n_workers": 4, # For tokenizing only + } + + @property + def _cache_dir_name(self): + return f"l_max-{self.l_max}-append_bos-{self.append_bos}-append_eos-{self.append_eos}" + + def init(self): + if self.data_dir is None: + self.data_dir = default_data_path / self._name_ + self.cache_dir = self.data_dir / self._cache_dir_name + + def prepare_data(self): + if self.cache_dir is None: + for split in ["train", "eval", "test"]: + split_path = self.data_dir / f"new_aan_pairs.{split}.tsv" + if not split_path.is_file(): + raise FileNotFoundError( + f""" + File {str(split_path)} not found. + To get the dataset, download lra_release.gz from + https://github.com/google-research/long-range-arena, + then unzip it with tar -xvf lra_release.gz. + Then point data_dir to the tsv_data directory. + """ + ) + else: # Process the dataset and save it + self.process_dataset() + + def setup(self, stage=None): + if stage == "test" and hasattr(self, "dataset_test"): + return + + # [2021-08-18] TD: I ran into RuntimeError: Too many open files. + # https://github.com/pytorch/pytorch/issues/11201 + # torch.multiprocessing.set_sharing_strategy("file_system") + + dataset, self.tokenizer, self.vocab = self.process_dataset() + # self.vocab_size = len(self.vocab) + print("AAN vocab size:", len(self.vocab)) + + dataset.set_format(type="torch", columns=["input_ids1", "input_ids2", "label"]) + self.dataset_train, self.dataset_val, self.dataset_test = ( + dataset["train"], + dataset["val"], + dataset["test"], + ) + + def collate_batch(batch): + xs1, xs2, ys = zip( + *[ + (data["input_ids1"], data["input_ids2"], data["label"]) + for data in batch + ] + ) + lengths1 = torch.tensor([len(x) for x in xs1]) + lengths2 = torch.tensor([len(x) for x in xs2]) + xs1 = nn.utils.rnn.pad_sequence( + xs1, padding_value=self.vocab[""], batch_first=True + ) + xs2 = nn.utils.rnn.pad_sequence( + xs2, padding_value=self.vocab[""], batch_first=True + ) + # Pad both to same length + # Shape (batch, length) + L = max(xs1.size(1), xs2.size(1)) + xs1 = F.pad(xs1, (0, L-xs1.size(1)), value=self.vocab[""]) + xs2 = F.pad(xs2, (0, L-xs2.size(1)), value=self.vocab[""]) + ys = torch.tensor(ys) + # return xs1, xs2, ys, lengths1, lengths2 + + # Concatenate two batches + xs = torch.cat([xs1, xs2], dim=0) + lengths = torch.cat([lengths1, lengths2], dim=0) + return xs, ys, {"lengths": lengths} + + self._collate_fn = collate_batch + + def process_dataset(self): + cache_dir = ( + None if self.cache_dir is None else self.cache_dir / self._cache_dir_name + ) + if cache_dir is not None: + if cache_dir.is_dir(): + return self._load_from_cache(cache_dir) + + dataset = load_dataset( + "csv", + data_files={ + "train": str(self.data_dir / "new_aan_pairs.train.tsv"), + "val": str(self.data_dir / "new_aan_pairs.eval.tsv"), + "test": str(self.data_dir / "new_aan_pairs.test.tsv"), + }, + delimiter="\t", + column_names=["label", "input1_id", "input2_id", "text1", "text2"], + keep_in_memory=True, + ) # True) + dataset = dataset.remove_columns(["input1_id", "input2_id"]) + new_features = dataset["train"].features.copy() + new_features["label"] = Value("int32") + dataset = dataset.cast(new_features) + + tokenizer = list # Just convert a string to a list of chars + # Account for and tokens + l_max = self.l_max - int(self.append_bos) - int(self.append_eos) + tokenize = lambda example: { + "tokens1": tokenizer(example["text1"])[:l_max], + "tokens2": tokenizer(example["text2"])[:l_max], + } + dataset = dataset.map( + tokenize, + remove_columns=["text1", "text2"], + keep_in_memory=True, + load_from_cache_file=False, + num_proc=max(self.n_workers, 1), + ) + vocab = torchtext.vocab.build_vocab_from_iterator( + dataset["train"]["tokens1"] + dataset["train"]["tokens2"], + specials=( + ["", ""] + + ([""] if self.append_bos else []) + + ([""] if self.append_eos else []) + ), + ) + vocab.set_default_index(vocab[""]) + + encode = lambda text: vocab( + ([""] if self.append_bos else []) + + text + + ([""] if self.append_eos else []) + ) + numericalize = lambda example: { + "input_ids1": encode(example["tokens1"]), + "input_ids2": encode(example["tokens2"]), + } + dataset = dataset.map( + numericalize, + remove_columns=["tokens1", "tokens2"], + keep_in_memory=True, + load_from_cache_file=False, + num_proc=max(self.n_workers, 1), + ) + + if cache_dir is not None: + self._save_to_cache(dataset, tokenizer, vocab, cache_dir) + return dataset, tokenizer, vocab + + def _save_to_cache(self, dataset, tokenizer, vocab, cache_dir): + cache_dir = self.cache_dir / self._cache_dir_name + logger = logging.getLogger(__name__) + logger.info(f"Saving to cache at {str(cache_dir)}") + dataset.save_to_disk(str(cache_dir)) + with open(cache_dir / "tokenizer.pkl", "wb") as f: + pickle.dump(tokenizer, f) + with open(cache_dir / "vocab.pkl", "wb") as f: + pickle.dump(vocab, f) + + def _load_from_cache(self, cache_dir): + assert cache_dir.is_dir() + logger = logging.getLogger(__name__) + logger.info(f"Load from cache at {str(cache_dir)}") + dataset = DatasetDict.load_from_disk(str(cache_dir)) + with open(cache_dir / "tokenizer.pkl", "rb") as f: + tokenizer = pickle.load(f) + with open(cache_dir / "vocab.pkl", "rb") as f: + vocab = pickle.load(f) + return dataset, tokenizer, vocab \ No newline at end of file diff --git a/src/clm/src/dataloaders/synthetics.py b/src/clm/src/dataloaders/synthetics.py new file mode 100644 index 00000000..de7e0f0d --- /dev/null +++ b/src/clm/src/dataloaders/synthetics.py @@ -0,0 +1,335 @@ +'''Synthetic datasets to test in-context learning ability.''' + +import os +import torch +from torch.utils.data import TensorDataset, Dataset, DataLoader +from typing import Dict +import numpy as np +from tqdm import tqdm +from collections import Counter + +from clm.src.dataloaders.base import SequenceDataset + +class Vocab: + """Custom vocab.""" + def __init__(self, vocab_size: int, special_vocabs: Dict): + # Special tokens hold copy_prefix and noop/pad token etc + assert "copy_prefix" in special_vocabs + self.special_vocabs = special_vocabs + vocab = [str(v) for v in list(range(vocab_size))] + self.non_special_vocab = sorted(list(vocab)) + self.vocab = sorted(list(set(vocab + list(self.special_vocabs.values())))) + self.v2id = {v:i for i,v in enumerate(self.vocab)} + self.vocab_size = len(vocab) + + def get_next_vocab(self, token: str): + """Gets next token excluding special_vocabs.""" + id = (self.get_id(token) + 1) % self.vocab_size + while self.get_vocab(id) in self.special_vocabs: + id = (id + 1) % self.vocab_size + return self.get_vocab(id) + + @property + def copy_prefix(self): + return self.special_vocabs["copy_prefix"] + + @property + def noop(self): + return self.special_vocabs["noop"] + + @property + def special_tokens(self): + return set(self.special_vocabs.values()) + + def get_id(self, token: str): + return self.v2id[token] + + def get_vocab(self, id: int): + return self.vocab[id] + + def __len__(self): + return len(self.vocab) + + +class Tokenizer: + """Custom Tokenizer for our own vocab.""" + def __init__(self, vocab: Vocab): + self.vocab = vocab + + def tokenize(self, text: str, return_tensor=False, mask_input=False): + input_ids = [self.vocab.get_id(t) for t in text.split()] + if self.vocab.get_id(self.vocab.copy_prefix) not in input_ids: + raise ValueError("Input text must contain copy_prefix token.") + copy_prefix_pos = input_ids.index(self.vocab.get_id(self.vocab.copy_prefix)) + labels = input_ids + if mask_input: + # Mask the input tokens for loss but do not mask the copied token + labels = [-100] * (copy_prefix_pos+1) + labels[copy_prefix_pos+1:] + if return_tensor: + input_ids = torch.LongTensor(input_ids) + labels = torch.LongTensor(labels) + return { + "input_ids": input_ids, + "labels": labels, + } + + def decode(self, ids: list): + return " ".join([self.vocab.get_vocab(id) for id in ids]) + +def generate_start_seq(vocab: Vocab, input_seq_len: int, rng: np.random.Generator): + """Generate token sequence up to and including the copy_prefix token.""" + vocab_seq = rng.choice( + vocab.vocab, + input_seq_len, + replace=True, + # Do not generate any special tokens + p=[1/(len(vocab)-len(vocab.special_tokens)) if p not in vocab.special_tokens else 0 for p in vocab.vocab]) + vocab_seq = np.append(vocab_seq, vocab.copy_prefix) + return vocab_seq.tolist() + +def generate_induction_head( + vocab: Vocab, + input_seq_len: int, + copy_prefix: str, + induction_len: int, + num_triggers: int, + rng: np.random.Generator, + valid_chars: list = None, +): + """Generate sequence where the copy prefix is inserted into the input + and then the character after the copy prefix is copied at the end. + """ + if valid_chars is not None: + raise NotImplementedError("Valid chars not implemented for induction heads.") + vocab_seq = generate_start_seq(vocab, input_seq_len, rng) + if rng.uniform() < 0.5: + num_triggers = 1 + pos = sorted(rng.integers( + input_seq_len - (1 + induction_len), size=num_triggers + )) + pos_filtered = [] + for i, p in enumerate(pos): + if i == 0: + pos_filtered.append(p) + elif p - pos_filtered[-1] > induction_len: + pos_filtered.append(p) + to_copy = [ + vocab_seq[pos_filtered[0]+1+i] + for i in range(induction_len) + ] + for pos in pos_filtered: + vocab_seq[pos] = copy_prefix + for i in range(induction_len): + vocab_seq[pos+1+i] = to_copy[i] + # if valid_chars is not None and to_copy not in valid_chars: + # vocab_seq[pos+1] = rng.choice(valid_chars) + # to_copy = vocab_seq[pos+1] + vocab_seq = vocab_seq + to_copy + return " ".join(vocab_seq) + +def generate_assoc_recall( + vocab: Vocab, + input_seq_len: int, + num_keys: int, + rng: np.random.Generator, + allow_dot: bool = True, + valid_chars: list = None, +): + """Generate sequence where the input has a sequence of key value pairs + and the copy prefix at the end, and then a key value pair is inserted + after the copy prefix.""" + non_special_vocab_size = len(vocab.non_special_vocab) + keys = vocab.non_special_vocab[:non_special_vocab_size // 2] + values = vocab.non_special_vocab[non_special_vocab_size // 2:] + keys_multi = [ [key] for key in keys ] + for i in range(num_keys-1): + keys_multi = [ key + [key2] for key in keys_multi for key2 in keys ] + kv_map = { + tuple(k): rng.choice(values) for k in keys_multi + } + + key_present = {} + vocab_seq = [] + for _ in range(input_seq_len // (num_keys + 1)): + k = tuple(rng.choice(list(kv_map.keys()))) + v = kv_map[k] + vocab_seq += list(k) + [v] + key_present[k] = True + # vocab_seq.append(v) + + + k = tuple(rng.choice(list(kv_map.keys()))) + if not allow_dot: + while k not in key_present: + k = tuple(rng.choice(list(key_present.keys()))) + to_copy = [vocab.copy_prefix] + list(k) + [ kv_map[k] if k in key_present else vocab.noop ] + vocab_seq = vocab_seq + to_copy + return " ".join(vocab_seq) + +class ICLDataModule(SequenceDataset): + _name_ = "icl_synthetics" + + def __init__( + self, + num_examples: int, + num_test_examples: int, + vocab_size: int, + input_seq_len: int, + copy_method: str, + number_duplicates_per_epoch: int = 0, + seed: int = 0, + batch_size: int = 32, + split_train_test: bool = False, + induction_len: int = 1, + induction_num_triggers: int = 1, + allow_dot: bool = False, + max_copy_len: int = 10, + test_seq_len: int = None, + num_keys: int = 1, # number of keys for associative recall, + data_dir: str = None, + *args, **kwargs + ): + self.num_examples = num_examples + self.num_test_examples = num_test_examples + self.input_seq_len = input_seq_len + self.vocab_size = vocab_size + self.copy_method = copy_method + assert copy_method in ["induction_head", "assoc_recall"] + self.number_duplicates_per_epoch = number_duplicates_per_epoch + self.seed = seed + self.batch_size = batch_size + self.split_train_test = split_train_test # let the same copy chars appear in train/test + self.induction_len = induction_len + self.induction_num_triggers = induction_num_triggers + self.allow_dot = allow_dot + self.max_copy_len = max_copy_len + self.data_dir = data_dir + + if test_seq_len is not None: + self.test_seq_len = test_seq_len + else: + self.test_seq_len = input_seq_len + self.num_keys = num_keys + + special_vocabs = { + "copy_prefix": "=>", + "noop": "." + } + self.special_vocabs = special_vocabs + self.vocab = Vocab(vocab_size-len(special_vocabs), special_vocabs=special_vocabs) + self.tokenizer = Tokenizer(self.vocab) + + self.num_extra_seq_len = 2 + + if self.copy_method == "induction_head": + self.copy_f = self.generate_induction_head + self.num_extra_seq_len = 1 + self.induction_len + elif self.copy_method == "assoc_recall": + self.copy_f = self.generate_assoc_recall + self.num_extra_seq_len = 1 + self.num_keys + else: + self.copy_f = None + + if self.number_duplicates_per_epoch > 0: + self.duplicate_ex = self.generate_example() + self.duplicate_index = max(int(self.num_examples / self.number_duplicates_per_epoch), 1) + else: + self.duplicate_ex = None + self.duplicate_index = -1 + + self.total_seq_len = self.input_seq_len + self.num_extra_seq_len + + def generate_induction_head(self, seqlen=None, valid_chars=None): + return generate_induction_head(self.vocab, seqlen if seqlen is not None else self.input_seq_len, self.special_vocabs["copy_prefix"], self.induction_len, self.induction_num_triggers, self.rng, valid_chars=valid_chars) + + def generate_assoc_recall(self, seqlen=None, valid_chars=None): + return generate_assoc_recall(self.vocab, seqlen if seqlen is not None else self.input_seq_len, self.num_keys, self.rng, allow_dot = self.allow_dot, valid_chars=valid_chars) + + def generate_example(self, seqlen=None, valid_chars=None): + vocab_seq = self.copy_f(seqlen=seqlen, valid_chars=valid_chars) + return self.tokenizer.tokenize(vocab_seq, return_tensor=True) + + def setup(self, stage=None): + train_tensor = test_tensor = None + if self.data_dir is not None: + try: + train_tensor = torch.load(os.path.join(self.data_dir, + f"train_{self.copy_method}_{self.num_examples}_{self.vocab_size}_{self.input_seq_len}.pt")) + test_tensor = torch.load(os.path.join(self.data_dir, + f"test_{self.copy_method}_{self.num_examples}_{self.vocab_size}_{self.input_seq_len}.pt")) + except: + pass + + if train_tensor is None or test_tensor is None: + if hasattr(self, 'dataset'): + return + self.rng = np.random.default_rng(self.seed) + + if self.split_train_test: + all_vocab = self.vocab.non_special_vocab + train_vocab = set(self.rng.choice(all_vocab, size=len(all_vocab) // 2, replace=False)) + test_vocab = set(all_vocab) - train_vocab + train_vocab = list(train_vocab) + test_vocab = list(test_vocab) + else: + train_vocab = None + test_vocab = None + + all_examples = [] + for i, (example_count, valid_vocab) in enumerate(zip([self.num_examples, self.num_test_examples], [train_vocab, test_vocab])): + examples = torch.stack([self.generate_example( + seqlen=self.input_seq_len if i == 0 else self.test_seq_len, + valid_chars=valid_vocab + )['input_ids'] for _ in tqdm(range(example_count))]) + examples = torch.unique(examples, dim=0, sorted=False).tolist() + + while len(examples) < example_count: + new_example = self.generate_example( + seqlen=self.input_seq_len if i == 0 else self.test_seq_len, + valid_chars=valid_vocab + )['input_ids'].tolist() + if new_example not in examples: + examples.append(new_example) + + self.rng.shuffle(examples) + all_examples.append(torch.LongTensor(examples)) + + # all_examples = torch.concat(all_examples) + train_tensor = torch.stack([torch.stack([example[:-1], example[1:]]) for example in all_examples[0]]) + test_tensor = torch.stack([torch.stack([example[:-1], example[1:]]) for example in all_examples[1]]) + test_tensor[:, 1, :-1 * (self.num_extra_seq_len - 1)] = -100 + if self.copy_method in ["assoc_recall"]: + test_tensor[:, 1, :-1] = -100 + if self.copy_method in ["majority", "fom1"]: + train_tensor[:, 1, :-1 * (self.num_extra_seq_len - 1)] = -100 + + if self.data_dir is not None: + torch.save(train_tensor, os.path.join(self.data_dir, + f"train_{self.copy_method}_{self.num_examples}_{self.vocab_size}_{self.input_seq_len}.pt") + ) + torch.save(test_tensor, os.path.join(self.data_dir, + f"test_{self.copy_method}_{self.num_examples}_{self.vocab_size}_{self.input_seq_len}.pt") + ) + + self.dataset = { + 'train': TensorDataset(train_tensor[:, 0, :], train_tensor[:, 1, :]), + 'test': TensorDataset(test_tensor[:, 0, :], test_tensor[:, 1, :]) + } + + def train_dataloader(self, *args, **kwargs): + return self._data_loader(self.dataset['train'], shuffle=True) + + def val_dataloader(self, *args, **kwargs): + return self._data_loader(self.dataset['test'], shuffle=False) + + def test_dataloader(self, *args, **kwargs): + return self._data_loader(self.dataset['test'], shuffle=False) + + def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader: + return DataLoader( + dataset, + batch_size=self.batch_size, + num_workers=10, + shuffle=shuffle, + persistent_workers=True + ) \ No newline at end of file diff --git a/src/clm/src/dataloaders/utils/cifar_augmentations.py b/src/clm/src/dataloaders/utils/cifar_augmentations.py new file mode 100644 index 00000000..3c063edb --- /dev/null +++ b/src/clm/src/dataloaders/utils/cifar_augmentations.py @@ -0,0 +1,138 @@ +""" +Borrowed from https://github.com/hysts/pytorch_image_classification/tree/9ff4248905850c68aa9c09c17914307eb81769e7/pytorch_image_classification/transforms +""" +import torch +import numpy as np +import PIL +import PIL.Image +from PIL.Image import Image + + +class NpNormalize: + def __init__(self, mean: np.ndarray, std: np.ndarray): + self.mean = np.array(mean) + self.std = np.array(std) + + def __call__(self, image: PIL.Image.Image) -> np.ndarray: + image = np.asarray(image).astype(np.float32) / 255. + image = (image - self.mean) / self.std + return image + + +class Cutout(object): + """Randomly mask out one or more patches from an image. + Args: + n_holes (int): Number of patches to cut out of each image. + length (int): The length (in pixels) of each square patch. + """ + + def __init__(self, n_holes, length): + self.n_holes = n_holes + self.length = length + + def __call__(self, img): + """ + Args: + img (Tensor): Tensor image of size (C, H, W). + Returns: + Tensor: Image with n_holes of dimension length x length cut out of it. + """ + h = img.size(1) + w = img.size(2) + + mask = np.ones((h, w), np.float32) + + for n in range(self.n_holes): + y = np.random.randint(h) + x = np.random.randint(w) + + y1 = np.clip(y - self.length // 2, 0, h) + y2 = np.clip(y + self.length // 2, 0, h) + x1 = np.clip(x - self.length // 2, 0, w) + x2 = np.clip(x + self.length // 2, 0, w) + + mask[y1: y2, x1: x2] = 0. + + mask = torch.from_numpy(mask) + mask = mask.expand_as(img) + img = img * mask + + return img + + +# +# class Cutout: +# def __init__(self, p=1.0, mask_size=16, cutout_inside=False, mask_color=0): +# # https://github.com/hysts/pytorch_image_classification/blob/9ff4248905850c68aa9c09c17914307eb81769e7/configs/augmentations/cifar/cutout.yaml +# self.p = p +# self.mask_size = mask_size +# self.cutout_inside = cutout_inside +# self.mask_color = mask_color +# +# self.mask_size_half = self.mask_size // 2 +# self.offset = 1 if self.mask_size % 2 == 0 else 0 +# +# def __call__(self, image: np.ndarray) -> np.ndarray: +# image = np.asarray(image).copy() +# +# if np.random.random() > self.p: +# return image +# +# h, w = image.shape[:2] +# +# if self.cutout_inside: +# cxmin = self.mask_size_half +# cxmax = w + self.offset - self.mask_size_half +# cymin = self.mask_size_half +# cymax = h + self.offset - self.mask_size_half +# else: +# cxmin, cxmax = 0, w + self.offset +# cymin, cymax = 0, h + self.offset +# +# cx = np.random.randint(cxmin, cxmax) +# cy = np.random.randint(cymin, cymax) +# xmin = cx - self.mask_size_half +# ymin = cy - self.mask_size_half +# xmax = xmin + self.mask_size +# ymax = ymin + self.mask_size +# xmin = max(0, xmin) +# ymin = max(0, ymin) +# xmax = min(w, xmax) +# ymax = min(h, ymax) +# image[ymin:ymax, xmin:xmax] = self.mask_color +# return image + + +class RandomErasing: + def __init__(self, p=0.5, max_attempt=20, sl=0.02, sh=0.4, rl=0.3, rh=1. / 0.3): + # https://github.com/hysts/pytorch_image_classification/blob/9ff4248905850c68aa9c09c17914307eb81769e7/configs/augmentations/cifar/random_erasing.yaml + self.p = 0.5 + self.max_attempt = 20 + self.sl, self.sh = 0.02, 0.4 + self.rl = 0.3 + self.rh = 1. / 0.3 + + def __call__(self, image: np.ndarray) -> np.ndarray: + image = np.asarray(image).copy() + + if np.random.random() > self.p: + return image + + h, w = image.shape[:2] + image_area = h * w + + for _ in range(self.max_attempt): + mask_area = np.random.uniform(self.sl, self.sh) * image_area + aspect_ratio = np.random.uniform(self.rl, self.rh) + mask_h = int(np.sqrt(mask_area * aspect_ratio)) + mask_w = int(np.sqrt(mask_area / aspect_ratio)) + + if mask_w < w and mask_h < h: + x0 = np.random.randint(0, w - mask_w) + y0 = np.random.randint(0, h - mask_h) + x1 = x0 + mask_w + y1 = y0 + mask_h + image[y0:y1, x0:x1] = np.random.uniform(0, 1) + break + + return image diff --git a/src/clm/src/dataloaders/utils/timm_mixup.py b/src/clm/src/dataloaders/utils/timm_mixup.py new file mode 100644 index 00000000..333a9c65 --- /dev/null +++ b/src/clm/src/dataloaders/utils/timm_mixup.py @@ -0,0 +1,22 @@ +import torch + +from timm.data import Mixup +from timm.data.mixup import mixup_target + + +class TimmMixup(Mixup): + """ Wrap timm.data.Mixup that avoids the assert that batch size must be even. + """ + def __call__(self, x, target, *args): + if self.mode == 'elem': + lam = self._mix_elem(x) + elif self.mode == 'pair': + # We move the assert from the beginning of the function to here + assert len(x) % 2 == 0, 'Batch size should be even when using this' + lam = self._mix_pair(x) + else: + lam = self._mix_batch(x) + # Another change is to set the right device here + target = mixup_target(target, self.num_classes, lam, self.label_smoothing, + device=target.device) + return x, target, *args \ No newline at end of file diff --git a/src/clm/src/dataloaders/utils/vocabulary.py b/src/clm/src/dataloaders/utils/vocabulary.py new file mode 100644 index 00000000..bdb98936 --- /dev/null +++ b/src/clm/src/dataloaders/utils/vocabulary.py @@ -0,0 +1,237 @@ +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import os +from collections import Counter +from collections import OrderedDict + +import torch + +import clm.src.utils as utils + + +class Vocab(object): + def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True, + delimiter=None, vocab_file=None): + self.counter = Counter() + self.special = special + self.min_freq = min_freq + self.max_size = max_size + self.lower_case = lower_case + self.delimiter = delimiter + self.vocab_file = vocab_file + + def tokenize(self, line, add_eos=False, add_double_eos=False): + line = line.strip() + # convert to lower case + if self.lower_case: + line = line.lower() + + # empty delimiter '' will evaluate False + if self.delimiter == '': + symbols = line + else: + symbols = line.split(self.delimiter) + + if add_double_eos: # lm1b + return [''] + symbols + [''] + elif add_eos: + return symbols + [''] + else: + return symbols + + def count_file(self, path, verbose=False, add_eos=False): + if verbose: + print('counting file {} ...'.format(path)) + assert os.path.exists(path) + + sents = [] + with open(path, 'r', encoding='utf-8') as f: + for idx, line in enumerate(f): + if verbose and idx > 0 and idx % 500000 == 0: + print(' line {}'.format(idx)) + symbols = self.tokenize(line, add_eos=add_eos) + self.counter.update(symbols) + sents.append(symbols) + + return sents + + def count_sents(self, sents, verbose=False): + """ + sents : a list of sentences, each a list of tokenized symbols + """ + if verbose: + print('counting {} sents ...'.format(len(sents))) + for idx, symbols in enumerate(sents): + if verbose and idx > 0 and idx % 500000 == 0: + print(' line {}'.format(idx)) + self.counter.update(symbols) + + def _build_from_file(self, vocab_file): + self.idx2sym = [] + self.sym2idx = OrderedDict() + + with open(vocab_file, 'r', encoding='utf-8') as f: + for line in f: + symb = line.strip().split()[0] + self.add_symbol(symb) + self.unk_idx = self.sym2idx[''] + + def build_vocab(self): + if self.vocab_file: + print('building vocab from {}'.format(self.vocab_file)) + self._build_from_file(self.vocab_file) + print('final vocab size {}'.format(len(self))) + else: + print('building vocab with min_freq={}, max_size={}'.format( + self.min_freq, self.max_size)) + self.idx2sym = [] + self.sym2idx = OrderedDict() + + for sym in self.special: + self.add_special(sym) + + for sym, cnt in self.counter.most_common(self.max_size): + if cnt < self.min_freq: + break + self.add_symbol(sym) + + print('final vocab size {} from {} unique tokens'.format( + len(self), len(self.counter))) + + def encode_file(self, path, ordered=False, verbose=False, add_eos=True, + add_double_eos=False): + if verbose: + print('encoding file {} ...'.format(path)) + assert os.path.exists(path) + encoded = [] + with open(path, 'r', encoding='utf-8') as f: + for idx, line in enumerate(f): + if verbose and idx > 0 and idx % 500000 == 0: + print(' line {}'.format(idx)) + symbols = self.tokenize(line, add_eos=add_eos, + add_double_eos=add_double_eos) + encoded.append(self.convert_to_tensor(symbols)) + + if ordered: + encoded = torch.cat(encoded) + + return encoded + + def encode_sents(self, sents, ordered=False, verbose=False): + if verbose: + print('encoding {} sents ...'.format(len(sents))) + encoded = [] + for idx, symbols in enumerate(sents): + if verbose and idx > 0 and idx % 500000 == 0: + print(' line {}'.format(idx)) + encoded.append(self.convert_to_tensor(symbols)) + + if ordered: + encoded = torch.cat(encoded) + + return encoded + + def add_special(self, sym): + if sym not in self.sym2idx: + self.idx2sym.append(sym) + self.sym2idx[sym] = len(self.idx2sym) - 1 + setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym]) + + def add_symbol(self, sym): + if sym not in self.sym2idx: + self.idx2sym.append(sym) + self.sym2idx[sym] = len(self.idx2sym) - 1 + + def get_sym(self, idx): + assert 0 <= idx < len(self), 'Index {} out of range'.format(idx) + return self.idx2sym[idx] + + def get_idx(self, sym): + if sym in self.sym2idx: + return self.sym2idx[sym] + else: + # print('encounter unk {}'.format(sym)) + assert '' not in sym + assert hasattr(self, 'unk_idx') + return self.sym2idx.get(sym, self.unk_idx) + + def get_symbols(self, indices): + return [self.get_sym(idx) for idx in indices] + + def get_indices(self, symbols): + return [self.get_idx(sym) for sym in symbols] + + def convert_to_tensor(self, symbols): + return torch.LongTensor(self.get_indices(symbols)) + + def convert_to_sent(self, indices, exclude=None): + if exclude is None: + return ' '.join([self.get_sym(idx) for idx in indices]) + else: + return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude]) + + def __len__(self): + return len(self.idx2sym) + + +# Class OpenAIVocab has been adapted from +# https://github.com/cybertronai/transformer-xl/blob/master/utils/vocabulary.py +class OpenAIVocab(Vocab): + def __init__(self, max_size=None, vocab_file=None): + from transformers import GPT2Tokenizer + self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') + self.EOT = self.tokenizer.encoder['<|endoftext|>'] + self.max_size = max_size + self.vocab_file = vocab_file + + pad = 8 + vocab_size = len(self.tokenizer) + padded_vocab_size = (vocab_size + pad - 1) // pad * pad + for i in range(0, padded_vocab_size - vocab_size): + token = f'madeupword{i:09d}' + self.tokenizer.add_tokens([token]) + + def __len__(self): + return len(self.tokenizer) + + def count_file(self, path, verbose=False, add_eos=False): + # TODO: train from scratch, respect self.max_size + pass + + def build_vocab(self): + pass + + def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False) -> torch.LongTensor: + cached = path + '.bpe' + if os.path.exists(cached): + return torch.load(cached) + print(f'encoding file {path} ...') + assert os.path.exists(path), f"{path} doesn't exist" + + with open(path, encoding='utf-8') as f: + # Suppress warnings about length. + with open(os.devnull, "w") as devnull, contextlib.redirect_stderr(devnull): + out = torch.LongTensor(self.tokenizer.encode(f.read()) + [self.EOT]) + with utils.distributed.sync_workers() as rank: + if rank == 0: + torch.save(out, cached) + return out + + def tokenize(self, line, add_eos=False, add_double_eos=False): + return self.tokenizer.encode(line) + + def convert_to_tensor(self, symbols): + return torch.LongTensor(symbols) diff --git a/src/clm/src/dataloaders/vision.py b/src/clm/src/dataloaders/vision.py new file mode 100644 index 00000000..ac44763b --- /dev/null +++ b/src/clm/src/dataloaders/vision.py @@ -0,0 +1,279 @@ +"""Miscellaneous vision datasets.""" + +import os + +import torch +from torch import nn +from torch.nn import functional as F +import torchvision + +from clm.src.dataloaders.base import default_data_path, SequenceDataset + + +class ImageNet(SequenceDataset): + """ + .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2017/08/ + Sample-of-Images-from-the-ImageNet-Dataset-used-in-the-ILSVRC-Challenge.png + :width: 400 + :alt: Imagenet + Specs: + - 1000 classes + - Each image is (3 x varies x varies) (here we default to 3 x 224 x 224) + Imagenet train, val and test dataloaders. + The train set is the imagenet train. + The val split is taken from train if a val_split % is provided, or will be the same as test otherwise + The test set is the official imagenet validation set. + + """ + + _name_ = "imagenet" + d_input = 3 + d_output = 1000 + l_output = 0 + + init_defaults = { + "data_dir": None, + "cache_dir": None, + "image_size": 224, + "val_split": None, # currently not implemented + "train_transforms": None, + "val_transforms": None, + "test_transforms": None, + "mixup": None, # augmentation + "num_aug_repeats": 0, + "num_gpus": 1, + "shuffle": True, # for train + "loader_fft": False, + } + + @property + def num_classes(self) -> int: + """ + Return: + 1000 + """ + return 1000 + + def _verify_splits(self, data_dir: str, split: str) -> None: + dirs = os.listdir(data_dir) + + if split not in dirs: + raise FileNotFoundError( + f"a {split} Imagenet split was not found in {data_dir}," + f" make sure the folder contains a subfolder named {split}" + ) + + def prepare_data(self) -> None: + """This method already assumes you have imagenet2012 downloaded. It validates the data using the meta.bin. + .. warning:: Please download imagenet on your own first. + """ + if not self.use_archive_dataset: + self._verify_splits(self.data_dir, "train") + self._verify_splits(self.data_dir, "val") + else: + if not self.data_dir.is_file(): + raise FileNotFoundError(f"""Archive file {str(self.data_dir)} not found.""") + + def setup(self, stage=None): + """Creates train, val, and test dataset.""" + + from typing import Any, Callable, List, Optional, Union + + import hydra # for mixup + from pl_bolts.transforms.dataset_normalizations import \ + imagenet_normalization + from torch.utils.data import Dataset + from torch.utils.data.dataloader import default_collate + from torchvision.datasets import ImageFolder + + # for access in other methods + self.imagenet_normalization = imagenet_normalization + self.default_collate = default_collate + self.hydra = hydra + self.ImageFolder = ImageFolder + + if self.mixup is not None: + self.mixup_fn = hydra.utils.instantiate(self.mixup) + else: + self.mixup_fn = None + + self.dir_path = self.data_dir or default_data_path / self._name_ + + if stage == "fit" or stage is None: + self.set_phase([self.image_size]) + + if stage == "test" or stage is None: + test_transforms = (self.val_transform() if self.test_transforms is None + else hydra.utils.instantiate(self.test_transforms)) + + self.dataset_test = ImageFolder(os.path.join(self.dir_path, 'val'), transform=test_transforms) + + # # modded, override (for debugging) + # self.dataset_test = self.dataset_val + + def set_phase(self, stage_params=[224], val_upsample=False, test_upsample=False): + """ + For progresive learning. + Will modify train transform parameters during training, just image size for now, + and create a new train dataset, which the train_dataloader will load every + n epochs (in config). + + Later, will be possible to change magnitude of RandAug here too, and mixup alpha + + stage_params: list, list of values to change. single [image_size] for now + """ + + img_size = int(stage_params[0]) + + if val_upsample: + self.val_transforms["input_size"] = img_size + + train_transforms = (self.train_transform() if self.train_transforms is None + else self.hydra.utils.instantiate(self.train_transforms)) + val_transforms = (self.val_transform() if self.val_transforms is None + else self.hydra.utils.instantiate(self.val_transforms)) + + if self.loader_fft: + train_transforms = torchvision.transforms.Compose( + train_transforms.transforms + [ + torchvision.transforms.Lambda(lambda x: torch.fft.rfftn(x, s=tuple([2*l for l in x.shape[1:]]))) + ] + ) + val_transforms = torchvision.transforms.Compose( + val_transforms.transforms + [ + torchvision.transforms.Lambda(lambda x: torch.fft.rfftn(x, s=tuple([2*l for l in x.shape[1:]]))) + ] + ) + + self.dataset_train = self.ImageFolder(self.dir_path / 'train', + transform=train_transforms) + + if self.val_split > 0.: + # this will create the val split + self.split_train_val(self.val_split) + # will use the test split as val by default + else: + self.dataset_val = self.ImageFolder(self.dir_path / 'val', transform=val_transforms) + + # # modded, override (for debugging) + # self.dataset_train = self.dataset_val + + # not sure if normally you upsample test also + if test_upsample: + self.test_transforms["input_size"] = img_size + test_transforms = (self.val_transform() if self.test_transforms is None + else self.hydra.utils.instantiate(self.test_transforms)) + self.dataset_test = self.ImageFolder(os.path.join(self.dir_path, 'val'), transform=test_transforms) + ## modded, override (for debugging) + # self.dataset_test = self.dataset_val + + # could modify mixup by reinstantiating self.mixup_fn (later maybe) + + def train_transform(self): + """The standard imagenet transforms. + .. code-block:: python + transforms.Compose([ + transforms.RandomResizedCrop(self.image_size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ), + ]) + """ + preprocessing = torchvision.transforms.Compose( + [ + torchvision.transforms.RandomResizedCrop(self.image_size), + torchvision.transforms.RandomHorizontalFlip(), + torchvision.transforms.ToTensor(), + self.imagenet_normalization(), + ] + ) + + return preprocessing + + def val_transform(self): + """The standard imagenet transforms for validation. + .. code-block:: python + transforms.Compose([ + transforms.Resize(self.image_size + 32), + transforms.CenterCrop(self.image_size), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ), + ]) + """ + + preprocessing = torchvision.transforms.Compose( + [ + torchvision.transforms.Resize(self.image_size + 32), + torchvision.transforms.CenterCrop(self.image_size), + torchvision.transforms.ToTensor(), + self.imagenet_normalization(), + ] + ) + return preprocessing + + def train_dataloader(self, **kwargs): + """ The train dataloader """ + if self.num_aug_repeats == 0 or self.num_gpus == 1: + shuffle = self.shuffle + sampler = None + else: + shuffle = False + from timm.data.distributed_sampler import RepeatAugSampler + sampler = RepeatAugSampler(self.dataset_train, num_repeats=self.num_aug_repeats) + + # calculate resolution + resolution = self.image_size / self.train_transforms['input_size'] # usually 1.0 + + return (self._data_loader(self.dataset_train, shuffle=shuffle, mixup=self.mixup_fn, sampler=sampler, resolution=resolution, **kwargs)) + + def val_dataloader(self, **kwargs): + """ The val dataloader """ + kwargs['drop_last'] = False + + # update batch_size for eval if provided + batch_size = kwargs.get("batch_size_eval", None) or kwargs.get("batch_size") + kwargs["batch_size"] = batch_size + + # calculate resolution + resolution = self.image_size / self.val_transforms['input_size'] # usually 1.0 or 0.583 + + return (self._data_loader(self.dataset_val, resolution=resolution, **kwargs)) + + def test_dataloader(self, **kwargs): + """ The test dataloader """ + kwargs['drop_last'] = False + + # update batch_size for test if provided + batch_size = kwargs.get("batch_size_test", None) or kwargs.get("batch_size_eval", None) or kwargs.get("batch_size") + kwargs["batch_size"] = batch_size + + # calculate resolution + resolution = self.image_size / self.test_transforms.get("input_size", self.val_transforms['input_size']) + + return (self._data_loader(self.dataset_test, resolution=resolution, **kwargs)) + + def _data_loader(self, dataset, resolution, shuffle=False, mixup=None, sampler=None, **kwargs): + # collate_fn = (lambda batch: mixup(*self.default_collate(batch))) if mixup is not None else self.default_collate + collate_fn = (lambda batch: mixup(*self.collate_with_resolution(batch, resolution))) if mixup is not None else lambda batch: self.collate_with_resolution(batch, resolution) + + # hacked - can't pass this this arg to dataloader, but used to update the batch_size val / test + kwargs.pop('batch_size_eval', None) + kwargs.pop('batch_size_test', None) + + return torch.utils.data.DataLoader( + dataset, + collate_fn=collate_fn, + shuffle=shuffle, + sampler=sampler, + **kwargs, + ) + + def collate_with_resolution(self, batch, resolution): + stuff = self.default_collate(batch) + return *stuff, {"resolution": resolution} diff --git a/src/clm/src/models/__init__.py b/src/clm/src/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/clm/src/models/baselines/vit_all.py b/src/clm/src/models/baselines/vit_all.py new file mode 100644 index 00000000..d2a18b6d --- /dev/null +++ b/src/clm/src/models/baselines/vit_all.py @@ -0,0 +1,433 @@ +""" +The original Vision Transformer (ViT) from timm, copyright belongs to / Copyright 2020 Ross Wightman +""" +import math +import logging + +from functools import partial +from collections import OrderedDict +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg +from timm.models.layers import PatchEmbed, Mlp, trunc_normal_, lecun_normal_ + +from clm.src.models.sequence.base import SequenceModule +from clm.src.models.nn.components import Normalization +from clm.src.models.sequence.block import SequenceResidualBlock +from clm.src.utils.config import to_list, to_dict + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, + 'input_size': (3, 224, 224), + 'pool_size': None, + 'classifier': 'head', + **kwargs, + } + + +default_cfgs = { + # patch models (my experiments) + 'vit_small_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', + ), + + # patch models (weights ported from official Google JAX impl) + 'vit_base_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + ), +} + + +class VisionTransformer(SequenceModule): + """ Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` + - https://arxiv.org/abs/2012.12877 + """ + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + d_model=768, + depth=12, + expand=4, + representation_size=None, + distilled=False, + dropout=0., + drop_path_rate=0., + embed_layer=PatchEmbed, + norm='layer', + weight_init='', + layer=None, + transposed=False, + layer_reps=1, + use_pos_embed=False, + use_cls_token=False, + track_norms=False, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + d_model (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + distilled (bool): model includes a distillation token and head as in DeiT models + dropout (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + weight_init: (str): weight init scheme + """ + super().__init__() + self.num_classes = num_classes + self.num_features = self.d_model = d_model # num_features for consistency with other models + self.num_tokens = 2 if distilled else 1 + self.use_pos_embed = use_pos_embed + self.use_cls_token = use_cls_token + + self.track_norms = track_norms + + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=d_model, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = None + self.dist_token = None + if use_cls_token: + self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) + self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model)) if distilled else None + else: + assert not distilled, 'Distillation token not supported without class token' + + self.pos_embed = None + if use_pos_embed: + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, d_model)) + self.pos_drop = nn.Dropout(p=dropout) + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + self.transposed = transposed + + layer = to_list(layer, recursive=False) * layer_reps + + # Some special arguments are passed into each layer + for _layer in layer: + # If layers don't specify dropout, add it + if _layer.get('dropout', None) is None: + _layer['dropout'] = dropout + # Ensure all layers are shaped the same way + _layer['transposed'] = transposed + + # Config for the inverted bottleneck + ff_cfg = { + '_name_': 'ff', + 'expand': int(expand), + 'transposed': self.transposed, + 'activation': 'gelu', + 'initializer': None, + 'dropout': dropout, + } + + blocks = [] + for i in range(depth): + for _layer in layer: + blocks.append( + SequenceResidualBlock( + d_input=d_model, + i_layer=i, + prenorm=True, + dropout=dropout, + layer=_layer, + residual='R', + norm=norm, + pool=None, + drop_path=dpr[i], + ) + ) + if expand > 0: + blocks.append( + SequenceResidualBlock( + d_input=d_model, + i_layer=i, + prenorm=True, + dropout=dropout, + layer=ff_cfg, + residual='R', + norm=norm, + pool=None, + drop_path=dpr[i], + ) + ) + self.blocks = nn.Sequential(*blocks) + + if norm is None: + self.norm = None + elif isinstance(norm, str): + self.norm = Normalization(d_model, transposed=self.transposed, _name_=norm) + else: + self.norm = Normalization(d_model, transposed=self.transposed, **norm) + + # Representation layer: generally defaults to nn.Identity() + if representation_size and not distilled: + self.num_features = representation_size + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(d_model, representation_size)), + ('act', nn.Tanh()) + ])) + else: + self.pre_logits = nn.Identity() + + # Classifier head(s): TODO: move to decoder + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = None + if distilled: + self.head_dist = nn.Linear(self.d_model, self.num_classes) if num_classes > 0 else nn.Identity() + + # Weight init + assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '') + head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0. + if self.pos_embed is not None: + trunc_normal_(self.pos_embed, std=.02) + if self.dist_token is not None: + trunc_normal_(self.dist_token, std=.02) + if weight_init.startswith('jax'): + # leave cls token as zeros to match jax impl + for n, m in self.named_modules(): + _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True) + else: + if self.cls_token is not None: + trunc_normal_(self.cls_token, std=.02) + self.apply(_init_vit_weights) + + def _init_weights(self, m): + # this fn left here for compat with downstream users + _init_vit_weights(m) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token', 'dist_token'} + + def forward_features(self, x): + # TODO: move to encoder + x = self.patch_embed(x) + + if self.use_cls_token: + cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + if self.dist_token is None: + x = torch.cat((cls_token, x), dim=1) + else: + x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) + + if self.use_pos_embed: + x = self.pos_drop(x + self.pos_embed) + + if self.track_norms: output_norms = [torch.mean(x.detach() ** 2)] + + for block in self.blocks: + x, _ = block(x) + if self.track_norms: output_norms.append(torch.mean(x.detach() ** 2)) + x = self.norm(x) + + if self.track_norms: + metrics = to_dict(output_norms, recursive=False) + self.metrics = {f'norm/{i}': v for i, v in metrics.items()} + + if self.dist_token is None: + if self.use_cls_token: + return self.pre_logits(x[:, 0]) + else: + # pooling: TODO move to decoder + return self.pre_logits(x.mean(1)) + else: + return x[:, 0], x[:, 1] + + def forward(self, x, rate=1.0, resolution=None, state=None): + x = self.forward_features(x) + if self.head_dist is not None: + x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple + if self.training and not torch.jit.is_scripting(): + # during inference, return the average of both classifier predictions + return x, x_dist + else: + return (x + x_dist) / 2 + else: + x = self.head(x) + return x, None + + +def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False): + """ ViT weight initialization + * When called without n, head_bias, jax_impl args it will behave exactly the same + as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). + * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl + """ + if isinstance(m, (nn.Linear)): + if n.startswith('head'): + nn.init.zeros_(m.weight) + nn.init.constant_(m.bias, head_bias) + elif n.startswith('pre_logits'): + lecun_normal_(m.weight) + nn.init.zeros_(m.bias) + else: + if jax_impl: + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + if 'mlp' in n: + nn.init.normal_(m.bias, std=1e-6) + else: + nn.init.zeros_(m.bias) + else: + if m.bias is not None: + nn.init.zeros_(m.bias) + dense_init_fn_ = partial(trunc_normal_, std=.02) + if isinstance(m, nn.Linear): + dense_init_fn_(m.weight) + + elif jax_impl and isinstance(m, nn.Conv2d): + # NOTE conv was left to pytorch default in my original init + lecun_normal_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.LayerNorm): + nn.init.zeros_(m.bias) + nn.init.ones_(m.weight) + + +def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) + ntok_new = posemb_new.shape[1] + if num_tokens: + posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] + ntok_new -= num_tokens + else: + posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + gs_old = int(math.sqrt(len(posemb_grid))) + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))] * 2 + assert len(gs_new) >= 2 + _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear') + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + return posemb + + +def checkpoint_filter_fn(state_dict, model): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + for k, v in state_dict.items(): + if 'patch_embed.proj.weight' in k and len(v.shape) < 4: + # For old models that I trained prior to conv based patchification + O, I, H, W = model.patch_embed.proj.weight.shape + v = v.reshape(O, -1, H, W) + elif k == 'pos_embed' and v.shape != model.pos_embed.shape: + # To resize pos embedding when using model at different size from pretrained weights + v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1), + model.patch_embed.grid_size) + out_dict[k] = v + return out_dict + + +def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): + if default_cfg is None: + default_cfg = deepcopy(default_cfgs[variant]) + overlay_external_default_cfg(default_cfg, kwargs) + default_num_classes = default_cfg['num_classes'] + default_img_size = default_cfg['input_size'][-2:] + + num_classes = kwargs.pop('num_classes', default_num_classes) + img_size = kwargs.pop('img_size', default_img_size) + repr_size = kwargs.pop('representation_size', None) + if repr_size is not None and num_classes != default_num_classes: + # Remove representation layer if fine-tuning. This may not always be the desired action, + # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? + _logger.warning("Removing representation layer for fine-tuning.") + repr_size = None + + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + VisionTransformer, + variant, + pretrained, + default_cfg=default_cfg, + img_size=img_size, + num_classes=num_classes, + representation_size=repr_size, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + + return model + + +def vit_small_patch16_224(pretrained=False, **kwargs): + """ Tri's custom 'small' ViT model. d_model=768, depth=8, num_heads=8, mlp_ratio=3. + NOTE: + * this differs from the DeiT based 'small' definitions with d_model=384, depth=12, num_heads=6 + * this model does not have a bias for QKV (unlike the official ViT and DeiT models) + """ + print(kwargs) + model_kwargs = dict( + patch_size=16, + d_model=768, + depth=8, + expand=3, + norm='layer', + ) + model_kwargs = { + **model_kwargs, + **kwargs, + } + if pretrained: + # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model + model_kwargs.setdefault('qk_scale', 768 ** -0.5) + model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +def vit_base_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). + ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. + """ + model_kwargs = dict( + patch_size=16, + d_model=768, + depth=12, + # num_heads=12, + ) + model_kwargs = { + **model_kwargs, + **kwargs, + } + model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model diff --git a/src/clm/src/models/nn/__init__.py b/src/clm/src/models/nn/__init__.py new file mode 100644 index 00000000..aee8113e --- /dev/null +++ b/src/clm/src/models/nn/__init__.py @@ -0,0 +1 @@ +from .components import LinearActivation, Activation, Normalization, DropoutNd diff --git a/src/clm/src/models/nn/adaptive_softmax.py b/src/clm/src/models/nn/adaptive_softmax.py new file mode 100644 index 00000000..4ac9e2f0 --- /dev/null +++ b/src/clm/src/models/nn/adaptive_softmax.py @@ -0,0 +1,404 @@ +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional +import functools + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class OptionalParameterList(nn.ParameterList): + def extra_repr(self): + child_lines = [] + for k, p in self._parameters.items(): + if p is not None: + size_str = 'x'.join(str(size) for size in p.size()) + device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device()) + parastr = 'Parameter containing: [{} of size {}{}]'.format( + torch.typename(p), size_str, device_str) + child_lines.append(' (' + str(k) + '): ' + parastr) + tmpstr = '\n'.join(child_lines) + return tmpstr + + +class ProjectedAdaptiveLogSoftmax(nn.Module): + def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, + tie_projs=None, out_layers_weights=None, out_projs=None, + keep_order=False, + bias_scale=0.0, + dropout=0.0, + ): + super().__init__() + + self.n_token = n_token + self.d_embed = d_embed + self.d_proj = d_proj + + self.cutoffs = list(cutoffs) + [n_token] + self.cutoff_ends = [0] + self.cutoffs + self.div_val = div_val + + self.shortlist_size = self.cutoffs[0] + self.n_clusters = len(self.cutoffs) - 1 + self.head_size = self.shortlist_size + self.n_clusters + + # bake the first False into the definition, just as [0] is built into the cutoffs + if tie_projs is None: tie_projs = [] + elif isinstance(tie_projs, bool): tie_projs = [tie_projs] * len(cutoffs) + else: tie_projs = list(tie_projs) + tie_projs = [False] + tie_projs + self.tie_projs = tie_projs + + if self.n_clusters > 0: + self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed)) + self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) + + if not out_layers_weights: + self.out_layers_weights = nn.ParameterList() + else: + self.out_layers_weights = out_layers_weights + + self.out_layers_biases = nn.ParameterList() + + self.shared_out_projs = out_projs + self.out_projs = OptionalParameterList() + + self.dropout = dropout + self.drop = nn.Dropout(dropout) + + if div_val == 1: + if d_proj != d_embed: + for i in range(len(self.cutoffs)): + if tie_projs[i]: + self.out_projs.append(None) + else: + self.out_projs.append( + nn.Parameter(torch.zeros(d_proj, d_embed)) + ) + else: + self.out_projs.append(None) + + self.out_layers_biases.append( + nn.Parameter(torch.zeros(n_token)) + ) + + if not out_layers_weights: + self.out_layers_weights.append( + nn.Parameter(torch.zeros(n_token, d_embed)) + ) + else: + for i in range(len(self.cutoffs)): + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] + d_emb_i = d_embed // (div_val ** i) + + if tie_projs[i]: + self.out_projs.append(None) + else: + self.out_projs.append( + nn.Parameter(torch.zeros(d_proj, d_emb_i)) + ) + + self.out_layers_biases.append( + nn.Parameter(torch.zeros(r_idx - l_idx)) + ) + if not out_layers_weights: + self.out_layers_weights.append( + nn.Parameter(torch.zeros(r_idx - l_idx, d_emb_i)) + ) + for bias in self.out_layers_biases: + bound = bias_scale * d_proj ** -.5 + nn.init.uniform_(bias, -bound, bound) + + + self.keep_order = keep_order + + def _compute_logit(self, hidden, weight, bias, proj): + if proj is None: + logit = F.linear(hidden, weight, bias=bias) + else: + if self.dropout > 0.0: + logit = hidden @ proj + logit = self.drop(logit) + logit = logit @ weight.t() + else: + logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t())) + if bias is not None: + logit = logit + bias + return logit + + def get_out_proj(self, i): + if self.tie_projs[i]: + if len(self.shared_out_projs) == 0: + return None + elif len(self.shared_out_projs) == 1: + return self.shared_out_projs[0] + else: + return self.shared_out_projs[i] + else: + return self.out_projs[i] + + def forward(self, hidden, target, keep_order=False, key_padding_mask=None, *args, **kwargs): + # [21-09-15 AG]: TODO may need to handle key_padding_mask + ''' + hidden :: [len*bsz x d_proj] + target :: [len*bsz] + ''' + + hidden = hidden.reshape(-1, hidden.size(-1)) + target = target.reshape(-1) + if hidden.size(0) != target.size(0): + print(hidden.shape, target.shape) + raise RuntimeError('Input and target should have the same size ' + 'in the batch dimension.') + + if self.n_clusters == 0: + logit = self._compute_logit(hidden, self.out_layers_weights[0], + self.out_layers_biases[0], self.get_out_proj(0)) + nll = -F.log_softmax(logit, dim=-1) \ + .gather(1, target.unsqueeze(1)).squeeze(1) + else: + # construct weights and biases + weights, biases = [], [] + for i in range(len(self.cutoffs)): + if self.div_val == 1: + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] + weight_i = self.out_layers_weights[0][l_idx:r_idx] + bias_i = self.out_layers_biases[0][l_idx:r_idx] + else: + weight_i = self.out_layers_weights[i] + bias_i = self.out_layers_biases[i] + + if i == 0: + weight_i = torch.cat( + [weight_i, self.cluster_weight], dim=0) + bias_i = torch.cat( + [bias_i, self.cluster_bias], dim=0) + + weights.append(weight_i) + biases.append(bias_i) + + head_weight, head_bias, head_proj = weights[0], biases[0], self.get_out_proj(0) + + head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) + head_logprob = F.log_softmax(head_logit, dim=1) + + nll = torch.zeros_like(target, dtype=hidden.dtype, device=hidden.device) + + offset = 0 + cutoff_values = [0] + self.cutoffs + for i in range(len(cutoff_values) - 1): + l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] + + mask_i = (target >= l_idx) & (target < r_idx) + indices_i = mask_i.nonzero(as_tuple=False).squeeze() + + if indices_i.numel() == 0: + continue + + target_i = target.index_select(0, indices_i) - l_idx + head_logprob_i = head_logprob.index_select(0, indices_i) + + if i == 0: + logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1) + else: + weight_i, bias_i, proj_i = weights[i], biases[i], self.get_out_proj(i) + + hidden_i = hidden.index_select(0, indices_i) + + tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) + tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) + + # First term accounts for cluster probabilities + logprob_i = head_logprob_i[:, -i] \ + + tail_logprob_i.gather(1, target_i[:, None]).squeeze(1) + + if self.keep_order or keep_order: + nll.index_copy_(0, indices_i, -logprob_i) + else: + nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i) + + offset += logprob_i.size(0) # TODO This should be a bug in the original implementation; it should go into the continue case above as well + + return nll.mean() # TODO maybe cases for length or padding_mask + + def compute_logits(self, hidden): + """Compute full vector of logits + + Adapted from https://github.com/kimiyoung/transformer-xl/issues/88 + """ + hidden = hidden.reshape(-1, hidden.size(-1)) + + if self.n_clusters == 0: + logits = self._compute_logit(hidden, self.out_layers_weights[0], + self.out_layers_biases[0], self.get_out_proj(0)) + return logits + else: + # construct weights and biases + weights, biases = [], [] + for i in range(len(self.cutoffs)): + if self.div_val == 1: + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] + weight_i = self.out_layers_weights[0][l_idx:r_idx] + bias_i = self.out_layers_biases[0][l_idx:r_idx] + else: + weight_i = self.out_layers_weights[i] + bias_i = self.out_layers_biases[i] + + if i == 0: + weight_i = torch.cat( + [weight_i, self.cluster_weight], dim=0) + bias_i = torch.cat( + [bias_i, self.cluster_bias], dim=0) + + weights.append(weight_i) + biases.append(bias_i) + + head_weight, head_bias, head_proj = weights[0], biases[0], self.get_out_proj(0) + + head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) + head_logprob = F.log_softmax(head_logit, dim=1) + + out_full_logps = [head_logprob[:, :self.cutoffs[0]]] + offset = 0 + cutoff_values = [0] + self.cutoffs + + for i in range(1, len(cutoff_values) - 1): + l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] + head_logprob_i = head_logprob # .index_select(0, indices_i) + + if i == 0: + logprob_i = head_logprob_i + else: + weight_i, bias_i, proj_i = weights[i], biases[i], self.get_out_proj(i) + + hidden_i = hidden # .index_select(0, indices_i) + + tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) + tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) + logprob_i = head_logprob_i[:, -i].view(-1, 1) + tail_logprob_i + + offset += logprob_i.size(0) + out_full_logps.append(logprob_i) + out_full_logps = torch.cat(out_full_logps, dim = 1) + # print(torch.sum(out_full_ps), out_full_ps.shape) + return out_full_logps + + +class AdaptiveEmbedding(nn.Module): + """ Copy of transformers.AdaptiveEmbedding that works with fp16 by replacing the index_put_ operation + + Initialization has been fixed for the case when d_proj = d_embed + """ + def __init__(self, n_token, d_embed, d_proj, cutoffs : List[int], div_val=1, init_scale=1.0, sample_softmax=False, dropout=0.0): + super().__init__() + + self.n_token = n_token + self.d_embed = d_embed + + self.cutoffs = list(cutoffs) + [n_token] + self.div_val = div_val + self.d_proj = d_proj + self.drop = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() + + self.emb_scale = d_proj ** 0.5 + + self.cutoff_ends = [0] + self.cutoffs + + self.emb_layers = nn.ModuleList() + self.emb_projs = nn.ParameterList() + if div_val == 1: + self.emb_layers.append(nn.Embedding(n_token, d_embed, sparse=sample_softmax > 0)) + _init_embed(self.emb_layers[-1].weight, d_embed, init_scale) + # torch.nn.init.normal_(self.emb_layers[-1].weight, mean=0, std=init_scale * d_embed ** -.5) + if d_proj != d_embed: # TODO + # self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed))) + self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed))) + # torch.nn.init.normal_(self.emb_projs[-1], mean=0, std=init_scale * 1./self.emb_scale) + _init_proj(self.emb_projs[-1], d_proj, init_scale) + else: + for i in range(len(self.cutoffs)): + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] + d_emb_i = d_embed // (div_val ** i) + self.emb_layers.append(nn.Embedding(r_idx - l_idx, d_emb_i)) + # torch.nn.init.normal_(self.emb_layers[-1].weight, mean=0, std=init_scale * d_emb_i ** -.5) + _init_embed(self.emb_layers[-1].weight, d_emb_i, init_scale) + self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i))) + # torch.nn.init.normal_(self.emb_projs[-1], mean=0, std=init_scale * 1./self.emb_scale) + _init_proj(self.emb_projs[-1], d_proj, init_scale) + + def forward(self, inp): + if self.div_val == 1: + embed = self.emb_layers[0](inp) + embed = self.drop(embed) + if self.d_proj != self.d_embed: + embed = F.linear(embed, self.emb_projs[0]) + else: + param = next(self.parameters()) + inp_flat = inp.reshape(-1) + + # Changes from original impl + # emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device) + embeddings = [] + indices = torch.zeros_like(inp_flat) # empty should work as long as cutoffs[-1] > max token + _total_tokens = 0 + + # emb_flat = inp.new_zeros(inp_flat.size(0), self.d_proj) + for i in range(len(self.cutoffs)): + l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] + + mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx) + indices_i = mask_i.nonzero().squeeze(-1) # shape (_tokens,) + + _tokens = indices_i.numel() + if _tokens == 0: + continue + + inp_i = inp_flat.index_select(0, indices_i) - l_idx + emb_i = self.emb_layers[i](inp_i) + emb_i = self.drop(emb_i) + emb_i = F.linear(emb_i, self.emb_projs[i]) + + # Changes + embeddings.append(emb_i) + indices.index_put_( + (indices_i,), + torch.arange(_tokens, device=inp.device) + _total_tokens + ) + _total_tokens += _tokens + + # emb_flat.index_copy_(0, indices_i, emb_i) + embeddings = torch.cat(embeddings, dim=0) + emb_flat = embeddings[indices] + + embed_shape = inp.size() + (self.d_proj,) + embed = emb_flat.view(embed_shape) + + embed.mul_(self.emb_scale) + # embed.div_(self.emb_scale) + + return embed + + +def _init_weight(weight, d : int, init_scale : Optional[float], default=None): + assert init_scale or default + if init_scale is None: + std = default + else: + std = init_scale * (d ** -0.5) + nn.init.normal_(weight, mean=0, std=std) + +_init_embed = functools.partial(_init_weight, default=0.02) +_init_proj = functools.partial(_init_weight, default=0.01) diff --git a/src/clm/src/models/nn/components.py b/src/clm/src/models/nn/components.py new file mode 100644 index 00000000..b47e951e --- /dev/null +++ b/src/clm/src/models/nn/components.py @@ -0,0 +1,389 @@ +""" Utility nn components, in particular handling activations, initializations, and normalization layers """ + +from functools import partial +import math +from typing import ForwardRef +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from opt_einsum import contract + + +def stochastic_depth(input: torch.tensor, p: float, mode: str, training: bool = True): + """ + Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth" + `_ used for randomly dropping residual + branches of residual architectures. + + Args: + input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one + being its batch i.e. a batch with ``N`` rows. + p (float): probability of the input to be zeroed. + mode (str): ``"batch"`` or ``"row"``. + ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes + randomly selected rows from the batch. + training: apply stochastic depth if is ``True``. Default: ``True`` + + Returns: + Tensor[N, ...]: The randomly zeroed tensor. + """ + if p < 0.0 or p > 1.0: + raise ValueError("drop probability has to be between 0 and 1, but got {}".format(p)) + if mode not in ["batch", "row"]: + raise ValueError("mode has to be either 'batch' or 'row', but got {}".format(mode)) + if not training or p == 0.0: + return input + + survival_rate = 1.0 - p + if mode == "row": + size = [input.shape[0]] + [1] * (input.ndim - 1) + else: + size = [1] * input.ndim + noise = torch.empty(size, dtype=input.dtype, device=input.device) + noise = noise.bernoulli_(survival_rate).div_(survival_rate) + return input * noise + +class StochasticDepth(nn.Module): + """ + See :func:`stochastic_depth`. + """ + def __init__(self, p: float, mode: str) -> None: + # TODO(karan): need to upgrade to torchvision==0.11.0 to use StochasticDepth directly + # from torchvision.ops import StochasticDepth + super().__init__() + self.p = p + self.mode = mode + + def forward(self, input): + return stochastic_depth(input, self.p, self.mode, self.training) + + def __repr__(self) -> str: + tmpstr = self.__class__.__name__ + '(' + tmpstr += 'p=' + str(self.p) + tmpstr += ', mode=' + str(self.mode) + tmpstr += ')' + return tmpstr + +class DropoutNd(nn.Module): + def __init__(self, p: float = 0.5, tie=True, transposed=True): + """ + tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d) + """ + super().__init__() + if p < 0 or p >= 1: + raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p)) + self.p = p + self.tie = tie + self.transposed = transposed + self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p) + + def forward(self, X): + """ X: (batch, dim, lengths...) """ + if self.training: + if not self.transposed: X = rearrange(X, 'b d ... -> b ... d') + # binomial = torch.distributions.binomial.Binomial(probs=1-self.p) # This is incredibly slow + mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape + # mask = self.binomial.sample(mask_shape) + mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p + X = X * mask * (1.0/(1-self.p)) + if not self.transposed: X = rearrange(X, 'b ... d -> b d ...') + return X + return X + + +def Activation(activation=None, size=None, dim=-1): + if activation in [ None, 'id', 'identity', 'linear' ]: + return nn.Identity() + elif activation == 'tanh': + return nn.Tanh() + elif activation == 'relu': + return nn.ReLU() + elif activation == 'gelu': + return nn.GELU() + elif activation in ['swish', 'silu']: + return nn.SiLU() + elif activation == 'glu': + return nn.GLU(dim=dim) + elif activation == 'sigmoid': + return nn.Sigmoid() + elif activation == 'softplus': + return nn.Softplus() + elif activation in ['sqrelu', 'relu2']: + return SquaredReLU() + elif activation == 'laplace': + return Laplace() + elif activation == 'ln': + return TransposedLN(dim) + else: + raise NotImplementedError("hidden activation '{}' is not implemented".format(activation)) + +def get_initializer(name, activation=None): + if activation in [ None, 'id', 'identity', 'linear' ]: + nonlinearity = 'linear' + elif activation in ['relu', 'tanh', 'sigmoid']: + nonlinearity = activation + elif activation in ['gelu', 'swish', 'silu']: + nonlinearity = 'relu' # Close to ReLU so approximate with ReLU's gain + else: + raise NotImplementedError(f"get_initializer: activation {activation} not supported") + + if name == 'uniform': + initializer = partial(torch.nn.init.kaiming_uniform_, nonlinearity=nonlinearity) + elif name == 'normal': + initializer = partial(torch.nn.init.kaiming_normal_, nonlinearity=nonlinearity) + elif name == 'xavier': + initializer = torch.nn.init.xavier_normal_ + elif name == 'zero': + initializer = partial(torch.nn.init.constant_, val=0) + elif name == 'one': + initializer = partial(torch.nn.init.constant_, val=1) + else: + raise NotImplementedError(f"get_initializer: initializer type {name} not supported") + + return initializer + +def LinearActivation( + d_input, d_output, bias=True, + zero_bias_init=False, + transposed=False, + initializer=None, + activation=None, + activate=False, # Apply activation as part of this module + weight_norm=False, + **kwargs, + ): + """ Returns a linear nn.Module with control over axes order, initialization, and activation """ + + # Construct core module + # linear_cls = partial(nn.Conv1d, kernel_size=1) if transposed else nn.Linear + linear_cls = TransposedLinear if transposed else nn.Linear + if activation == 'glu': d_output *= 2 + linear = linear_cls(d_input, d_output, bias=bias, **kwargs) + + # Initialize weight + if initializer is not None: + get_initializer(initializer, activation)(linear.weight) + + # Initialize bias + if bias and zero_bias_init: + nn.init.zeros_(linear.bias) + + # Weight norm + if weight_norm: + linear = nn.utils.weight_norm(linear) + + if activate and activation is not None: + activation = Activation(activation, d_output, dim=1 if transposed else -1) + linear = nn.Sequential(linear, activation) + return linear + +class SquaredReLU(nn.Module): + def forward(self, x): + # return F.relu(x)**2 + return torch.square(F.relu(x)) # Could this be faster? + +def laplace(x, mu=0.707107, sigma=0.282095): + x = (x - mu).div(sigma * math.sqrt(2.0)) + return 0.5 * (1.0 + torch.erf(x)) + +class Laplace(nn.Module): + def __init__(self, mu=0.707107, sigma=0.282095): + super().__init__() + self.mu = mu + self.sigma = sigma + + def forward(self, x): + return laplace(x, mu=self.mu, sigma=self.sigma) + + +class TransposedLinear(nn.Module): + """ Linear module on the second-to-last dimension + Assumes shape (B, D, L), where L can be 1 or more axis + """ + + def __init__(self, d_input, d_output, bias=True): + super().__init__() + + self.weight = nn.Parameter(torch.empty(d_output, d_input)) + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # nn.Linear default init + # nn.init.kaiming_uniform_(self.weight, nonlinearity='linear') # should be equivalent + + if bias: + self.bias = nn.Parameter(torch.empty(d_output)) + bound = 1 / math.sqrt(d_input) + nn.init.uniform_(self.bias, -bound, bound) + setattr(self.bias, "_optim", {"weight_decay": 0.0}) + else: + self.bias = 0.0 + + def forward(self, x): + num_axis = len(x.shape[2:]) # num_axis in L, for broadcasting bias + y = contract('b u ..., v u -> b v ...', x, self.weight) + self.bias.view(-1, *[1]*num_axis) + return y + + +class TransposedLN(nn.Module): + """ LayerNorm module over second dimension + Assumes shape (B, D, L), where L can be 1 or more axis + + This is slow and a dedicated CUDA/Triton implementation shuld provide substantial end-to-end speedup + """ + def __init__(self, d, scalar=True): + super().__init__() + self.scalar = scalar + if self.scalar: + self.m = nn.Parameter(torch.zeros(1)) + self.s = nn.Parameter(torch.ones(1)) + setattr(self.m, "_optim", {"weight_decay": 0.0}) + setattr(self.s, "_optim", {"weight_decay": 0.0}) + else: + self.ln = nn.LayerNorm(d) + + def forward(self, x): + if self.scalar: + # calc. stats over D dim / channels + s, m = torch.std_mean(x, dim=1, unbiased=False, keepdim=True) + y = (self.s/s) * (x-m+self.m) + else: + # move channel to last axis, apply layer_norm, then move channel back to second axis + _x = self.ln(rearrange(x, 'b d ... -> b ... d')) + y = rearrange(_x, 'b ... d -> b d ...') + return y + +class Normalization(nn.Module): + def __init__( + self, + d, + transposed=False, # Length dimension is -1 or -2 + _name_='layer', + **kwargs + ): + super().__init__() + self.transposed = transposed + self._name_ = _name_ + + if _name_ == 'layer': + self.channel = True # Normalize over channel dimension + if self.transposed: + self.norm = TransposedLN(d, **kwargs) + else: + self.norm = nn.LayerNorm(d, **kwargs) + elif _name_ == 'instance': + self.channel = False + norm_args = {'affine': False, 'track_running_stats': False} + norm_args.update(kwargs) + self.norm = nn.InstanceNorm1d(d, **norm_args) # (True, True) performs very poorly + elif _name_ == 'batch': + self.channel = False + norm_args = {'affine': True, 'track_running_stats': True} + norm_args.update(kwargs) + self.norm = nn.BatchNorm1d(d, **norm_args) + elif _name_ == 'group': + self.channel = False + self.norm = nn.GroupNorm(1, d, *kwargs) + elif _name_ == 'none': + self.channel = True + self.norm = nn.Identity() + else: raise NotImplementedError + + def forward(self, x): + # Handle higher dimension logic + shape = x.shape + if self.transposed: + x = rearrange(x, 'b d ... -> b d (...)') + else: + x = rearrange(x, 'b ... d -> b (...)d ') + + # The cases of LayerNorm / no normalization are automatically handled in all cases + # Instance/Batch Norm work automatically with transposed axes + if self.channel or self.transposed: + x = self.norm(x) + else: + x = x.transpose(-1, -2) + x = self.norm(x) + x = x.transpose(-1, -2) + + x = x.view(shape) + return x + + def step(self, x, **kwargs): + assert self._name_ in ["layer", "none"] + if self.transposed: x = x.unsqueeze(-1) + x = self.forward(x) + if self.transposed: x = x.squeeze(-1) + return x + +class TSNormalization(nn.Module): + + def __init__(self, method, horizon): + super().__init__() + + self.method = method + self.horizon = horizon + + + def forward(self, x): + # x must be BLD + if self.method == 'mean': + self.scale = x.abs()[:, :-self.horizon].mean(dim=1)[:, None, :] + return x / self.scale + elif self.method == 'last': + self.scale = x.abs()[:, -self.horizon-1][:, None, :] + return x / self.scale + return x + +class TSInverseNormalization(nn.Module): + + def __init__(self, method, normalizer): + super().__init__() + + self.method = method + self.normalizer = normalizer + + def forward(self, x): + if self.method == 'mean' or self.method == 'last': + return x * self.normalizer.scale + return x + +class ReversibleInstanceNorm1dInput(nn.Module): + def __init__(self, d, transposed=False): + super().__init__() + # BLD if transpoed is False, otherwise BDL + self.transposed = transposed + self.norm = nn.InstanceNorm1d(d, affine=True, track_running_stats=False) + + def forward(self, x): + # Means, stds + if not self.transposed: + x = x.transpose(-1, -2) + + self.s, self.m = torch.std_mean(x, dim=-1, unbiased=False, keepdim=True) + self.s += 1e-4 + + x = (x - self.m) / self.s + # x = self.norm.weight.unsqueeze(-1) * x + self.norm.bias.unsqueeze(-1) + + if not self.transposed: + return x.transpose(-1, -2) + return x + +class ReversibleInstanceNorm1dOutput(nn.Module): + + def __init__(self, norm_input): + super().__init__() + self.transposed = norm_input.transposed + self.weight = norm_input.norm.weight + self.bias = norm_input.norm.bias + self.norm_input = norm_input + + def forward(self, x): + if not self.transposed: + x = x.transpose(-1, -2) + + # x = (x - self.bias.unsqueeze(-1))/self.weight.unsqueeze(-1) + x = x * self.norm_input.s + self.norm_input.m + + if not self.transposed: + return x.transpose(-1, -2) + return x diff --git a/src/clm/src/models/nn/dxt.py b/src/clm/src/models/nn/dxt.py new file mode 100644 index 00000000..a9813bc5 --- /dev/null +++ b/src/clm/src/models/nn/dxt.py @@ -0,0 +1,196 @@ +"""Implementations of several types of Discrete Sin/Cosine Transforms with various reductions to FFT. + +Currently not used by S4 +""" + +import torch +import torch.nn as nn +import numpy as np +import scipy.fft +from einops import rearrange, repeat + +class DCT(nn.Module): + """ Reductions adapted from https://dsp.stackexchange.com/questions/2807/fast-cosine-transform-via-fft """ + + def __init__(self, N, norm='backward'): + super().__init__() + + self.N = N + + # Materialize DCT matrix + P = scipy.fft.dct(np.eye(N), norm=norm, type=2).T + P = torch.tensor(P, dtype=torch.float) + self.register_buffer('P', P) + + # TODO take care of normalization + Q = np.exp(-1j * np.pi / (2 * self.N) * np.arange(self.N)) + Q = torch.tensor(Q, dtype=torch.cfloat) + self.register_buffer('Q', Q) # half shift + + def forward(self, x, mode=2): + if mode == 0: + return self.forward_dense(x) + elif mode == 1: + return self.forward_n(x) + elif mode == 2: + return self.forward_2n(x) + elif mode == 4: + return self.forward_4n(x) + + def forward_dense(self, x): + """ Baseline DCT type II - matmul by DCT matrix """ + y = (self.P.to(x) @ x.unsqueeze(-1)).squeeze(-1) + return y + + def forward_4n(self, x): + """ DCT type II - reduction to FFT size 4N """ + assert self.N == x.shape[-1] + x = torch.cat([x, x.flip(-1)], dim=-1) + z = torch.zeros_like(x) + x = torch.stack([z, x], dim=-1) + x = x.view(x.shape[:-2] + (-1,)) + y = torch.fft.fft(x) + y = y[..., :self.N] + if torch.is_complex(x): + return y + else: + return torch.real(y) + + def forward_2n(self, x): + """ DCT type II - reduction to FFT size 2N mirrored + + The reduction from the DSP forum is not quite correct in the complex input case. + halfshift(FFT[a, b, c, d, d, c, b, a]) -> [A, B, C, D, 0, -D, -C, -B] + In the case of real input, the intermediate step after FFT has form [A, B, C, D, 0, D*, C*, B*] + """ + assert self.N == x.shape[-1] + x = torch.cat([x, x.flip(-1)], dim=-1) + y = torch.fft.fft(x)[..., :self.N] + y = y * self.Q + if torch.is_complex(x): + return y + else: + return torch.real(y) + + def forward_n(self, x): + """ DCT type II - reduction to size N """ + assert self.N == x.shape[-1] + x = torch.cat([x[..., 0::2], x[..., 1::2].flip(-1)], dim=-1) + y = torch.fft.fft(x) + y = y * 2 * self.Q + if torch.is_complex(x): + y = torch.cat([y[..., :1], (y[..., 1:] + 1j * y[..., 1:].flip(-1)) / 2], dim=-1) # TODO in-place sum + else: + y = torch.real(y) + return y + +class IDCT(nn.Module): + def __init__(self, N, norm='backward'): + super().__init__() + + self.N = N + + # Materialize DCT matrix + P = np.linalg.inv(scipy.fft.dct(np.eye(N), norm=norm, type=2).T) + P = torch.tensor(P, dtype=torch.float) + self.register_buffer('P', P) + + # TODO take care of normalization + Q = np.exp(-1j * np.pi / (2 * self.N) * np.arange(2*self.N)) + Q = torch.tensor(Q, dtype=torch.cfloat) + self.register_buffer('Q', Q) # half shift + + def forward(self, x, mode=2): + if mode == 0: + return self.forward_dense(x) + elif mode == 1: + return self.forward_n(x) + elif mode == 2: + return self.forward_2n(x) + elif mode == 4: + return self.forward_4n(x) + + def forward_dense(self, x): + """ Baseline DCT type II - matmul by DCT matrix """ + y = (self.P.to(x) @ x.unsqueeze(-1)).squeeze(-1) + return y + + def forward_4n(self, x): + """ DCT type II - reduction to FFT size 4N """ + assert self.N == x.shape[-1] + z = x.new_zeros(x.shape[:-1] + (1,)) + x = torch.cat([x, z, -x.flip(-1), -x[..., 1:], z, x[..., 1:].flip(-1)], dim=-1) + y = torch.fft.ifft(x) + y = y[..., 1:2*self.N:2] + if torch.is_complex(x): + return y + else: + return torch.real(y) + + def forward_2n(self, x): + """ DCT type II - reduction to FFT size 2N mirrored """ + assert self.N == x.shape[-1] + z = x.new_zeros(x.shape[:-1] + (1,)) + x = torch.cat([x, z, -x[..., 1:].flip(-1)], dim=-1) + x = x / self.Q + y = torch.fft.ifft(x)[..., :self.N] + if torch.is_complex(x): + return y + else: + return torch.real(y) + + def forward_n(self, x): + """ DCT type II - reduction to size N """ + assert self.N == x.shape[-1] + raise NotImplementedError # Straightforward by inverting operations of DCT-II reduction + +def test_dct_ii(): + N = 8 + dct = DCT(N) + + baseline = dct.forward_dense + methods = [dct.forward_4n, dct.forward_2n, dct.forward_n] + + # Real case + print("DCT-II Real input") + x = torch.randn(1, N) + y = baseline(x) + print(y) + for fn in methods: + y_ = fn(x) + print("err", torch.norm(y-y_)) + + # Complex case + print("DCT-II Complex input") + x = torch.randn(N) + 1j * torch.randn(N) + y = baseline(x) + print(y) + for fn in methods: + y_ = fn(x) + print("err", torch.norm(y-y_)) + +def test_dct_iii(): + N = 8 + dct = IDCT(N) + + baseline = dct.forward_dense + methods = [dct.forward_4n, dct.forward_2n] + + # Real case + print("DCT-III Real input") + x = torch.randn(1, N) + y = baseline(x) + print(y) + for fn in methods: + y_ = fn(x) + print("err", torch.norm(y-y_)) + + # Complex case + print("DCT-III Complex input") + # x = torch.randn(N) + 1j * torch.randn(N) + x = 1j * torch.ones(N) + y = baseline(x) + print(y) + for fn in methods: + y_ = fn(x) + print("err", torch.norm(y-y_)) diff --git a/src/clm/src/models/nn/gate.py b/src/clm/src/models/nn/gate.py new file mode 100644 index 00000000..d0a531f7 --- /dev/null +++ b/src/clm/src/models/nn/gate.py @@ -0,0 +1,128 @@ +""" Defines flexible gating mechanisms based on ideas from LSSL paper and UR-LSTM paper https://arxiv.org/abs/1910.09890 """ + +import torch +import torch.nn as nn + +class Gate(nn.Module): + """ Implements gating mechanisms. TODO update this with more detailed description with reference to LSSL paper when it's on arxiv + + Mechanisms: + N - No gate + G - Standard sigmoid gate + UR - Uniform refine gates + R - Refine gate + + FS - Forward discretization, Sigmoid activation [equivalent to G] + BE - Backward discretization, Exp activation [equivalent to G] + BR - Backward discretization, Relu activation + TE - Trapezoid discretization, Exp activation + TR - Trapezoid discretization, Relu activation + TS - Trapezoid discretization, Sigmoid activation (0 to 2) + """ + def __init__(self, size, preact_ctor, preact_args, mechanism='N'): + super().__init__() + self.size = size + self.mechanism = mechanism + + if self.mechanism == 'N': + pass + elif self.mechanism in ['G', 'FS', 'BE', 'BR', 'TE', 'TR', 'TS', 'ZE', 'ZR', 'ZS']: + self.W_g = preact_ctor(*preact_args) + elif self.mechanism in ['U', 'UT']: + self.W_g = preact_ctor(*preact_args) + b_g_unif = torch.empty(size) + torch.nn.init.uniform_(b_g_unif, 1./self.size, 1.-1./self.size) + self.b_g = nn.Parameter(torch.log(1./b_g_unif-1.).detach(), requires_grad=False) + elif self.mechanism == 'UR': + self.W_g = preact_ctor(*preact_args) + self.W_r = preact_ctor(*preact_args) + + b_g_unif = torch.empty(size) + torch.nn.init.uniform_(b_g_unif, 1./self.size, 1.-1./self.size) + self.b_g = nn.Parameter(torch.log(1./b_g_unif-1.).detach(), requires_grad=False) + elif self.mechanism == 'R': + self.W_g = preact_ctor(*preact_args) + self.W_r = preact_ctor(*preact_args) + elif self.mechanism in ['GT']: + self.W_g = preact_ctor(*preact_args) + else: + assert False, f'Gating type {self.mechanism} is not supported.' + + def forward(self, *inputs): + if self.mechanism == 'N': + return 1.0 + + if self.mechanism == 'G': + g_preact = self.W_g(*inputs) + g = torch.sigmoid(g_preact) + if self.mechanism == 'U': + g_preact = self.W_g(*inputs) + self.b_g + g = torch.sigmoid(g_preact) + elif self.mechanism == 'UR': + g_preact = self.W_g(*inputs) + self.b_g + g = torch.sigmoid(g_preact) + r = torch.sigmoid(self.W_r(*inputs)) + g = (1-2*r)*g**2 + 2*r*g + elif self.mechanism == 'R': + g_preact = self.W_g(*inputs) + g = torch.sigmoid(g_preact) + r = torch.sigmoid(self.W_r(*inputs)) + g = (1-2*r)*g**2 + 2*r*g + elif self.mechanism == 'UT': + g_preact = self.W_g(*inputs) + self.b_g + g = torch.sigmoid(g_preact) + r = g + g = (1-2*r)*g**2 + 2*r*g + elif self.mechanism == 'GT': + g_preact = self.W_g(*inputs) + g = torch.sigmoid(g_preact) + r = g + g = (1-2*r)*g**2 + 2*r*g + else: + g_preact = self.W_g(*inputs) + # if self.mechanism[1] == 'S': + # g = torch.sigmoid(g_preact) + # elif self.mechanism[1] == 'E': + # g = torch.exp(g_preact) + # elif self.mechanism[1] == 'R': + # g = torch.relu(g_preact) + if self.mechanism == 'FS': + g = torch.sigmoid(g_preact) + g = self.forward_diff(g) + elif self.mechanism == 'BE': + g = torch.exp(g_preact) + g = self.backward_diff(g) + elif self.mechanism == 'BR': + g = torch.relu(g_preact) + g = self.backward_diff(g) + elif self.mechanism == 'TS': + g = 2 * torch.sigmoid(g_preact) + g = self.trapezoid(g) + elif self.mechanism == 'TE': + g = torch.exp(g_preact) + g = self.trapezoid(g) + elif self.mechanism == 'TR': + g = torch.relu(g_preact) + g = self.trapezoid(g) + elif self.mechanism == 'ZE': + g = torch.exp(g_preact) + g = self.zoh(g) + elif self.mechanism == 'ZR': + g = torch.relu(g_preact) + g = self.zoh(g) + elif self.mechanism == 'ZS': + g = torch.sigmoid(g_preact) + g = self.zoh(g) + return g + + def forward_diff(self, x): + return x + + def backward_diff(self, x): + return x / (1+x) + + def trapezoid(self, x): + return x / (1 + x/2) + + def zoh(self, x): + return 1 - torch.exp(-x) diff --git a/src/clm/src/models/nn/residual.py b/src/clm/src/models/nn/residual.py new file mode 100644 index 00000000..360697e2 --- /dev/null +++ b/src/clm/src/models/nn/residual.py @@ -0,0 +1,108 @@ +""" Implementations of different types of residual functions. """ + +import torch +from torch import nn + +class Residual(nn.Module): + """ Residual connection with constant affine weights. Can simulate standard residual, no residual, and "constant gates". """ + + def __init__(self, i_layer, d_input, d_model, alpha=1.0, beta=1.0): + # print("ConstantResidual extra kwargs", kwargs) + super().__init__() + assert (d_input == d_model) or alpha == 0.0 + self.i_layer = i_layer + self.d_input = d_input + self.d_model = d_model + self.alpha = alpha + self.beta = beta + + @property + def d_output(self): + return self.d_model + + def forward(self, x, y, transposed): # TODO documentation of transposed + y = self.beta*y if self.beta != 1.0 else y + return self.alpha * x + y if self.alpha else y + +class Affine(Residual): + """ Residual connection with learnable scalar multipliers on the main branch + scalar: Single scalar multiplier, or one per dimension + scale, power: Initialize to scale * layer_num**(-power) + """ + + def __init__(self, *args, scalar=True, gamma=0.0, **kwargs): + # print("ConstantResidual extra kwargs", kwargs) + super().__init__(*args, **kwargs) + self.scalar = scalar + self.gamma = gamma + + c = self.beta * self.i_layer ** (-self.gamma) + d = 1 if self.scalar else self.d_input + self.affine = nn.Parameter(c * torch.ones(d)) + + def forward(self, x, y, transposed): # TODO documentation of transposed + c = self.affine + if transposed: c = c.unsqueeze(-1) + return self.alpha * x + c * y + + +class Feedforward(Residual): + def __init__(self, *args): + # print("Feedforward extra kwargs", kwargs) + super().__init__(*args, alpha=0.0, beta=1.0) + + +class Highway(Residual): + def __init__(self, *args, scaling_correction=False, elemwise=False): + super().__init__(*args) + self.scaling_correction = 1.732 if scaling_correction else 1.0 # TODO + self.elemwise = elemwise + self.Wx = nn.Linear(self.d_input, self.d_input) + if self.elemwise: + self.Wy = nn.Parameter(torch.randn(self.d_input)) + else: + self.Wy = nn.Linear(self.d_input, self.d_input) + + def forward(self, x, y, transposed=False): # TODO handle this case + if self.elemwise: + y = self.Wy * y + else: + y = self.Wy(y) + r = torch.sigmoid(self.Wx(x) + y) + z = self.scaling_correction * (1.-r) * x + r * y + return z + + +class DecayResidual(Residual): + """ Residual connection that can decay the linear combination depending on depth. """ + + def __init__(self, *args, power=0.5, l2=True): + # print("DecayResidual extra kwargs", kwargs) + super().__init__(*args) + self.power = power + self.l2 = l2 + + def forward(self, x, y, transposed): + beta = self.i_layer ** (-self.power) + if self.l2: + alpha = (1. - beta**2)**0.5 + else: + alpha = 1. - beta + + return alpha * x + beta * y + +registry = { + 'F': Feedforward, + 'N': Feedforward, + 'R': Residual, + 'H': Highway, + 'D': DecayResidual, + 'A': Affine, + 'none': Feedforward, + 'ff': Feedforward, + 'feedforward': Feedforward, + 'residual': Residual, + 'highway': Highway, + 'decay': DecayResidual, + 'affine': Affine, +} diff --git a/src/clm/src/models/nn/utils.py b/src/clm/src/models/nn/utils.py new file mode 100644 index 00000000..2c4d18d9 --- /dev/null +++ b/src/clm/src/models/nn/utils.py @@ -0,0 +1,125 @@ +""" Utility wrappers around modules to let them handle Args and extra arguments """ + +import inspect +from functools import wraps +import torch +from torch import nn + +def wrap_kwargs(f): + """ + Given a callable f that can consume some named arguments, + wrap it with a kwargs that passes back any unused args + + EXAMPLES + -------- + + Basic usage: + def foo(x, y=None): + return x + + wrap_kwargs(foo)(0, y=1, z=2) == (0, {'z': 2}) + + -------- + + The wrapped function can return its own argument dictionary, + which gets merged with the new kwargs. + def foo(x, y=None): + return x, {} + wrap_kwargs(foo)(0, y=1, z=2) == (0, {'z': 2}) + + def foo(x, y=None): + return x, {"y": y, "z": None} + wrap_kwargs(foo)(0, y=1, z=2) == (0, {'y': 1, 'z': 2}) + + -------- + + The wrapped function can have its own kwargs parameter: + def foo(x, y=None, **kw_args): + return x, {} + wrap_kwargs(foo)(0, y=1, z=2) == (0, {}) + + -------- + + Partial functions and modules work automatically: + class Module: + def forward(self, x, y=0): + return x, {"y": y+1} + + m = Module() + + wrap_kwargs(m.forward)(0, y=1, z=2) == (0, {'y': 2, 'z': 2}) + + """ + sig = inspect.signature(f) + # Check if f already has kwargs + has_kwargs = any([ + param.kind == inspect.Parameter.VAR_KEYWORD + for param in sig.parameters.values() + ]) + if has_kwargs: + @wraps(f) + def f_kwargs(*args, **kwargs): + y = f(*args, **kwargs) + if isinstance(y, tuple) and isinstance(y[-1], dict): + return y + else: + return y, {} + else: + param_kwargs = inspect.Parameter("kwargs", kind=inspect.Parameter.VAR_KEYWORD) + sig_kwargs = inspect.Signature(parameters=list(sig.parameters.values())+[param_kwargs]) + @wraps(f) + def f_kwargs(*args, **kwargs): + bound = sig_kwargs.bind(*args, **kwargs) + if "kwargs" in bound.arguments: + kwargs = bound.arguments.pop("kwargs") + else: + kwargs = {} + y = f(**bound.arguments) + if isinstance(y, tuple) and isinstance(y[-1], dict): + return *y[:-1], {**y[-1], **kwargs} + else: + return y, kwargs + return f_kwargs + +def discard_kwargs(f): + if f is None: return None + f_kwargs = wrap_kwargs(f) + @wraps(f) + def f_(*args, **kwargs): + return f_kwargs(*args, **kwargs)[0] + return f_ + +def PassthroughSequential(*modules): + """Special Sequential module that chains kwargs. + + Semantics are the same as nn.Sequential, with extra convenience features: + - Discard None modules + - Flatten inner Sequential modules + - In case with 0 or 1 Module, rename the class for ease of inspection + """ + def flatten(module): + if isinstance(module, nn.Sequential): + return sum([flatten(m) for m in module], []) + else: + return [module] + + modules = flatten(nn.Sequential(*modules)) + modules = [module for module in modules if module if not None] + + class Sequential(nn.Sequential): + def forward(self, x, **kwargs): + for layer in self: + x, kwargs = wrap_kwargs(layer.forward)(x, **kwargs) + return x, kwargs + + def step(self, x, **kwargs): + for layer in self: + fn = getattr(layer, "step", layer.forward) + x, kwargs = wrap_kwargs(fn)(x, **kwargs) + return x, kwargs + + if len(modules) == 0: + Sequential.__name__ = "Identity" + elif len(modules) == 1: + Sequential.__name__ = type(modules[0]).__name__ + return Sequential(*modules) diff --git a/src/clm/src/models/sequence/__init__.py b/src/clm/src/models/sequence/__init__.py new file mode 100644 index 00000000..38669fb6 --- /dev/null +++ b/src/clm/src/models/sequence/__init__.py @@ -0,0 +1,3 @@ +from .base import SequenceModule, TransposedModule +from .model import SequenceModel +from .ff import FF diff --git a/src/clm/src/models/sequence/base.py b/src/clm/src/models/sequence/base.py new file mode 100644 index 00000000..4f8a4ffa --- /dev/null +++ b/src/clm/src/models/sequence/base.py @@ -0,0 +1,131 @@ +from torch import nn +import functools + +class SequenceModule(nn.Module): + """Abstract sequence model class. All models must adhere to this interface + + A SequenceModule is generally a model that transforms an input of shape + (n_batch, l_sequence, d_model) to (n_batch, l_sequence, d_output) + + REQUIRED methods and attributes + forward, d_model, d_output: controls standard forward pass, a sequence-to-sequence transformation + __init__ should also satisfy the following interface; see SequenceIdentity for an example + def __init__(self, d_model, transposed=False, **kwargs) + + OPTIONAL methods + default_state, step: allows stepping the model recurrently with a hidden state + state_to_tensor, d_state: allows decoding from hidden state + """ + + @property + def d_model(self): + """Model dimension (generally same as input dimension). + + This attribute is required for all SequenceModule instantiations. + It is used by the rest of the pipeline (e.g. model backbone, encoder) to track the internal shapes of the full model. + """ + if getattr(self, "_d_model", None) is None: + raise NotImplementedError("SequenceModule instantiation must set d_model") + return self._d_model + + @d_model.setter + def d_model(self, d): + self._d_model = d + + @property + def d_output(self): + """Output dimension of model. + + This attribute is required for all SequenceModule instantiations. + It is used by the rest of the pipeline (e.g. model backbone, decoder) to track the internal shapes of the full model. + """ + if getattr(self, "_d_output", None) is None: + raise NotImplementedError("SequenceModule instantiation must specify d_output for decoder") + return self._d_output + + @d_output.setter + def d_output(self, d): + self._d_output = d + + def forward(self, x, state=None, **kwargs): + """Forward pass of sequence model, a sequence-to-sequence transformation with an optional state. + + Generally, this should map a tensor of shape (batch, length, self.d_model) to (batch, length, self.d_output) + + Additionally, it returns a "state" which can be any additional information + For example, RNN and SSM layers may return their hidden state, + while some types of transformer layers (e.g. Transformer-XL) may want to pass a state as well + """ + return x, None + + @property + def state_to_tensor(self): + """Returns a function mapping a state to a single tensor. + + This method should be implemented if one wants to use the hidden state instead of the output sequence for final prediction. + Currently only used with the StateDecoder. + """ + return lambda _: None + + @property + def d_state(self): + """ Returns dimension of output of self.state_to_tensor """ + return None + + + def default_state(self, *batch_shape, device=None): + """Create initial state for a batch of inputs.""" + + return None + + def step(self, x, state=None, **kwargs): + """Step the model recurrently for one step of the input sequence. + + For example, this should correspond to unrolling an RNN for one step. + If the forward pass has signature (B, L, H1) -> (B, L, H2), + this method should generally have signature (B, H1) -> (B, H2) with an optional recurrent state. + """ + raise NotImplementedError + +def TransposedModule(module): + """Wrap a SequenceModule class to accept transposed parameter, handle state, absorb kwargs""" + # https://stackoverflow.com/a/65470430/1980685 + @functools.wraps(module, updated=()) + class TransposedModule(module): + def __init__(self, *args, transposed=False, **kwargs): + super().__init__(*args, **kwargs) + self.transposed = transposed + + def forward(self, x, state=None, **kwargs): + if self.transposed: x = x.transpose(-1, -2) + x, next_state = super().forward(x, state) # Don't use kwarg because nn.LSTM + next_state = None if state is None else next_state + if self.transposed: x = x.transpose(-1,-2) + return x, next_state + # https://stackoverflow.com/questions/5352781/how-to-set-class-names-dynamically + # TransposedModule.__name__ = module.__name__ # functools wraps is better solution + return TransposedModule + +@TransposedModule +class SequenceIdentity(SequenceModule): + """Simple SequenceModule for testing purposes""" + + def __init__(self, d_model, dropout=0.0, **kwargs): + """Default interface for SequenceModule + + d_model: input dimension (sometimes denoted H for hidden dimension) + transposed: if True, inputs have axis ordering (B, H, L) instead of (B, H, L) + """ + super().__init__() + self.d_model = d_model + self.d_output = d_model + + + def forward(self, x, state=None): + return x, state + + def default_state(self, *batch_shape, device=None): + return None + + def step(self, x, state=None, **kwargs): + return x, state diff --git a/src/clm/src/models/sequence/block.py b/src/clm/src/models/sequence/block.py new file mode 100644 index 00000000..f44ee109 --- /dev/null +++ b/src/clm/src/models/sequence/block.py @@ -0,0 +1,129 @@ +""" Implements a full residual block around a black box layer + +Configurable options include: +normalization position: prenorm or postnorm +normalization type: batchnorm, layernorm etc. +subsampling/pooling +residual options: feedforward, residual, affine scalars, depth-dependent scaling, etc. +""" + +from torch import nn + +from functools import partial +import clm.src.utils as utils +from clm.src.models.nn.components import Normalization, StochasticDepth, DropoutNd +from clm.src.models.sequence import SequenceModule +from clm.src.models.sequence.pool import registry as pool_registry +from clm.src.models.nn.residual import registry as residual_registry +import clm.src.utils.registry as registry + + +class SequenceResidualBlock(SequenceModule): + def __init__( + self, + d_input, + i_layer=None, # Only needs to be passed into certain residuals like Decay + prenorm=True, + dropout=0.0, + tie_dropout=False, + transposed=False, + layer=None, # Config for black box module + residual=None, # Config for residual function + norm=None, # Config for normalization layer + pool=None, + drop_path=0., + ): + super().__init__() + + self.i_layer = i_layer + self.d_input = d_input + self.layer = utils.instantiate(registry.layer, layer, d_input) + self.prenorm = prenorm + self.transposed = transposed + + # Residual + # d_residual is the output dimension after residual + if residual is None: + self.residual = None + self.d_residual = self.layer.d_output + else: + self.residual = utils.instantiate(residual_registry, residual, i_layer, d_input, self.layer.d_output) + self.d_residual = self.residual.d_output + + # Normalization + d_norm = d_input if self.prenorm else self.d_residual + # We don't use config to directly instantiate since Normalization has some special cases + if norm is None: + self.norm = None + elif isinstance(norm, str): + self.norm = Normalization(d_norm, transposed=self.transposed, _name_=norm) + else: + self.norm = Normalization(d_norm, transposed=self.transposed, **norm) + + # Pool + self.pool = utils.instantiate(pool_registry, pool, self.d_residual, transposed=self.transposed) + + # Dropout + dropout_cls = partial(DropoutNd, transposed=self.transposed) if tie_dropout else nn.Dropout + self.drop = dropout_cls(dropout) if dropout > 0.0 else nn.Identity() + + # Stochastic depth + self.drop_path = StochasticDepth(drop_path, mode='row') if drop_path > 0.0 else nn.Identity() + + + @property + def d_output(self): + return self.pool.d_output if self.pool is not None else self.d_residual + + @property + def d_state(self): + return self.layer.d_state + + @property + def state_to_tensor(self): + return self.layer.state_to_tensor + + def default_state(self, *args, **kwargs): + return self.layer.default_state(*args, **kwargs) + + def forward(self, x, state=None, **kwargs): + y = x + + # Pre-norm + if self.norm is not None and self.prenorm: y = self.norm(y) + + # Black box layer + y, state = self.layer(y, state=state, **kwargs) + + # Residual + if self.residual is not None: y = self.residual(x, self.drop_path(self.drop(y)), self.transposed) + + # Post-norm + if self.norm is not None and not self.prenorm: y = self.norm(y) + + # Pool + if self.pool is not None: y, _ = self.pool(y) + + return y, state + + def step(self, x, state, **kwargs): + y = x + + # Pre-norm + if self.norm is not None and self.prenorm: + y = self.norm.step(y) + + # Black box layer + y, state = self.layer.step(y, state, **kwargs) + + # Residual + if self.residual is not None: y = self.residual(x, y, transposed=False) # NOTE this would not work with concat residual function (catformer) + + # Post-norm + if self.norm is not None and not self.prenorm: + y = self.norm.step(y) + + # Pool + if self.pool is not None: y, _ = self.pool(y) + + return y, state diff --git a/src/clm/src/models/sequence/block_fft.py b/src/clm/src/models/sequence/block_fft.py new file mode 100644 index 00000000..c0a1c568 --- /dev/null +++ b/src/clm/src/models/sequence/block_fft.py @@ -0,0 +1,177 @@ +'''PyTorch version of the block FFT convolution as described in the H3 paper.''' + +import torch +from einops import rearrange +import math +from torch import nn +from clm.src.models.nn import Activation +from clm.src.utils.train import OptimModule + +def ref_dft_matrix(N, H=1): + """Compute the DFT matrix of size N x N. + + This is where we could add extra compute for free.""" + # n = torch.arange(N) + n = torch.arange(N).cuda() + k = n.view(-1, 1) + M = torch.exp(-2j * torch.pi * n * k / N) + return torch.view_as_real(M.repeat(H, 1, 1)) + +def compute_twiddle_factors(n, m): + """Compute the twiddle factors of size n x m""" + # n_a = torch.arange(n).view(-1, 1) + # m_a = torch.arange(m) + n_a = torch.arange(n).cuda().view(-1, 1) + m_a = torch.arange(m).cuda() + N = n * m + M = torch.exp(-2j * torch.pi * n_a * m_a / N) + return torch.view_as_real(M) + +def _cooley_tukey( + k, n, m, + dft_matrix=ref_dft_matrix, + max_m=16, + activation=None, +): + ''' + Compute the FFT using the general Cooley-Tukey algorithm: + * Reshape to (m, n) + * Do n m-length FFTs along the rows + * Transpose to (n, m), multiply by twiddle factors + * Do m n-length FFTs along the rows + + This function assumes that m <= 16 and recurses on n. + The base case is n <= 16 (we are simulating tensor cores of 16x16 mm). + The dft_matrix function is overwriteable + so that we can replace it with learnable parameters in a model. + ''' + assert m <= max_m + + if activation is not None: + act_fn = Activation(activation) + + k = rearrange(k, '... (m n) -> ... m n', m=m, n=n) # (m, n) + + # do n m-length FFTs + if activation is None: + mat = torch.view_as_complex(dft_matrix(m)) + k_f = torch.einsum('... m o, ... o n -> ... m n', mat, k) # (..., m, n) + else: + mat = torch.view_as_complex(dft_matrix(m)) + k_f = torch.view_as_complex(act_fn( + torch.view_as_real(torch.einsum('... m o, ... o n -> ... m n', mat, k)) + )) # (..., m, n) + + # multiply by twiddle factors + twi = torch.view_as_complex(compute_twiddle_factors(n, m)) # (n, m) + k_f = torch.einsum('n m, ... m n -> ... n m', twi, k_f) # (..., n, m) + + if n <= max_m: + # do m n-length FFTs + if activation is None: + mat = torch.view_as_complex(dft_matrix(n)) + k_f = torch.einsum('... n o, ... o m -> ... n m', mat, k_f) # (.., n, m) + else: + mat = torch.view_as_complex(dft_matrix(n)) + k_f = torch.view_as_complex(act_fn( + torch.view_as_real(torch.einsum('... n o, ... o m -> ... n m', mat, k_f)) + )) # (.., n, m) + else: + # recurse + k_f = rearrange(k_f, '... h n m -> ... m h n') + k_f = _cooley_tukey(k_f, n // max_m, max_m, dft_matrix, max_m, activation) + k_f = rearrange(k_f, '... m h n -> ... h n m') + + # reshape for the output + k_f = rearrange(k_f, '... n m -> ... (n m)') # (..., n*m) + + return k_f + +def block_fft( + k, N, + dft_matrix=ref_dft_matrix, + max_m=16, + **kwargs, +): + ''' + Compute the FFT of size N of the vector k, using _block_fft_recurse. + + The dft_matrix function is overwriteable + so that we can replace it with learnable parameters in a model. + ''' + if not math.log(N, 2).is_integer(): + N = int(2 ** math.ceil(math.log(N, 2))) + # pad k with zeros if necessary (e.g. for causality) + if k.shape[-1] != N: + k = nn.ConstantPad1d((0, N - k.shape[-1]), 0)(k) + + if N <= max_m: + mat = torch.view_as_complex(dft_matrix(m)) + return torch.einsum('... n o, ... o -> ... n', mat, k) # (.., n, m) + n = N // max_m + m = max_m + return _cooley_tukey(k, n, m, dft_matrix, max_m, **kwargs) + +class BlockFFT(OptimModule): + ''' + Learnable Block FFT module. + + Args: + learn_dft_matrix (bool): If True, learn a different DFT matrix for lengths 2, 4, 8, and 16. If False, this module computes a normal FFT. + ''' + def __init__(self, learn_dft_matrices=True, H=1, max_m=16, dft_lr=0.001, dropout=0, learn_additive=False, **block_fft_args): + super().__init__() + self.learn_dft_matrices = learn_dft_matrices + self.block_fft_args = block_fft_args + self.max_m=max_m + self.drop = torch.nn.Dropout(p=dropout) + self.learn_additive=learn_additive + # get the powers of 2 up to max_m + assert math.log(max_m, 2).is_integer(), 'max_m must be a power of 2' + + self.powers = [ 2 ** (i + 1) for i in range(int(math.log(max_m, 2))) ] + + if learn_dft_matrices: + assert dft_lr>0,"If learn_dft_matrices=True dft_lr must be positive" + self.dft_matrices = nn.ParameterList() + for n in self.powers: + setattr(self,f"mat_{n}",nn.Parameter( + 0.01 * torch.randn(H, n, n, 2) if self.learn_additive + else ref_dft_matrix(n, H=H), + requires_grad=True)) + self.register(f"mat_{n}",getattr(self,f"mat_{n}"),dft_lr) + self.dft_matrices.append(getattr(self,"mat_{}".format(n))) + + def compute_dft_matrix(self, n): + if not self.learn_dft_matrices: + return ref_dft_matrix(n) + else: + assert n in self.powers + if self.learn_additive: + mat = ref_dft_matrix(n) + return mat + self.drop(self.dft_matrices[int(math.log(n, 2) - 1)]) + else: + return self.drop(self.dft_matrices[int(math.log(n, 2) - 1)]) + + def forward(self, x, N,forward=True): + '''Compute an FFT (forward=True) or iFFT (forward=False) of length N over x.''' + if forward: + return block_fft(x, N, dft_matrix=self.compute_dft_matrix, **self.block_fft_args) + else: + return (1/(N))*torch.conj(block_fft(torch.conj(x), N, dft_matrix=self.compute_dft_matrix, **self.block_fft_args)) + + +if __name__ == "__main__": + B = 128 + H = 29 + N = 8192 + n = 2 + m = 8 + k = torch.randn(B, H, N).to(torch.complex64) + + print(f'(B, H, N) = ({B}, {H}, {N})') + + # test FFT + k_f = block_fft(k, N) + k_f_ref = torch.fft.fft(k, N) + print('L-inf error in FFT: ', torch.max(torch.abs(k_f - k_f_ref)).item()) \ No newline at end of file diff --git a/src/clm/src/models/sequence/ff.py b/src/clm/src/models/sequence/ff.py new file mode 100644 index 00000000..804408dd --- /dev/null +++ b/src/clm/src/models/sequence/ff.py @@ -0,0 +1,50 @@ +""" Implementation of FFN block in the style of Transformers """ + +from functools import partial +from torch import nn +from clm.src.models.sequence.base import SequenceModule +from clm.src.models.nn import LinearActivation, DropoutNd + +class FF(SequenceModule): + def __init__(self, d_input, expand=2, d_output=None, transposed=False, activation='gelu', initializer=None, dropout=0.0, tie_dropout=False): + super().__init__() + self.d_output = d_input if d_output is None else d_output + self.transposed = transposed + d_inner = expand * d_input + + linear1 = LinearActivation( + d_input, d_inner, + transposed=transposed, + activation=activation, + initializer=initializer, + activate=True, + ) + dropout_cls = partial(DropoutNd, transposed=self.transposed) if tie_dropout else nn.Dropout + # dropout_cls = nn.Dropout2d if self.transposed else nn.Dropout + drop = dropout_cls(dropout) if dropout > 0.0 else nn.Identity() + + linear2 = LinearActivation( + d_inner, self.d_output, + transposed=transposed, + activation=None, + initializer=initializer, + activate=False, + ) + + self.ff = nn.Sequential( + linear1, + drop, + linear2, + ) + + def forward(self, x, *args, **kwargs): + return self.ff(x), None + + def step(self, x, state, **kwargs): + # x: [batch, d_input] + if self.transposed: + # expects: [batch, d_input, seq_len] + return self.ff(x.unsqueeze(-1)).squeeze(-1), state + else: + return self.ff(x), state + diff --git a/src/clm/src/models/sequence/h3.py b/src/clm/src/models/sequence/h3.py new file mode 100644 index 00000000..07dc4c89 --- /dev/null +++ b/src/clm/src/models/sequence/h3.py @@ -0,0 +1,206 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + +from clm.src.models.sequence.ssm.ss_kernel import SSKernel + +try: + from clm.src.ops.fftconv import fftconv_func +except ImportError: + fftconv_func = None + + +@torch.jit.script +def mul_sum(q, y): + return (q * y).sum(dim=1) + + +class H3(nn.Module): + + def __init__( + self, + d_model, + d_state=64, + l_max=None, + head_dim=1, + use_fast_fftconv=False, + dropout=0.0, # Just to absorb the kwarg + layer_idx=None, + device=None, dtype=None, + # SSM Kernel arguments + **kernel_args, + ): + """ + d_state: the dimension of the state, also denoted by N + l_max: the maximum kernel length, also denoted by L. Set l_max=None to always use a global kernel + + See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" + + Other options are all experimental and should not need to be configured + """ + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.d_model = d_model + self.head_dim = head_dim + assert d_model % head_dim == 0 + self.H = d_model // head_dim + self.N = d_state + self.L = l_max + self.layer_idx = layer_idx + self.use_fast_fftconv = use_fast_fftconv + if self.use_fast_fftconv: + assert fftconv_func is not None, 'Need to install fftconv' + + self.q_proj = nn.Linear(self.d_model, self.d_model, **factory_kwargs) + self.k_proj = nn.Linear(self.d_model, self.d_model, **factory_kwargs) + self.v_proj = nn.Linear(self.d_model, self.d_model, **factory_kwargs) + + # TODO: SSKernel doesn't take device argument yet + self.ssm_k_kernel = SSKernel(self.d_model, N=d_state, L=self.L, mode='shift', + lr=kernel_args.get('lr', None)) + self.ssm_k_D = nn.Parameter(torch.randn(self.d_model)) + # S4D Kernel + self.kernel = SSKernel(self.H, N=self.N, L=self.L, channels=1, **kernel_args) + self.D = nn.Parameter(torch.randn(self.H, **factory_kwargs)) + + # Pointwise + # position-wise output transform to mix features + # Don't use FusedDense since the layout is H first + self.output_linear = nn.Linear(self.d_model, self.d_model) + + def forward(self, u, inference_params=None): + """ + u: (B L H) + + Returns: same shape as u + """ + L_og = u.size(-2) + if self.use_fast_fftconv and L_og % 2 != 0: + u = F.pad(u, (0, 0, 0, 1)) + L = u.size(-2) + + use_fast_fftconv = self.use_fast_fftconv and inference_params is None + + state_k, state = None, None + if inference_params is not None: + assert self.layer_idx is not None + if self.layer_idx not in inference_params.key_value_memory_dict: + batch_shape = (u.shape[0] * self.head_dim * self.head_dim,) + state_k = self.ssm_k_kernel.default_state(*batch_shape) + state = self.kernel.default_state(*batch_shape) + inference_params.key_value_memory_dict[self.layer_idx] = (state_k, state) + else: + state_k, state = inference_params.key_value_memory_dict[self.layer_idx] + if inference_params.sequence_len_offset == 0: + self.ssm_k_kernel._setup_step() + self.kernel._setup_step() + + if inference_params is not None and inference_params.sequence_len_offset > 0: + y, next_state_k, next_state = self.step(u, state_k, state) + inference_params.key_value_memory_dict[self.layer_idx][0].copy_(next_state_k) + inference_params.key_value_memory_dict[self.layer_idx][1].copy_(next_state) + return y + + # Compute SS Kernel + L_kernel = L if self.L is None else min(L, self.L ) + ssm_kernel, k_state = self.kernel(L=L_kernel, state=state, rate=1.0) # (C H L) (B C H L) + ssm_kernel = rearrange(ssm_kernel, '1 h l -> h l') + + u = rearrange(u, 'b l h -> (b l) h') + dtype = (self.q_proj.weight.dtype if not torch.is_autocast_enabled() + else torch.get_autocast_gpu_dtype()) + q = self.q_proj.weight @ u.T + self.q_proj.bias.to(dtype).unsqueeze(-1) + k = self.k_proj.weight @ u.T + self.k_proj.bias.to(dtype).unsqueeze(-1) + v = self.v_proj.weight @ u.T + self.v_proj.bias.to(dtype).unsqueeze(-1) + q, k, v = [rearrange(x, 'h (b l) -> b h l', l=L) for x in [q, k, v]] + + k_og = k + ssm_k_kernel, _ = self.ssm_k_kernel(L=L_kernel, state=state_k, rate=1.0) # (C H L) (B C H L) + ssm_k_kernel = rearrange(ssm_k_kernel, '1 h l -> h l') + if not use_fast_fftconv: + fft_size = L_kernel + L + ssm_k_kernel_f = torch.fft.rfft(ssm_k_kernel, n=fft_size) # (H 2L) + k_f = torch.fft.rfft(k.to(ssm_kernel.dtype), n=fft_size) # (B H 2L) + shift_k_out = torch.fft.irfft(ssm_k_kernel_f * k_f, n=fft_size)[..., :L] + k = shift_k_out + rearrange(self.ssm_k_D, 'h -> h 1') * k + else: + dropout_mask = None + # No GeLU after the SSM + # We want output_hbl=True so that k has the same layout as q and v for the next + # fftconv + k = fftconv_func(k, ssm_k_kernel, self.ssm_k_D, dropout_mask, False, False, True) + # This line below looks like it doesn't do anything, but it gets the stride right + # for the case batch_size=1. In that case k has stride (L, L, 1), but q and v has + # stride (H * L, L, 1). The two strides are equivalent because batch_size=1, but + # the C++ code doesn't like that. + k = rearrange(rearrange(k, 'b h l -> h b l'), 'h b l -> b h l') + + if not use_fast_fftconv: + fft_size = L_kernel + L + # kv = k * v + kv = (rearrange(k, 'b (h d1) l -> b d1 1 h l', d1=self.head_dim) + * rearrange(v, 'b (h d2) l -> b 1 d2 h l', d2=self.head_dim)) # b d1 d2 h l + kv_f = torch.fft.rfft(kv.to(dtype=ssm_kernel.dtype), n=fft_size) / fft_size + ssm_kernel_f = torch.fft.rfft(ssm_kernel, n=fft_size) # h L+1 + y = torch.fft.irfft(kv_f * ssm_kernel_f, n=fft_size, norm='forward')[..., :L] # b d1 d2 h l + y = y + kv * self.D.unsqueeze(-1) # b d1 d2 h l + q = rearrange(q, 'b (h d1) l -> b d1 1 h l', d1=self.head_dim) + # einsum is way slower than multiply and then sum. + if self.head_dim > 1: + y = mul_sum(y, q) + y = rearrange(y, 'b d h l -> b (d h) l') + else: + y = rearrange(y * q, 'b 1 1 h l -> b h l') + else: + dropout_mask = None + # No GeLU after the SSM + # Set output_hbl_layout=True since we'll be doing a matmul right after + y = fftconv_func(k, ssm_kernel, self.D, + dropout_mask, False, torch.is_autocast_enabled(), True, + v, self.head_dim, q) + + y = rearrange(y, 'b h l -> b l h') + + if state is not None: + assert inference_params is not None + # TODO: This doesn't ever happen? + # if inference_params.sequence_len_offset > 0: + # y = y + k_state + inference_params.key_value_memory_dict[self.layer_idx][0].copy_( + self.ssm_k_kernel.forward_state(k_og, state_k) + ) + inference_params.key_value_memory_dict[self.layer_idx][1].copy_( + self.kernel.forward_state(rearrange(kv, 'b d1 d2 h l -> (b d1 d2) h l'), state) + ) + + # y could be in fp32 because of the SSMs + if not torch.is_autocast_enabled(): + y = y.to(dtype=self.output_linear.weight.dtype) + y = self.output_linear(y) + if L_og < L: + y = y[:, :L_og, :] + + return y + + def step(self, u, state_k, state): + q, k, v = self.q_proj(u), self.k_proj(u), self.v_proj(u) + shift_k, next_state_k = self.ssm_k_kernel.step(rearrange(k, 'b 1 h -> b h'), state_k) + k = shift_k + k * self.ssm_k_D + # kv = k * v + kv = (rearrange(k, 'b 1 (h d1) -> b d1 1 h', d1=self.head_dim) + * rearrange(v, 'b 1 (h d2) -> b 1 d2 h', d2=self.head_dim)) # b d1 d2 h + y, next_state = self.kernel.step(rearrange(kv, 'b d1 d2 h -> (b d1 d2) h'), state) + y = (rearrange(y, '(b d1 d2) 1 h -> b d1 d2 h', d1=self.head_dim, d2=self.head_dim) + + kv * self.D) + q = rearrange(q, 'b 1 (h d1) -> b d1 1 h', d1=self.head_dim) + if self.head_dim > 1: + y = mul_sum(y, q) + y = rearrange(y, 'b d h l -> b (d h) l') + else: + y = rearrange(y * q, 'b 1 1 h -> b 1 h') + # y could be in fp32 because of the SSMs + if not torch.is_autocast_enabled(): + y = y.to(dtype=self.output_linear.weight.dtype) + return self.output_linear(y), next_state_k, next_state diff --git a/src/clm/src/models/sequence/h3_conv.py b/src/clm/src/models/sequence/h3_conv.py new file mode 100644 index 00000000..f2a7f7c0 --- /dev/null +++ b/src/clm/src/models/sequence/h3_conv.py @@ -0,0 +1,150 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + +from clm.src.models.sequence.long_conv_kernel import LongConvKernel + +try: + from clm.src.ops.fftconv import fftconv_func +except ImportError: + fftconv_func = None + + +@torch.jit.script +def mul_sum(q, y): + return (q * y).sum(dim=1) + + +class H3Conv(nn.Module): + + def __init__( + self, + d_model, + l_max=None, + head_dim=1, + use_fast_fftconv=False, + dropout=0.0, # Just to absorb the kwarg + layer_idx=None, + device=None, dtype=None, + # SSM Kernel arguments + **kernel_args, + ): + """ + d_state: the dimension of the state, also denoted by N + l_max: the maximum kernel length, also denoted by L. Set l_max=None to always use a global kernel + + See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" + + Other options are all experimental and should not need to be configured + """ + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.d_model = d_model + self.head_dim = head_dim + assert d_model % head_dim == 0 + self.H = d_model // head_dim + self.L = l_max + self.layer_idx = layer_idx + self.use_fast_fftconv = use_fast_fftconv + if self.use_fast_fftconv: + assert fftconv_func is not None, 'Need to install fftconv' + + self.q_proj = nn.Linear(self.d_model, self.d_model, **factory_kwargs) + self.k_proj = nn.Linear(self.d_model, self.d_model, **factory_kwargs) + self.v_proj = nn.Linear(self.d_model, self.d_model, **factory_kwargs) + self.k_kernel = LongConvKernel( + self.d_model, L=self.L, + **kernel_args) + self.k_D = nn.Parameter(torch.randn(self.d_model)) + self.kernel = LongConvKernel( + self.d_model, L=self.L, + **kernel_args) + self.D = nn.Parameter(torch.randn(self.H, **factory_kwargs)) + + # Pointwise + # position-wise output transform to mix features + # Don't use FusedDense since the layout is H first + self.output_linear = nn.Linear(self.d_model, self.d_model) + + def forward(self, u, inference_params=None): + """ + u: (B L H) + + Returns: same shape as u + """ + L_og = u.size(-2) + if self.use_fast_fftconv and L_og % 2 != 0: + u = F.pad(u, (0, 0, 0, 1)) + L = u.size(-2) + + use_fast_fftconv = self.use_fast_fftconv + + # Compute SS Kernel + ssm_kernel, _ = self.kernel() # (C H L) (B C H L) + ssm_kernel = rearrange(ssm_kernel, '1 h l -> h l') + + u = rearrange(u, 'b l h -> (b l) h') + dtype = (self.q_proj.weight.dtype if not torch.is_autocast_enabled() + else torch.get_autocast_gpu_dtype()) + q = self.q_proj.weight @ u.T + self.q_proj.bias.to(dtype).unsqueeze(-1) + k = self.k_proj.weight @ u.T + self.k_proj.bias.to(dtype).unsqueeze(-1) + v = self.v_proj.weight @ u.T + self.v_proj.bias.to(dtype).unsqueeze(-1) + q, k, v = [rearrange(x, 'h (b l) -> b h l', l=L) for x in [q, k, v]] + + k_og = k + k_kernel, _ = self.k_kernel() # (C H L) (B C H L) + k_kernel = rearrange(k_kernel, '1 h l -> h l') + if not use_fast_fftconv: + fft_size = 2 * L + k_kernel_f = torch.fft.rfft(k_kernel, n=fft_size) # (H 2L) + k_f = torch.fft.rfft(k.to(ssm_kernel.dtype), n=fft_size) # (B H 2L) + shift_k_out = torch.fft.irfft(k_kernel_f * k_f, n=fft_size)[..., :L] + k = shift_k_out + rearrange(self.k_D, 'h -> h 1') * k + else: + dropout_mask = None + # No GeLU after the SSM + # We want output_hbl=True so that k has the same layout as q and v for the next + # fftconv + k = fftconv_func(k, k_kernel, self.k_D, dropout_mask, False, False, True) + # This line below looks like it doesn't do anything, but it gets the stride right + # for the case batch_size=1. In that case k has stride (L, L, 1), but q and v has + # stride (H * L, L, 1). The two strides are equivalent because batch_size=1, but + # the C++ code doesn't like that. + k = rearrange(rearrange(k, 'b h l -> h b l'), 'h b l -> b h l') + + if not use_fast_fftconv: + fft_size = 2 * L + # kv = k * v + kv = (rearrange(k, 'b (h d1) l -> b d1 1 h l', d1=self.head_dim) + * rearrange(v, 'b (h d2) l -> b 1 d2 h l', d2=self.head_dim)) # b d1 d2 h l + kv_f = torch.fft.rfft(kv.to(dtype=ssm_kernel.dtype), n=fft_size) / fft_size + ssm_kernel_f = torch.fft.rfft(ssm_kernel, n=fft_size) # h L+1 + y = torch.fft.irfft(kv_f * ssm_kernel_f, n=fft_size, norm='forward')[..., :L] # b d1 d2 h l + y = y + kv * self.D.unsqueeze(-1) # b d1 d2 h l + q = rearrange(q, 'b (h d1) l -> b d1 1 h l', d1=self.head_dim) + # einsum is way slower than multiply and then sum. + if self.head_dim > 1: + y = mul_sum(y, q) + y = rearrange(y, 'b d h l -> b (d h) l') + else: + y = rearrange(y * q, 'b 1 1 h l -> b h l') + else: + dropout_mask = None + # No GeLU after the SSM + # Set output_hbl_layout=True since we'll be doing a matmul right after + y = fftconv_func(k, ssm_kernel, self.D, + dropout_mask, False, torch.is_autocast_enabled(), True, + v, self.head_dim, q) + + y = rearrange(y, 'b h l -> b l h') + + # y could be in fp32 because of the SSMs + if not torch.is_autocast_enabled(): + y = y.to(dtype=self.output_linear.weight.dtype) + y = self.output_linear(y) + if L_og < L: + y = y[:, :L_og, :] + + return y diff --git a/src/clm/src/models/sequence/hyena.py b/src/clm/src/models/sequence/hyena.py new file mode 100644 index 00000000..3089c549 --- /dev/null +++ b/src/clm/src/models/sequence/hyena.py @@ -0,0 +1,359 @@ +import math + +from re import U +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import partial + +from einops import rearrange, repeat + +try: + from clm.src.ops.fftconv import fftconv_ref, fftconv_func +except ImportError: + fftconv_func = None + +try: + from flash_attn.ops.fused_dense import FusedDense +except ImportError: + FusedDense = None + +import clm.src.utils.registry as registry +from clm.src.utils.train import OptimModule +from clm.src.utils.config import instantiate, auto_assign_attrs +from clm.src.models.nn import Activation + + +# reference convolution with residual connection +def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None): + seqlen = u.shape[-1] + fft_size = 2 * seqlen + k_f = torch.fft.rfft(k, n=fft_size) / fft_size + if k_rev is not None: + k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size + k_f = k_f + k_rev_f.conj() + u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) + + if len(u.shape) > 3: k_f = k_f.unsqueeze(1) + + y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen] + + out = y + u * D.unsqueeze(-1) + if gelu: + out = F.gelu(out) + if dropout_mask is not None: + return (out * rearrange(dropout_mask, 'b H -> b H 1')).to(dtype=u.dtype) + else: + return out.to(dtype=u.dtype) + + +@torch.jit.script +def mul_sum(q, y): + return (q * y).sum(dim=1) + + +class Sin(nn.Module): + def __init__(self, dim, w=10, train_freq=True): + super().__init__() + self.freq = nn.Parameter(w * torch.ones(1, dim)) if train_freq else w * torch.ones(1, dim) + + def forward(self, x): + return torch.sin(self.freq * x) + + +class PositionalEmbedding(OptimModule): + def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float=1e-5, **kwargs): + """Complex exponential positional embeddings for Hyena filters.""" + super().__init__() + + self.seq_len = seq_len + # The time embedding fed to the filteres is normalized so that t_f = 1 + t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1 + + if emb_dim > 1: + bands = (emb_dim - 1) // 2 + # To compute the right embeddings we use the "proper" linspace + t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None] + w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1 + + f = torch.linspace(1e-4, bands - 1, bands)[None, None] + z = torch.exp(-1j * f * w) + z = torch.cat([t, z.real, z.imag], dim=-1) + self.register("z", z, lr=lr_pos_emb) + self.register("t", t, lr=0.0) + + def forward(self, L): + return self.z[:, :L], self.t[:, :L] + + +class ExponentialModulation(OptimModule): + def __init__( + self, + d_model, + fast_decay_pct=0.3, + slow_decay_pct=1.5, + target=1e-2, + modulation_lr=0.0, + modulate: bool=True, + shift: float = 0.0, + **kwargs + ): + super().__init__() + self.modulate = modulate + self.shift = shift + max_decay = math.log(target) / fast_decay_pct + min_decay = math.log(target) / slow_decay_pct + deltas = torch.linspace(min_decay, max_decay, d_model)[None, None] + self.register("deltas", deltas, lr=modulation_lr) + + def forward(self, t, x): + if self.modulate: + decay = torch.exp(-t * self.deltas.abs()) + x = x * (decay + self.shift) + return x + + +class HyenaFilter(OptimModule): + def __init__( + self, + d_model, + emb_dim=3, # dim of input to MLP, augments with positional encoding + order=16, # width of the implicit MLP + fused_fft_conv=False, + seq_len=1024, + lr=1e-3, + lr_pos_emb=1e-5, + dropout=0.0, + w=1, # frequency of periodic activations + wd=0, # weight decay of kernel parameters + bias=True, + num_inner_mlps=2, + normalized=False, + **kwargs + ): + """ + Implicit long filter with modulation. + + Args: + d_model: number of channels in the input + emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands + order: width of the FFN + num_inner_mlps: number of inner linear layers inside filter MLP + + Note: + filter_dropout is not implemented + """ + super().__init__() + self.d_model = d_model + self.use_bias = bias + self.fused_fft_conv = fused_fft_conv + self.bias = nn.Parameter(torch.randn(self.d_model)) + self.dropout = nn.Dropout(dropout) + + act = Sin(dim=order, w=w) + self.emb_dim = emb_dim + assert emb_dim % 2 != 0 and emb_dim >= 3, "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)" + self.seq_len = seq_len + + self.pos_emb = PositionalEmbedding(emb_dim, seq_len, lr_pos_emb) + + # uses a variable number of inner linear layers + self.implicit_filter = nn.Sequential( + nn.Linear(emb_dim, order), + act, + ) + for i in range(num_inner_mlps): + self.implicit_filter.append(nn.Linear(order, order)) + self.implicit_filter.append(act) + # final linear layer + self.implicit_filter.append(nn.Linear(order, d_model, bias=False)) + + self.modulation = ExponentialModulation(d_model, **kwargs) + + self.normalized = normalized + for c in self.implicit_filter.children(): + for name, v in c.state_dict().items(): + optim = {"weight_decay": wd, "lr": lr} + setattr(getattr(c, name), "_optim", optim) + + def filter(self, L, *args, **kwargs): + z, t = self.pos_emb(L) + h = self.implicit_filter(z) + + h = self.modulation(t, h) + + if self.normalized: h = h / torch.norm(h, dim=-1, p=1, keepdim=True) + + return h + + def forward(self, x, L, k=None, bias=None, *args, **kwargs): + if k is None: k = self.filter(L) + + # Ensure compatibility with filters that return a tuple + k = k[0] if type(k) is tuple else k + if bias is None: bias = self.bias + bias = bias if self.use_bias else 0 * bias + + if self.fused_fft_conv: + bias = bias.to(dtype=torch.float32) + y = fftconv_func( + x, k, bias, dropout_mask=None, gelu=False, + force_fp16_output=torch.is_autocast_enabled() + ) + else: + y = fftconv_ref(x, k, bias, dropout_mask=None, gelu=False) + + return y + + +class HyenaOperator(nn.Module): + def __init__( + self, + d_model, + l_max, + order=2, + filter_order=64, + num_heads=1, + inner_factor=1, + num_blocks=1, + fused_bias_fc=False, + outer_mixing=False, + dropout=0.0, + filter_dropout=0.0, + filter_cls='hyena-filter', + post_order_ffn=False, + jit_filter=False, + short_filter_order=3, + activation="id", + return_state=False, + **filter_args, + ): + r""" + Hyena operator described in the paper https://arxiv.org/pdf/2302.10866.pdf + + Args: + d_model (int): Dimension of the input and output embeddings (width of the layer) + l_max: (int): Maximum input sequence length. Defaults to None + order: (int): Depth of the Hyena recurrence. Defaults to 2 + filter_order: (int): Width of the FFN parametrizing the implicit filter. Defaults to 64 + num_heads: (int): Number of heads. Defaults to 1 + inner_factor: (int): Width multiplier. Defaults to 1 + num_blocks: (int): Number of blocks in sequence length. Defaults to 1 + fused_bias_fc: (bool): Whether to use fused bias FC. Defaults to False + dropout: (float): Dropout probability. Defaults to 0.0 + filter_dropout: (float): Dropout probability for the filter. Defaults to 0.0 + post_order_ffn: (bool): Apply a dense layer between steps of the recurrence. Defaults to False + jit_filter: (bool): Whether JIT the implicit filter function. Defaults to False + short_filter_order: (int): Length of the explicit input convolutional filter. Defaults to 3 + activation: (str): type of act between kernel output and FF (default identity) + return_state: (bool): whether to return a state + """ + super().__init__() + assert d_model % num_heads == 0, f'Model dimension {d_model} must be divisible by num heads {num_heads}' + assert l_max % num_blocks == 0, f'Maximum signal length {l_max} must be divisible by block dimension {num_blocks}' + block_dim = l_max // num_blocks + head_dim = d_model // num_heads + + auto_assign_attrs( + self, d_model=d_model, order=order, l_max=l_max, num_heads=num_heads, inner_factor=inner_factor, + block_dim=block_dim, head_dim=head_dim, filter_order=filter_order, post_order_ffn=post_order_ffn, + short_filter_order=short_filter_order, num_blocks = num_blocks, filter_dropout=filter_dropout, + jit_filter=jit_filter, outer_mixing=outer_mixing, activation=activation, return_state=return_state, + ) + self.activation = Activation(activation) + self.dropout = nn.Dropout(dropout) + self.setup_projections(fused_bias_fc, inner_factor) + self.setup_filters(filter_cls, filter_args) + + + def setup_projections(self, fused_bias_fc, inner_factor): + "Initializes input and output projections (over the width dimension)" + if fused_bias_fc and FusedDense is None: + raise ImportError('fused_dense is not installed') + linear_cls = nn.Linear if not fused_bias_fc else FusedDense + self.out_proj = linear_cls(self.d_model * inner_factor, self.d_model) + self.in_proj = linear_cls(self.d_model, (self.order + 1) * self.d_model) + if self.post_order_ffn: + self.ord_proj_w = nn.Parameter(torch.randn(self.order, self.num_heads, self.num_heads) / math.sqrt(self.head_dim)) + + + def setup_filters(self, filter_cls, filter_args): + "Initializes the explicit and implicit filters" + assert self.order >= 2, f'Order must be at least 2, (got {self.order})' + total_width = self.d_model * self.inner_factor * (self.order + 1) + + self.short_filter = nn.Conv1d( + in_channels=total_width, + out_channels=total_width, + kernel_size=self.short_filter_order, + groups=total_width, + padding=self.short_filter_order - 1 + ) + + filter_cls = instantiate(registry.layer, filter_cls, partial=True) + + self.filter_fn = filter_cls( + self.head_dim * self.inner_factor * (self.order - 1), + order=self.filter_order, + seq_len=self.l_max, + channels=1, + dropout=self.filter_dropout, + **filter_args + ) + if self.jit_filter: self.filter_fn = torch.jit.script(self.filter_fn, self.L) + + def recurrence(self, u , state): + "Fast inference mode via distilled recurrence" + raise NotImplementedError("Working on it!") + + def forward(self, u, *args, **kwargs): + l = u.size(-2) + l_filter = min(l, self.l_max) + u = self.in_proj(u) + u = rearrange(u, 'b l d -> b d l') + + uc = self.short_filter(u)[...,:l_filter] + + uc = rearrange(uc, 'b (ho v) (z l) -> b ho v z l', + z=self.num_blocks, + ho=self.num_heads, + v=self.head_dim * (self.order + 1) + ) + + *x, v = uc.split(self.d_model, dim=2) + k = self.filter_fn.filter(l_filter) + + # `c` is always 1 by default + k = rearrange(k, 'c l (v o) -> c o v l', v=self.head_dim, o=self.order - 1)[0] + + bias = rearrange(self.filter_fn.bias, '(v o) -> o v', v=self.head_dim, o=self.order - 1) + + for o, x_i in enumerate(reversed(x[1:])): + if self.outer_mixing: + v = rearrange(v, 'b h v z l -> b h 1 v z l') + v = self.dropout( + v * rearrange(x_i, 'b h v z l -> b h v 1 z l') + ) + v = v.sum(dim=2) + else: + v = self.dropout(v * x_i) + + # the bias term is broadcasted. Last dimension (l) is handled by fftconv + v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o, None, :, None]) + + if self.post_order_ffn: + w = self.ord_proj_w[o] + v = mul_sum( + rearrange(w, 'h1 h2 -> 1 h1 h2 1 1 1'), rearrange(v, 'b h v z l -> b h 1 v z l') + ) + + y = self.activation(rearrange(v * x[0], 'b h v z l -> b (z l) (h v)', z=self.num_blocks, h=self.num_heads)) + y = self.out_proj(y) + + if self.return_state: + return y, None + return y + + @property + def d_output(self): + return self.d_model \ No newline at end of file diff --git a/src/clm/src/models/sequence/hyena_components.py b/src/clm/src/models/sequence/hyena_components.py new file mode 100644 index 00000000..99a55cff --- /dev/null +++ b/src/clm/src/models/sequence/hyena_components.py @@ -0,0 +1,255 @@ +""" +Standalone Hyena components without registry dependencies. +""" + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None): + """Reference convolution with residual connection""" + seqlen = u.shape[-1] + fft_size = 2 * seqlen + k_f = torch.fft.rfft(k, n=fft_size) / fft_size + if k_rev is not None: + k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size + k_f = k_f + k_rev_f.conj() + u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) + + if len(u.shape) > 3: + k_f = k_f.unsqueeze(1) + + y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen] + + out = y + u * D.unsqueeze(-1) + if gelu: + out = F.gelu(out) + if dropout_mask is not None: + return (out * rearrange(dropout_mask, 'b H -> b H 1')).to(dtype=u.dtype) + else: + return out.to(dtype=u.dtype) + + +@torch.jit.script +def mul_sum(q, y): + return (q * y).sum(dim=1) + + +class Sin(nn.Module): + """Sinusoidal activation function""" + def __init__(self, dim, w=10, train_freq=True): + super().__init__() + self.freq = nn.Parameter(w * torch.ones(1, dim)) if train_freq else w * torch.ones(1, dim) + + def forward(self, x): + return torch.sin(self.freq * x) + + +class PositionalEmbedding(nn.Module): + """Complex exponential positional embeddings for Hyena filters""" + def __init__(self, emb_dim: int, seq_len: int, **kwargs): + super().__init__() + + self.seq_len = seq_len + t = torch.linspace(0, 1, self.seq_len)[None, :, None] + + if emb_dim > 1: + bands = (emb_dim - 1) // 2 + + t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None] + w = 2 * math.pi * t_rescaled / seq_len + + f = torch.linspace(1e-4, bands - 1, bands)[None, None] + z = torch.exp(-1j * f * w) + z = torch.cat([t, z.real, z.imag], dim=-1) + + self.register_buffer('z', z) + self.register_buffer('t', t) + + def forward(self, L): + return self.z[:, :L], self.t[:, :L] + + +class ExponentialModulation(nn.Module): + """Exponential modulation for implicit filters""" + def __init__( + self, + d_model, + fast_decay_pct=0.3, + slow_decay_pct=1.5, + target=1e-2, + modulate: bool = True, + shift: float = 0.0, + **kwargs + ): + super().__init__() + self.modulate = modulate + self.shift = shift + max_decay = math.log(target) / fast_decay_pct + min_decay = math.log(target) / slow_decay_pct + deltas = torch.linspace(min_decay, max_decay, d_model)[None, None] + self.register_buffer('deltas', deltas) + + def forward(self, t, x): + if self.modulate: + decay = torch.exp(-t * self.deltas.abs()) + x = x * (decay + self.shift) + return x + + +class HyenaFilter(nn.Module): + """Standalone Hyena filter without registry dependencies""" + def __init__( + self, + d_model, + emb_dim=3, + order=16, + seq_len=1024, + dropout=0.0, + w=1, + bias=True, + num_inner_mlps=2, + **kwargs + ): + super().__init__() + self.d_model = d_model + self.use_bias = bias + self.bias = nn.Parameter(torch.randn(self.d_model)) + self.dropout = nn.Dropout(dropout) + + act = Sin(dim=order, w=w) + self.emb_dim = emb_dim + self.seq_len = seq_len + + self.pos_emb = PositionalEmbedding(emb_dim, seq_len) + + # Build MLP + layers = [nn.Linear(emb_dim, order), act] + for i in range(num_inner_mlps): + layers.append(nn.Linear(order, order)) + layers.append(act) + layers.append(nn.Linear(order, d_model, bias=False)) + + self.implicit_filter = nn.Sequential(*layers) + self.modulation = ExponentialModulation(d_model, **kwargs) + + def filter(self, L): + z, t = self.pos_emb(L) + h = self.implicit_filter(z) + h = self.modulation(t, h) + return h + + def forward(self, x, L, k=None, bias=None): + if k is None: + k = self.filter(L) + + k = k[0] if type(k) is tuple else k + if bias is None: + bias = self.bias + bias = bias if self.use_bias else 0 * bias + + # Use reference implementation + y = fftconv_ref(x, k, bias, dropout_mask=None, gelu=False) + return y + + +class HyenaOperator(nn.Module): + """Standalone Hyena operator without registry dependencies""" + def __init__( + self, + d_model, + l_max, + order=2, + filter_order=64, + num_heads=1, + inner_factor=1, + num_blocks=1, + dropout=0.0, + filter_dropout=0.0, + short_filter_order=3, + **filter_args, + ): + super().__init__() + + assert d_model % num_heads == 0, f'Model dimension {d_model} must be divisible by num heads {num_heads}' + assert l_max % num_blocks == 0, f'Maximum signal length {l_max} must be divisible by block dimension {num_blocks}' + + self.d_model = d_model + self.order = order + self.l_max = l_max + self.num_heads = num_heads + self.inner_factor = inner_factor + self.num_blocks = num_blocks + self.filter_order = filter_order + self.short_filter_order = short_filter_order + self.filter_dropout = filter_dropout + + self.block_dim = l_max // num_blocks + self.head_dim = d_model // num_heads + + self.dropout = nn.Dropout(dropout) + + # Projections + self.out_proj = nn.Linear(self.d_model * inner_factor, self.d_model) + self.in_proj = nn.Linear(self.d_model, (self.order + 1) * self.d_model) + + # Short filter + total_width = self.d_model * self.inner_factor * (self.order + 1) + self.short_filter = nn.Conv1d( + in_channels=total_width, + out_channels=total_width, + kernel_size=self.short_filter_order, + groups=total_width, + padding=self.short_filter_order - 1 + ) + + # Long implicit filter + self.filter_fn = HyenaFilter( + self.head_dim * self.inner_factor * (self.order - 1), + order=self.filter_order, + seq_len=self.l_max, + dropout=self.filter_dropout, + **filter_args + ) + + def forward(self, u): + l = u.size(-2) + l_filter = min(l, self.l_max) + + u = self.in_proj(u) + u = rearrange(u, 'b l d -> b d l') + + uc = self.short_filter(u)[..., :l_filter] + + uc = rearrange( + uc, 'b (ho v) (z l) -> b ho v z l', + z=self.num_blocks, + ho=self.num_heads, + v=self.head_dim * (self.order + 1) + ) + + *x, v = uc.split(self.d_model, dim=2) + k = self.filter_fn.filter(l_filter) + + k = rearrange(k, 'c l (v o) -> c o v l', v=self.head_dim, o=self.order - 1)[0] + bias = rearrange(self.filter_fn.bias, '(v o) -> o v', v=self.head_dim, o=self.order - 1) + + for o, x_i in enumerate(reversed(x[1:])): + v = self.dropout(v * x_i) + v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o, None, :, None]) + + y = rearrange( + v * x[0], 'b h v z l -> b (z l) (h v)', + z=self.num_blocks, + h=self.num_heads + ) + y = self.out_proj(y) + + return y + + @property + def d_output(self): + return self.d_model \ No newline at end of file diff --git a/src/clm/src/models/sequence/long_conv.py b/src/clm/src/models/sequence/long_conv.py new file mode 100644 index 00000000..7b5a53c1 --- /dev/null +++ b/src/clm/src/models/sequence/long_conv.py @@ -0,0 +1,170 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +import opt_einsum as oe + +optimized = True + +if optimized: + contract = oe.contract +else: + contract = torch.einsum + +from clm.src.models.nn import LinearActivation, Activation, DropoutNd +from clm.src.models.sequence.block_fft import BlockFFT +from clm.src.models.sequence.long_conv_kernel import LongConvKernel + +class LongConv(nn.Module): + def __init__( + self, + d_model, + l_max=1024, + channels=1, + bidirectional=False, + # Arguments for position-wise feedforward components + activation='gelu', # activation between conv and FF + postact='glu', # activation after FF + initializer=None, # initializer on FF + weight_norm=False, # weight normalization on FF + dropout=0.0, tie_dropout=False, + transposed=True, # axis ordering (B, L, D) or (B, D, L) + verbose=False, + block_fft_conv=False, # replace the FFT conv with Monarch blocks + block_fft_conv_args={}, + + # SSM Kernel arguments + **kernel_args, + ): + """ + d_state: the dimension of the state, also denoted by N + l_max: the maximum kernel length, also denoted by L + channels: can be interpreted as a number of "heads"; the SSM is a map from a 1-dim to C-dim sequence. It's not recommended to change this unless desperate for things to tune; instead, increase d_model for larger models + bidirectional: if True, convolution kernel will be two-sided + + Position-wise feedforward components: + -------------------- + activation: activation in between SS and FF + postact: activation after FF ('id' for no activation, None to remove FF layer) + initializer: initializer on FF + weight_norm: weight normalization on FF + dropout: standard dropout argument. tie_dropout=True ties the dropout mask across the sequence length, emulating nn.Dropout1d + + Other arguments: + -------------------- + transposed: choose backbone axis ordering of (B, L, H) (if False) or (B, H, L) (if True) [B=batch size, L=sequence length, H=hidden dimension] + """ + + super().__init__() + if verbose: + import clm.src.utils.train + log = clm.src.utils.train.get_logger(__name__) + log.info(f"Constructing Long Conv (H, L) = ({d_model}, {l_max})") + + self.d_model = d_model + self.H = d_model + self.L = l_max + self.bidirectional = bidirectional + self.channels = channels + self.transposed = transposed + self.block_fft_conv = block_fft_conv + self.block_fft_conv_args = block_fft_conv_args + + self.D = nn.Parameter(torch.randn(channels, self.H)) + + if self.bidirectional: + channels *= 2 + + # SSM Kernel + self.kernel = LongConvKernel(self.H, L=self.L, channels=channels, verbose=verbose, **kernel_args) + + if self.block_fft_conv: + self.block_fft_u = BlockFFT(**self.block_fft_conv_args) + self.block_fft_k = BlockFFT(**self.block_fft_conv_args) + + # Pointwise + self.activation = Activation(activation) + # dropout_fn = nn.Dropout2d if self.transposed else nn.Dropout # Broken in torch==1.11 + dropout_fn = DropoutNd if tie_dropout else nn.Dropout + self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() + + # position-wise output transform to mix features + if postact is None: + self.output_linear = nn.Identity() + else: + self.output_linear = LinearActivation( + self.d_model * self.channels, + self.d_model, + # self.H*self.channels, + # self.d_model*(1 if self.gate is None else self.gate), + transposed=self.transposed, + initializer=initializer, + activation=postact, + activate=True, + weight_norm=weight_norm, + ) + + + + def forward(self, u, state=None, rate=1.0, lengths=None, **kwargs): # absorbs return_output and transformer src mask + """ + u: (B H L) if self.transposed else (B L H) + state: (H N) never needed, remnant from state spaces repo + + Returns: same shape as u + """ + if not self.transposed: u = u.transpose(-1, -2) + L = u.size(-1) + # Mask out padding tokens + # TODO handle option for mask - instead of lengths, which assumes suffix padding + if isinstance(lengths, int): + if lengths != L: + lengths = torch.tensor(lengths, dtype=torch.long, device=u.device) + else: + lengths = None + if lengths is not None: + assert isinstance(lengths, torch.Tensor) and lengths.ndim == 1 and lengths.size(0) in [1, u.size(0)] + mask = torch.where(torch.arange(L, device=lengths.device) < lengths[:, None, None], 1., 0.) + u = u * mask + + # Compute SS Kernel + L_kernel = L if self.L is None else min(L, round(self.L / rate)) + k, _ = self.kernel(L=L_kernel, rate=rate, state=state) # (C H L) (B C H L) + + # Convolution + if self.bidirectional: + k0, k1 = rearrange(k, '(s c) h l -> s c h l', s=2) + k = F.pad(k0, (0, L)) \ + + F.pad(k1.flip(-1), (L, 0)) + + if self.block_fft_conv: + k_f = self.block_fft_k(k.to(torch.complex64), N=L_kernel+L) # (C H L) + u_f = self.block_fft_u(u.to(torch.complex64), N=L_kernel+L) # (B H L) + y_f = contract('bhl,chl->bchl', u_f, k_f) + y = torch.fft.ifft(y_f, n=L_kernel+L, dim=-1).real[..., :L] # (B C H L) + else: + k_f = torch.fft.rfft(k, n=L_kernel+L) # (C H L) + u_f = torch.fft.rfft(u, n=L_kernel+L) # (B H L) + y_f = contract('bhl,chl->bchl', u_f, k_f) + y = torch.fft.irfft(y_f, n=L_kernel+L)[..., :L] # (B C H L) + + # Compute skip connection + y = y + contract('bhl,ch->bchl', u, self.D) + + # Reshape to flatten channels + y = rearrange(y, '... c h l -> ... (c h) l') + + if not self.transposed: y = y.transpose(-1, -2) + y = self.activation(y) + y = self.dropout(y) + y = self.output_linear(y) + + return y, None + + @property + def d_state(self): + return self.H + + @property + def d_output(self): + return self.d_model diff --git a/src/clm/src/models/sequence/long_conv_kernel.py b/src/clm/src/models/sequence/long_conv_kernel.py new file mode 100644 index 00000000..f54b8b88 --- /dev/null +++ b/src/clm/src/models/sequence/long_conv_kernel.py @@ -0,0 +1,82 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import repeat + +from clm.src.utils.train import OptimModule + +class LongConvKernel(OptimModule): + def __init__( + self, + H, + L, + channels=1, + learning_rate=None, + lam=0.1, + causal=True, + kernel_dropout=0, + weight_init="random", + use_ma_smoothing = False, + ma_window_len = 7, + smooth_freq = False, + **kwargs + ): + super().__init__() + + self.drop = torch.nn.Dropout(p=kernel_dropout) + self.H = H + self.weight_init = weight_init + self.causal = causal + self.L = L*2 if not causal else L + + self.channels = channels + self.lam = lam + self.kernel = torch.nn.Parameter(self._parameter_initialization()) #(c,H,L) + + self.register("kernel", self.kernel, learning_rate) + + self.use_ma_smoothing=use_ma_smoothing + self.smooth_freq = smooth_freq + self.ma_window_len = ma_window_len + if self.use_ma_smoothing: + if smooth_freq: + weight = torch.arange(ma_window_len, dtype = self.kernel.dtype) + weight = torch.exp(-0.5 * torch.abs(weight - ma_window_len // 2) ** 2) + weight = repeat(weight, 'l -> h1 h2 l', h1 = self.H, h2 = 1) + weight = weight.type(torch.fft.rfft(self.kernel).dtype) + self.smooth_weight = weight + else: + self.ma_window_len = ma_window_len + assert self.ma_window_len%2!=0, "window size must be odd" + padding = (self.ma_window_len//2) + self.smooth = torch.nn.AvgPool1d(kernel_size=self.ma_window_len,stride=1,padding=padding) + + def _parameter_initialization(self): + if self.weight_init=="random": + return torch.randn(self.channels, self.H, self.L) * 0.002 + elif self.weight_init=="double_exp": + K = torch.randn(self.channels, self.H, self.L,dtype=torch.float32) * 0.02 + double_exp = torch.zeros((self.H,self.L),dtype=torch.float32) + for i in range(self.H): + for j in range(self.L): + double_exp[i,j] = torch.exp(-(j/self.L)*torch.pow(torch.tensor(int(self.H/2)),torch.tensor(i/self.H))) + K = torch.einsum("c h l, h l -> c h l",K,double_exp) + return K + else: raise NotImplementedError(f"{self.weight_init} is not valid") + + def forward(self, **kwargs): + k = self.kernel + if self.use_ma_smoothing: + if self.smooth_freq: + k_f = torch.fft.rfft(k, dim=-1) + k_f = F.conv1d(k_f, self.smooth_weight.to(k_f.device), padding='same', groups=self.H) + k = torch.fft.irfft(k_f, dim=-1) + else: + k = self.smooth(k) + k = F.relu(torch.abs(k)-self.lam)*torch.sign(k) + k = self.drop(k) + return k, None + + @property + def d_output(self): + return self.H \ No newline at end of file diff --git a/src/clm/src/models/sequence/long_conv_lm.py b/src/clm/src/models/sequence/long_conv_lm.py new file mode 100644 index 00000000..88222e5e --- /dev/null +++ b/src/clm/src/models/sequence/long_conv_lm.py @@ -0,0 +1,397 @@ +# Copyright (c) 2023, Tri Dao, Dan Fu. + +import copy +import math +import re +from functools import partial + +from collections import namedtuple, OrderedDict +from collections.abc import Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from transformers.models.gpt2.configuration_gpt2 import GPT2Config + +from einops import rearrange + +from flash_attn.modules.mha import MHA, ParallelMHA +from flash_attn.modules.mlp import Mlp, FusedMLP, ParallelFusedMLP +from flash_attn.modules.block import Block +from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings +from flash_attn.utils.generation import GenerationMixin +from flash_attn.utils.distributed import sync_shared_params, all_gather_raw + +try: + from flash_attn.ops.fused_dense import ColumnParallelLinear +except ImportError: + ColumnParallelLinear = None + +try: + from flash_attn.ops.layer_norm import dropout_add_layer_norm +except ImportError: + dropout_add_layer_norm = None + +from clm.src.utils import instantiate +import clm.src.utils.registry as registry + +def create_mixer_cls(layer=None, process_group=None, + attn_layer_idx=None, attn_cfg=None, layer_idx=None, + sequence_parallel=True, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + parallel_kwargs = ({'process_group': process_group, 'sequence_parallel': sequence_parallel} + if process_group is not None else {}) + if attn_layer_idx is not None and layer_idx in attn_layer_idx: + causal = True if attn_cfg is None else attn_cfg.pop('causal', True) + fused_bias_fc = False if attn_cfg is None else attn_cfg.get('fused_bias_fc', False) + if not fused_bias_fc: + assert process_group is None, 'TensorParallel MHA requires fused_bias_fc' + mha_cls = MHA if process_group is None else ParallelMHA + # ParallelMHA doesn't take 'fused_bias_fc', it is assumed that we fuse matmul + bias + if process_group is not None: + attn_cfg = copy.deepcopy(attn_cfg) # Don't modify the original cfg + attn_cfg.pop('fused_bias_fc', None) + mixer_cls = partial(mha_cls, causal=causal, layer_idx=layer_idx, + **(attn_cfg if attn_cfg is not None else {}), + **parallel_kwargs, **factory_kwargs) + else: + fused_bias_fc = False if layer is None else layer.get('fused_bias_fc', False) + if process_group is not None: + assert fused_bias_fc, 'TensorParallel SSM requires fused_bias_fc' + mixer_cls = instantiate(registry.layer, layer, partial=True, layer_idx=layer_idx, **factory_kwargs, **parallel_kwargs) + # mixer_cls = partial(ssm_cls, layer_idx=layer_idx, + # **(ssm_cfg if ssm_cfg is not None else {}), + # **parallel_kwargs, **factory_kwargs) + return mixer_cls + + +def create_mlp_cls(d_model, d_inner=None, process_group=None, fused_mlp=False, + sequence_parallel=True, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + inner_dim = d_inner if d_inner is not None else 4 * d_model + if process_group is not None: + assert fused_mlp, 'Tensor Parallel is only implemented for FusedMLP' + if not fused_mlp: + mlp_cls = partial(Mlp, hidden_features=inner_dim, + activation=partial(F.gelu, approximate='tanh'), **factory_kwargs) + else: + mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP + parallel_kwargs = ({'process_group': process_group, 'sequence_parallel': sequence_parallel} + if process_group is not None else {}) + mlp_cls = partial(mlp_cls, hidden_features=inner_dim, **parallel_kwargs, **factory_kwargs) + return mlp_cls + + +def create_block(d_model, d_inner=None, process_group=None, + layer=None, attn_layer_idx=None, + attn_cfg=None, layer_norm_epsilon=1e-5, + resid_dropout1=0.0, resid_dropout2=0.0, residual_in_fp32=False, + fused_mlp=False, fused_dropout_add_ln=False, layer_idx=None, + sequence_parallel=True, + device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + mixer_cls = create_mixer_cls(layer=layer, process_group=process_group, + attn_layer_idx=attn_layer_idx, + attn_cfg=attn_cfg, layer_idx=layer_idx, + sequence_parallel=sequence_parallel, + **factory_kwargs) + mlp_cls = create_mlp_cls(d_model, d_inner=d_inner, process_group=process_group, + fused_mlp=fused_mlp, sequence_parallel=sequence_parallel, + **factory_kwargs) + norm_cls = partial(nn.LayerNorm, eps=layer_norm_epsilon, **factory_kwargs) + block = Block(d_model, mixer_cls, mlp_cls, norm_cls=norm_cls, + prenorm=True, resid_dropout1=resid_dropout1, resid_dropout2=resid_dropout2, + fused_dropout_add_ln=fused_dropout_add_ln, residual_in_fp32=residual_in_fp32, + sequence_parallel=sequence_parallel and process_group is not None, + mark_shared_params=process_group is not None) + block.layer_idx = layer_idx + return block + + +# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 +def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True, + glu_act=False): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, std=initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=initializer_range) + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight", "fc2.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer)) + # If using GLU activation for now, we scale the std by 2 + elif name in ["output_linear.0.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + if not glu_act: + nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer)) + else: + out_features = p.shape[0] + # Multiplying the first half of the matrix by 2 since sigmoid scales it down by 0.5 + # on average. + nn.init.normal_(p[:out_features // 2], mean=0.0, std=initializer_range / math.sqrt(2 * n_layer) * 2) + + +class LMBackbone(nn.Module): + + def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_size: int, + process_group=None, layer=None, + attn_layer_idx=None, attn_cfg=None, max_position_embeddings=0, + resid_dropout: float = 0.0, embed_dropout: float = 0.1, dropout_cls=nn.Dropout, + layer_norm_epsilon: float = 1e-5, initializer_cfg=None, + fused_mlp=False, fused_dropout_add_ln=False, residual_in_fp32=False, + sequence_parallel=True, + device=None, dtype=None, **kwargs) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.process_group = process_group + self.sequence_parallel = sequence_parallel + self.residual_in_fp32 = residual_in_fp32 + + if process_group is None: + self.embeddings = GPT2Embeddings(d_model, vocab_size, max_position_embeddings, + **factory_kwargs) + else: + self.embeddings = ParallelGPT2Embeddings( + d_model, vocab_size, max_position_embeddings, + process_group=process_group, sequence_parallel=self.sequence_parallel, + **factory_kwargs + ) + + # We change the order of dropout, residual and layer norm: + # Instead of LN -> Attn / MLP -> Dropout -> Add, we do: + # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and + # the main branch (output of MLP). The model definition is unchanged, but the mapping of the + # nn.Dropout probabilities are changed. + # This is for performance reason: we can fuse dropout + add + layer_norm. + self.fused_dropout_add_ln = fused_dropout_add_ln + if self.fused_dropout_add_ln and dropout_add_layer_norm is None: + raise ImportError('dropout_add_layer_norm is not installed') + + self.layers = nn.ModuleList([create_block( + d_model, d_inner=d_inner, process_group=process_group, + layer=layer, attn_layer_idx=attn_layer_idx, + attn_cfg=attn_cfg, layer_norm_epsilon=layer_norm_epsilon, + resid_dropout1=embed_dropout if i == 0 else resid_dropout, + resid_dropout2=resid_dropout, residual_in_fp32=residual_in_fp32, + fused_mlp=fused_mlp, fused_dropout_add_ln=fused_dropout_add_ln, layer_idx=i, + sequence_parallel=self.sequence_parallel, + **factory_kwargs, + ) for i in range(n_layer)]) + + self.drop_f = nn.Dropout(resid_dropout) + self.ln_f = nn.LayerNorm(d_model, eps=layer_norm_epsilon, **factory_kwargs) + + if process_group is not None: + for p in self.ln_f.parameters(): + # Mark the norm parameters as "shared_params" so that we sync their values at init. + p._shared_params = True + # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads. + if self.sequence_parallel: + p._sequence_parallel = True + + self.apply(partial(_init_weights, n_layer=n_layer, + **(initializer_cfg if initializer_cfg is not None else {}))) + self.tie_weights() + + def tie_weights(self): + if self.process_group is not None: + sync_shared_params(self, self.process_group) + + def forward(self, input_ids, position_ids=None, inference_params=None): + # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen + # dimensions so that we can split on it easily, in case of small batch size. + # Only the attention/SSM layers need to know the seqlen. + embedding_kwargs = ({'combine_batch_seqlen_dim': True} + if self.process_group is not None and self.sequence_parallel else {}) + hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs) + residual = None + mixer_kwargs = ({'seqlen': input_ids.shape[1]} + if self.process_group is not None and self.sequence_parallel else {}) + if inference_params is not None: + mixer_kwargs['inference_params'] = inference_params + for layer in self.layers: + hidden_states, residual = layer(hidden_states, residual, mixer_kwargs=mixer_kwargs) + if not self.fused_dropout_add_ln: + dropped = self.drop_f(hidden_states) + residual = (dropped + residual) if residual is not None else dropped + hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype)) + else: + # Set prenorm=False here since we don't need the residual + hidden_states = dropout_add_layer_norm( + hidden_states, residual, self.ln_f.weight, self.ln_f.bias, + self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False, + residual_in_fp32=self.residual_in_fp32 + ) + return hidden_states + + +class ConvLMHeadModel(nn.Module, GenerationMixin): + + def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_size: int, + process_group=None, layer=None, + attn_layer_idx=None, attn_cfg=None, max_position_embeddings=0, + resid_dropout: float = 0.0, embed_dropout: float = 0.1, dropout_cls=nn.Dropout, + layer_norm_epsilon: float = 1e-5, initializer_cfg=None, + fused_mlp=False, fused_dropout_add_ln=False, residual_in_fp32=False, + pad_vocab_size_multiple: int = 1, sequence_parallel=True, + device=None, dtype=None, **kwargs) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.process_group = process_group + if vocab_size % pad_vocab_size_multiple != 0: + vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) + self.backbone = LMBackbone( + d_model=d_model, n_layer=n_layer, d_inner=d_inner, vocab_size=vocab_size, + process_group=process_group, + layer=layer, attn_layer_idx=attn_layer_idx, attn_cfg=attn_cfg, + max_position_embeddings=max_position_embeddings, + resid_dropout=resid_dropout, embed_dropout=embed_dropout, + dropout_cls=dropout_cls, layer_norm_epsilon=layer_norm_epsilon, + initializer_cfg=initializer_cfg, fused_mlp=fused_mlp, + fused_dropout_add_ln=fused_dropout_add_ln, residual_in_fp32=residual_in_fp32, + sequence_parallel=sequence_parallel, + **factory_kwargs, **kwargs + ) + if process_group is None: + self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) + else: + if ColumnParallelLinear is None: + raise ImportError('fused_dense_lib is not installed') + self.lm_head = ColumnParallelLinear( + d_model, vocab_size, process_group, bias=False, + sequence_parallel=sequence_parallel, **factory_kwargs + ) + # Initialize weights and apply final processing + self.apply(partial(_init_weights, n_layer=n_layer, + **(initializer_cfg if initializer_cfg is not None else {}))) + self.tie_weights() + + def tie_weights(self): + self.lm_head.weight = self.backbone.embeddings.word_embeddings.weight + if self.process_group is not None: + sync_shared_params(self, self.process_group) + + def forward(self, input_ids, position_ids=None, inference_params=None, state=None): # state for the repo interface + hidden_states = self.backbone(input_ids, position_ids=position_ids, + inference_params=inference_params) + lm_logits = self.lm_head(hidden_states) + # During inference, we want the full logit for sampling + if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None: + lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group) + lm_logits = rearrange(lm_logits, '(n b) s d -> b s (n d)', b=hidden_states.shape[0]) + CausalLMOutput = namedtuple('CausalLMOutput', ['logits']) + return CausalLMOutput(logits=lm_logits), None + + def load_state_dict(self, state_dict, strict=True): + # Remapping from our checkpoints that used different names + def key_mapping_backbone(key): + key = re.sub(r'^s4seq.encoder.', 'backbone.', key) + key = re.sub(r'^embedding.', 'backbone.embeddings.word_embeddings.', key) + key = re.sub(r'^backbone.norm', 'backbone.ln_0', key) + key = re.sub(r'^backbone.layers.(\d+).mixer.output_linear.', + r'backbone.layers.\1.mixer.out_proj.', key) + return key + state_dict = OrderedDict((key_mapping_backbone(k), v) for k, v in state_dict.items()) + # Remapping from our checkpoints that used a different ordering of layers in the block + # Previous: Mixer / MLP -> Dropout -> Add -> LN + # Current: Dropout -> Add -> LN -> Attn / MLP + if 'backbone.ln_0.weight' in state_dict: + n_layers = len(self.backbone.layers) + ln_weight = state_dict.pop(f'backbone.layers.{n_layers - 1}.norm2.weight') + ln_bias = state_dict.pop(f'backbone.layers.{n_layers - 1}.norm2.bias') + state_dict['backbone.ln_f.weight'] = ln_weight + state_dict['backbone.ln_f.bias'] = ln_bias + for l in reversed(range(n_layers)): + ln_weight = state_dict.pop(f'backbone.layers.{l}.norm1.weight') + ln_bias = state_dict.pop(f'backbone.layers.{l}.norm1.bias') + state_dict[f'backbone.layers.{l}.norm2.weight'] = ln_weight + state_dict[f'backbone.layers.{l}.norm2.bias'] = ln_bias + if l > 0: + ln_weight = state_dict.pop(f'backbone.layers.{l - 1}.norm2.weight') + ln_bias = state_dict.pop(f'backbone.layers.{l - 1}.norm2.bias') + state_dict[f'backbone.layers.{l}.norm1.weight'] = ln_weight + state_dict[f'backbone.layers.{l}.norm1.bias'] = ln_bias + ln_weight = state_dict.pop('backbone.ln_0.weight') + ln_bias = state_dict.pop('backbone.ln_0.bias') + state_dict[f'backbone.layers.0.norm1.weight'] = ln_weight + state_dict[f'backbone.layers.0.norm1.bias'] = ln_bias + # Previously we have separate projection matrices for q, k, v, now we stack them + if 'backbone.layers.0.mixer.q_proj.weight' in state_dict: + n_layers = len(self.backbone.layers) + for l in range(n_layers): + if f'backbone.layers.{l}.mixer.q_proj.weight' in state_dict: + Wq = state_dict.pop(f'backbone.layers.{l}.mixer.q_proj.weight') + Wk = state_dict.pop(f'backbone.layers.{l}.mixer.k_proj.weight') + Wv = state_dict.pop(f'backbone.layers.{l}.mixer.v_proj.weight') + bq = state_dict.pop(f'backbone.layers.{l}.mixer.q_proj.bias') + bk = state_dict.pop(f'backbone.layers.{l}.mixer.k_proj.bias') + bv = state_dict.pop(f'backbone.layers.{l}.mixer.v_proj.bias') + state_dict[f'backbone.layers.{l}.mixer.Wqkv.weight'] = torch.cat( + [Wq, Wk, Wv], dim=0 + ) + state_dict[f'backbone.layers.{l}.mixer.Wqkv.bias'] = torch.cat( + [bq, bk, bv], dim=0 + ) + return super().load_state_dict(state_dict, strict=strict) + + +def shard_state_dict_tp(state_dict, world_size, rank, pad_vocab_size_multiple=1): + """Convert the state_dict of a standard SSM model to the state_dict of a SSM model + with tensor parallel. + """ + layer_idx_match = [re.search(r'backbone\.layers\.(\d+)\.', k) for k in state_dict.keys()] + num_hidden_layers = len(set(m.group(1) for m in layer_idx_match if m is not None)) + vocab_size = state_dict['backbone.embeddings.word_embeddings.weight'].shape[0] + inner_dim, hidden_size = state_dict['backbone.layers.0.mlp.fc1.weight'].shape + vocab_size = (math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) + assert vocab_size % world_size == 0 + assert hidden_size % world_size == 0 + assert inner_dim % world_size == 0 + + def shard_dim(state_dict, key, dim=0): + x = state_dict[key] + dimension = x.shape[dim] // world_size + state_dict[key] = x.narrow(dim, rank * dimension, dimension) + + def shard_qkv_headdim(state_dict, key): + x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3) + dim = x.shape[1] // world_size + state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim], + 'three d ... -> (three d) ...') + + shard_dim(state_dict, 'backbone.embeddings.word_embeddings.weight', 0) + if 'lm_head.weight' in state_dict: + shard_dim(state_dict, 'lm_head.weight', 0) + if 'backbone.embeddings.position_embeddings.weight' in state_dict: + shard_dim(state_dict, 'backbone.embeddings.position_embeddings.weight', -1) + for i in range(num_hidden_layers): + shard_qkv_headdim(state_dict, f'backbone.layers.{i}.mixer.Wqkv.weight') + shard_qkv_headdim(state_dict, f'backbone.layers.{i}.mixer.Wqkv.bias') + shard_dim(state_dict, f'backbone.layers.{i}.mixer.out_proj.weight', -1) + if rank != 0: + state_dict.pop(f'backbone.layers.{i}.mixer.out_proj.bias') + shard_dim(state_dict, f'backbone.layers.{i}.mlp.fc1.weight', 0) + shard_dim(state_dict, f'backbone.layers.{i}.mlp.fc1.bias', 0) + shard_dim(state_dict, f'backbone.layers.{i}.mlp.fc2.weight', -1) + if rank != 0: + state_dict.pop(f'backbone.layers.{i}.mlp.fc2.bias') + if f'backbone.layers.{i}.mixer.kernel.kernel.B' in state_dict: + for name in ['D', 'ssm_k_D', 'kernel.kernel.B', 'kernel.kernel.inv_A_real', + 'kernel.kernel.A_imag', 'ssm_k_kernel.kernel.B', 'kernel.kernel.log_dt']: + if f'backbone.layers.{i}.mixer.{name}' in state_dict: + shard_dim(state_dict, f'backbone.layers.{i}.mixer.{name}', 0) + for name in ['kernel.kernel.C', 'ssm_k_kernel.kernel.C']: + if f'backbone.layers.{i}.mixer.{name}' in state_dict: + shard_dim(state_dict, f'backbone.layers.{i}.mixer.{name}', 1) + return state_dict diff --git a/src/clm/src/models/sequence/mha.py b/src/clm/src/models/sequence/mha.py new file mode 100644 index 00000000..12d55fda --- /dev/null +++ b/src/clm/src/models/sequence/mha.py @@ -0,0 +1,122 @@ +""" Wrapper around nn.MultiheadAttention to adhere to SequenceModule interface. """ + +import torch +import torch.nn.functional as F +from torch import nn +import hydra +from clm.src.models.sequence.base import SequenceModule, TransposedModule +import clm.src.models.nn.utils as U +from einops import rearrange + +@TransposedModule +class MultiheadAttention(SequenceModule): + """ Simple wrapper for MultiheadAttention """ + def __init__(self, d_model, n_heads, *args, causal=True, **kwargs): + super().__init__() + self.d_model = d_model + self.d_output = d_model + self.mha = nn.MultiheadAttention(d_model, n_heads, *args, batch_first=True, **kwargs) + self.causal = causal + + def forward(self, src, attn_mask=None, key_padding_mask=None, state=None, **kwargs): + """ state should represent a mask and key padding mask """ + if self.causal and attn_mask is None: + attn_mask = torch.triu(torch.ones(src.size(-2), clm.src.size(-2), + dtype=torch.bool, device=src.device), + diagonal=1) + # attn_mask, key_padding_mask = state + # Note that this returns None for the second argument + y, _ = self.mha(src, src, src, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False) + return y, None + + def step(self, x, state): + # TODO proper cached inference + # x: (B, D) + pass + + +class VitAttention(SequenceModule): + """Copied from implementation for ViT: only used for ViT model + + This attention class makes several simplifying assumptions (commonly satisfied in vision + applications): + 1. q = k = v + 2. No masks: no attention mask, no key padding mask + 3. Embed dimension = Input dimension, i.e. projection matrices are square. + """ + + @property + def d_output(self): + return self.dim + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_scale=None, + attn_drop=0., + # proj_drop=0., + packed_linear=True, + linear_cfg=None, + **kwargs, + ): + """packed_linear: whether to pack all 3 q_proj, k_proj, v_proj into 2 matrix. + This option is to be compatible with T2T-ViT pretrained weights, where there's only one + projection weight matrix. + """ + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = dim // num_heads + + self.scale = qk_scale or head_dim ** -0.5 + + if linear_cfg is not None: + packed_linear = False + self.packed_linear = packed_linear + if packed_linear: + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + else: + if linear_cfg is None: + linear_cfg = {'_target_': 'torch.nn.Linear'} + self.q_proj = hydra.utils.instantiate(linear_cfg, dim, dim, bias=qkv_bias, + _recursive_=False) + self.k_proj = hydra.utils.instantiate(linear_cfg, dim, dim, bias=qkv_bias, + _recursive_=False) + self.v_proj = hydra.utils.instantiate(linear_cfg, dim, dim, bias=qkv_bias, + _recursive_=False) + + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + # Removing this dropout because we do this in SequenceResidualBlock + # self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, state=None): + B, N, C = x.shape + if self.packed_linear: + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + else: + q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x) + q, k, v = [rearrange(x, 'b n (h d) -> b h n d', h=self.num_heads) for x in (q, k, v)] + + # attn = (q @ k.transpose(-2, -1) * self.scale) + # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) + bsz, num_heads, q_seq_len, dk = q.size() + _, _, k_seq_len, _ = k.size() + q = rearrange(q, 'b h t d -> (b h) t d') + k = rearrange(k, 'b h s d -> (b h) d s') + # Preallocate attn_weights for `baddbmm` + attn = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=q.dtype, device=q.device) + attn = rearrange(torch.baddbmm(attn, q, k, beta=0, alpha=self.scale), + '(b h) t s -> b h t s', h = self.num_heads) + + attn = F.softmax(attn, dim=-1, dtype=v.dtype) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + # x = self.proj_drop(x) + return x, None diff --git a/src/clm/src/models/sequence/model.py b/src/clm/src/models/sequence/model.py new file mode 100644 index 00000000..930c19c6 --- /dev/null +++ b/src/clm/src/models/sequence/model.py @@ -0,0 +1,134 @@ +""" Isotropic deep sequence model backbone, in the style of ResNets / Transformers. + +The SequenceModel class implements a generic (batch, length, d_input) -> (batch, length, d_output) transformation +""" + +from functools import partial + +import torch +import torch.nn as nn +from einops import rearrange + +from clm.src.utils.config import to_list, to_dict +from clm.src.models.sequence.block import SequenceResidualBlock +from clm.src.models.sequence.base import SequenceModule +from clm.src.models.nn.components import Normalization, DropoutNd + + +class SequenceModel(SequenceModule): + def __init__( + self, + d_model, # Resize input (useful for deep models with residuals) + n_layers=1, # Number of layers + transposed=False, # Transpose inputs so each layer receives (batch, dim, length) + dropout=0.0, # Dropout parameter applied on every residual and every layer + tie_dropout=False, # Tie dropout mask across sequence like nn.Dropout1d/nn.Dropout2d + prenorm=True, # Pre-norm vs. post-norm + n_repeat=1, # Each layer is repeated n times per stage before applying pooling + layer=None, # Layer config, must be specified + residual=None, # Residual config + norm=None, # Normalization config (e.g. layer vs batch) + pool=None, # Config for pooling layer per stage + track_norms=True, # Log norms of each layer output + dropinp=0.0, # Input dropout + ): + super().__init__() + # Save arguments needed for forward pass + self.d_model = d_model + self.transposed = transposed + self.track_norms = track_norms + + # Input dropout (not really used) + dropout_fn = partial(DropoutNd, transposed=self.transposed) if tie_dropout else nn.Dropout + self.drop = dropout_fn(dropinp) if dropinp > 0.0 else nn.Identity() + + layer = to_list(layer, recursive=False) + + # Some special arguments are passed into each layer + for _layer in layer: + # If layers don't specify dropout, add it + if _layer.get('dropout', None) is None: + _layer['dropout'] = dropout + # Ensure all layers are shaped the same way + _layer['transposed'] = transposed + + # Duplicate layers + layers = layer * n_layers * n_repeat + + # Instantiate layers + _layers = [] + d = d_model + for l, layer in enumerate(layers): + # Pool at the end of every n_repeat blocks + pool_cfg = pool if (l+1) % n_repeat == 0 else None + block = SequenceResidualBlock(d, l+1, prenorm=prenorm, dropout=dropout, tie_dropout=tie_dropout, transposed=transposed, layer=layer, residual=residual, norm=norm, pool=pool_cfg) + _layers.append(block) + d = block.d_output + + self.d_output = d + self.layers = nn.ModuleList(_layers) + if prenorm: + if norm is None: + self.norm = None + elif isinstance(norm, str): + self.norm = Normalization(self.d_output, transposed=self.transposed, _name_=norm) + else: + self.norm = Normalization(self.d_output, transposed=self.transposed, **norm) + else: + self.norm = nn.Identity() + + def forward(self, inputs, *args, state=None, **kwargs): + """ Inputs assumed to be (batch, sequence, dim) """ + if self.transposed: inputs = rearrange(inputs, 'b ... d -> b d ...') + inputs = self.drop(inputs) + + # Track norms + if self.track_norms: output_norms = [torch.mean(inputs.detach() ** 2)] + + # Apply layers + outputs = inputs + prev_states = [None] * len(self.layers) if state is None else state + next_states = [] + for layer, prev_state in zip(self.layers, prev_states): + outputs, state = layer(outputs, *args, state=prev_state, **kwargs) + next_states.append(state) + if self.track_norms: output_norms.append(torch.mean(outputs.detach() ** 2)) + if self.norm is not None: outputs = self.norm(outputs) + + if self.transposed: outputs = rearrange(outputs, 'b d ... -> b ... d') + + if self.track_norms: + metrics = to_dict(output_norms, recursive=False) + self.metrics = {f'norm/{i}': v for i, v in metrics.items()} + + return outputs, next_states + + @property + def d_state(self): + d_states = [layer.d_state for layer in self.layers] + return sum([d for d in d_states if d is not None]) + + @property + def state_to_tensor(self): + # Slightly hacky way to implement this in a curried manner (so that the function can be extracted from an instance) + # Somewhat more sound may be to turn this into a @staticmethod and grab subclasses using hydra.utils.get_class + def fn(state): + x = [_layer.state_to_tensor(_state) for (_layer, _state) in zip(self.layers, state)] + x = [_x for _x in x if _x is not None] + return torch.cat( x, dim=-1) + return fn + + def default_state(self, *batch_shape, device=None): + return [layer.default_state(*batch_shape, device=device) for layer in self.layers] + + def step(self, x, state, **kwargs): + # Apply layers + prev_states = [None] * len(self.layers) if state is None else state + next_states = [] + for layer, prev_state in zip(self.layers, prev_states): + x, state = layer.step(x, state=prev_state, **kwargs) + next_states.append(state) + + x = self.norm(x) + + return x, next_states diff --git a/src/clm/src/models/sequence/pool.py b/src/clm/src/models/sequence/pool.py new file mode 100644 index 00000000..e5b8f4c6 --- /dev/null +++ b/src/clm/src/models/sequence/pool.py @@ -0,0 +1,459 @@ +"""Implements downsampling and upsampling on sequences.""" + +import torch +from torch import nn +import torch.nn.functional as F +from einops import rearrange, repeat, reduce + +from clm.src.models.sequence import SequenceModule +from clm.src.models.nn import LinearActivation + +""" Simple pooling functions that just downsample or repeat + +stride: Subsample on the layer dimension +expand: Repeat on the feature dimension +""" + + +class DownSample(SequenceModule): + def __init__(self, d_input, stride=1, expand=1, transposed=True): + super().__init__() + self.d_input = d_input + self.stride = stride + self.expand = expand + self.transposed = transposed + + def forward(self, x): + if x is None: return None + if self.stride > 1: + assert x.ndim == 3, "Downsampling with higher-dimensional inputs is currently not supported. It is recommended to use average or spectral pooling instead." + if self.transposed: + x = x[..., 0::self.stride] + else: + x = x[..., 0::self.stride, :] + + if self.expand > 1: + if self.transposed: + x = repeat(x, 'b d ... -> b (d e) ...', e=self.expand) + else: + x = repeat(x, 'b ... d -> b ... (d e)', e=self.expand) + + return x, None + + + def step(self, x, state, **kwargs): + if self.stride > 1 or self.expand > 1: + raise NotImplementedError + return x, state + + @property + def d_output(self): + return self.d_input * self.expand + +class DownAvgPool(SequenceModule): + def __init__(self, d_input, stride=1, expand=None, transposed=True): + super().__init__() + self.d_input = d_input + self.stride = stride + self.expand = expand + self.transposed = transposed + + if self.expand is not None: + self.linear = LinearActivation( + d_input, + d_input * expand, + transposed=transposed, + ) + + def forward(self, x): + if not self.transposed: + x = rearrange(x, 'b ... d -> b d ...') + + if self.stride > 1: + # einops appears slower than F + if x.ndim == 3: + x = F.avg_pool1d(x, self.stride, self.stride) + elif x.ndim == 4: + x = F.avg_pool2d(x, self.stride, self.stride) + else: + # Reduction string e.g. "b d (l1 2) (l2 2) -> b d l1 l2" + reduce_str = "b d " + " ".join([f"(l{i} {self.stride})" for i in range(x.ndim-2)]) \ + + " -> b d " + " ".join([f"l{i}" for i in range(x.ndim-2)]) + x = reduce(x, reduce_str, 'mean') + + # if self.expand > 1: + # x = repeat(x, 'b d ... -> b (d e) ...', e=self.expand) + + if not self.transposed: + x = rearrange(x, 'b d ... -> b ... d') + if self.expand is not None: + x = self.linear(x) + return x, None + + def step(self, x, state, **kwargs): + if self.stride > 1 or self.expand > 1: + raise NotImplementedError + return x, state + + @property + def d_output(self): + if self.expand is None: + return self.d_input + else: + return self.d_input * self.expand + +class DownSpectralPool(SequenceModule): + def __init__(self, d_input, stride=1, expand=1, transposed=True): + super().__init__() + self.d_input = d_input + self.stride = stride + self.expand = expand + self.transposed = transposed + + def forward(self, x): + """ + x: (B, L..., D) + """ + if not self.transposed: + x = rearrange(x, 'b ... d -> b d ...') + shape = x.shape[2:] + x_f = torch.fft.ifftn(x, s=shape) + + for axis, l in enumerate(shape): + assert l % self.stride == 0, 'input length must be divisible by stride' + new_l = l // self.stride + idx = torch.cat([torch.arange(0, new_l-new_l//2), l+torch.arange(-new_l//2, 0)]).to(x_f.device) + x_f = torch.index_select(x_f, 2+axis, idx) + x = torch.fft.ifftn(x_f, s=[l//self.stride for l in shape]) + x = x.real + + if self.expand > 1: + x = repeat(x, 'b d ... -> b (d e) ...', e=self.expand) + if not self.transposed: + x = rearrange(x, 'b d ... -> b ... d') + return x, None + + def step(self, x, state, **kwargs): + if self.stride > 1 or self.expand > 1: + raise NotImplementedError + return x, state + + @property + def d_output(self): + return self.d_input * self.expand + +class UpSample(SequenceModule): + def __init__(self, d_input, stride=1, expand=1, transposed=True): + super().__init__() + self.d_input = d_input + self.stride = stride + self.expand = expand + self.transposed = transposed + + def forward(self, x): + if x is None: return None + if self.expand > 1: + if self.transposed: + x = reduce(x, '... (d e) l -> ... d l', 'mean', e=self.expand) + else: + x = reduce(x, '... (d e) -> ... d', 'mean', e=self.expand) + if self.stride > 1: + if self.transposed: + x = repeat(x, '... l -> ... (l e)', e=self.stride) + else: + x = repeat(x, '... l d -> ... (l e) d', e=self.stride) + return x, None + + @property + def d_output(self): + return self.d_input // self.expand + def step(self, x, state, **kwargs): + if self.stride > 1 or self.expand > 1: + raise NotImplementedError + return x, state +class UpAvgPool(SequenceModule): + def __init__(self, d_input, stride=1, expand=1, causal=False, transposed=True): + super().__init__() + assert d_input % expand == 0 + self.d_input = d_input + self.stride = stride + self.expand = expand + self.causal = causal + self.transposed = transposed + + self.linear = LinearActivation( + d_input, + d_input // expand, + transposed=transposed, + ) + + def forward(self, x): + # TODO only works for 1D right now + if x is None: return None + x = self.linear(x) + if self.stride > 1: + if self.transposed: + if self.causal: + x = F.pad(x[..., :-1], (1, 0)) # Shift to ensure causality + x = repeat(x, '... l -> ... (l e)', e=self.stride) + else: + if self.causal: + x = F.pad(x[..., :-1, :], (0, 0, 1, 0)) # Shift to ensure causality + x = repeat(x, '... l d -> ... (l e) d', e=self.stride) + return x, None + + @property + def d_output(self): + return self.d_input // self.expand + def step(self, x, state, **kwargs): + if self.stride > 1 or self.expand > 1: + raise NotImplementedError + return x, state + +class DownLinearPool(SequenceModule): + def __init__(self, d_model, stride=1, expand=1, causal=False, transposed=True): + super().__init__() + + self.d_model = d_model + self.stride = stride + self.expand = expand + self.transposed = transposed + + self.linear = LinearActivation( + d_model * stride, + d_model * expand, + transposed=transposed, + ) + + def forward(self, x): + if self.transposed: + x = rearrange(x, '... h (l s) -> ... (h s) l', s=self.stride) + else: + x = rearrange(x, '... (l s) h -> ... l (h s)', s=self.stride) + x = self.linear(x) + return x, None + + def step(self, x, state, **kwargs): + # if self.stride > 1 or self.expand > 1: + # raise NotImplementedError + # return x, state + if x is None: return None, state + state.append(x) + if len(state) == self.stride: + x = rearrange(torch.stack(state, dim=-1), '... h s -> ... (h s)') + if self.transposed: x = x.unsqueeze(-1) + x = self.linear(x) + if self.transposed: x = x.squeeze(-1) + return x, [] + else: + return None, state + + def default_state(self, *batch_shape, device=None): + return [] + + @property + def d_output(self): + return self.d_input * self.expand + +class UpLinearPool(SequenceModule): + def __init__(self, d, stride=1, expand=1, causal=False, transposed=True): + super().__init__() + + # self.d_model = d * expand + # self.d_output = d + assert d % expand == 0 + self.d_model = d + self.d_output = d // expand + # self._d_output = d_output + self.stride = stride + self.causal = causal + self.transposed = transposed + + self.linear = LinearActivation( + self.d_model, + self.d_output * stride, + transposed=transposed, + ) + + def forward(self, x, skip=None): + x = self.linear(x) + if self.transposed: + if self.causal: + x = F.pad(x[..., :-1], (1, 0)) # Shift to ensure causality + x = rearrange(x, '... (h s) l -> ... h (l s)', s=self.stride) + else: + if self.causal: + x = F.pad(x[..., :-1, :], (0, 0, 1, 0)) # Shift to ensure causality + x = rearrange(x, '... l (h s) -> ... (l s) h', s=self.stride) + if skip is not None: + x = x + skip + return x, None + + def step(self, x, state, **kwargs): + """ + x: (..., H) + """ + + assert len(state) > 0 + y, state = state[0], state[1:] + if len(state) == 0: + assert x is not None + if self.transposed: x = x.unsqueeze(-1) + x = self.linear(x) + if self.transposed: x = x.squeeze(-1) + x = rearrange(x, '... (h s) -> ... h s', s=self.stride) + state = list(torch.unbind(x, dim=-1)) + else: assert x is None + return y, state + + def default_state(self, *batch_shape, device=None): + state = torch.zeros(batch_shape + (self.d_output, self.stride), device=device) # (batch, h, s) + state = list(torch.unbind(state, dim=-1)) # List of (..., H) + return state + + # @property + # def d_output(self): return self._d_output + +""" Pooling functions with trainable parameters """ # TODO make d_output expand instead + +class DownPool2d(SequenceModule): + + def __init__(self, d_input, d_output, stride=1, transposed=True, weight_norm=True): + super().__init__() + + self.linear = LinearActivation( + d_input, + d_output, + transposed=transposed, + weight_norm=weight_norm, + ) + + self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride), + + def forward(self, x): + if self.transposed: + x = self.pool(x) + +# TODO DownPool/UpPool are currently used by unet/sashimi backbones +# DownLinearPool is used by the registry (for isotropic backbone) +# DownPool is essentially the same as DownLinearPool. These should be consolidated +class DownPool(SequenceModule): + def __init__(self, d_input, d_output=None, expand=None, stride=1, transposed=True, weight_norm=True, initializer=None, activation=None): + super().__init__() + assert (d_output is None) + (expand is None) == 1 + if d_output is None: d_output = d_input * expand + + self.d_output = d_output + self.stride = stride + self.transposed = transposed + + self.linear = LinearActivation( + d_input * stride, + d_output, + transposed=transposed, + initializer=initializer, + weight_norm = weight_norm, + activation=activation, + activate=True if activation is not None else False, + ) + + def forward(self, x): + if self.transposed: + x = rearrange(x, '... h (l s) -> ... (h s) l', s=self.stride) + else: + x = rearrange(x, '... (l s) h -> ... l (h s)', s=self.stride) + x = self.linear(x) + return x, None + + def step(self, x, state, **kwargs): + """ + x: (..., H) + """ + + if x is None: return None, state + state.append(x) + if len(state) == self.stride: + x = rearrange(torch.stack(state, dim=-1), '... h s -> ... (h s)') + if self.transposed: x = x.unsqueeze(-1) + x = self.linear(x) + if self.transposed: x = x.squeeze(-1) + return x, [] + else: + return None, state + + def default_state(self, *batch_shape, device=None): + return [] + + +class UpPool(SequenceModule): + def __init__(self, d_input, d_output, stride, transposed=True, weight_norm=True, initializer=None, activation=None): + super().__init__() + + self.d_input = d_input + self._d_output = d_output + self.stride = stride + self.transposed = transposed + + self.linear = LinearActivation( + d_input, + d_output * stride, + transposed=transposed, + initializer=initializer, + weight_norm = weight_norm, + activation=activation, + activate=True if activation is not None else False, + ) + + def forward(self, x, skip=None): + x = self.linear(x) + if self.transposed: + x = F.pad(x[..., :-1], (1, 0)) # Shift to ensure causality + x = rearrange(x, '... (h s) l -> ... h (l s)', s=self.stride) + else: + x = F.pad(x[..., :-1, :], (0, 0, 1, 0)) # Shift to ensure causality + x = rearrange(x, '... l (h s) -> ... (l s) h', s=self.stride) + if skip is not None: + x = x + skip + return x, None + + def step(self, x, state, **kwargs): + """ + x: (..., H) + """ + + assert len(state) > 0 + y, state = state[0], state[1:] + if len(state) == 0: + assert x is not None + if self.transposed: x = x.unsqueeze(-1) + x = self.linear(x) + if self.transposed: x = x.squeeze(-1) + x = rearrange(x, '... (h s) -> ... h s', s=self.stride) + state = list(torch.unbind(x, dim=-1)) + else: assert x is None + return y, state + + def default_state(self, *batch_shape, device=None): + state = torch.zeros(batch_shape + (self.d_output, self.stride), device=device) # (batch, h, s) + state = list(torch.unbind(state, dim=-1)) # List of (..., H) + return state + + @property + def d_output(self): return self._d_output + +registry = { + 'sample': DownSample, + 'pool': DownAvgPool, + 'avg': DownAvgPool, + 'linear': DownLinearPool, + 'spectral': DownSpectralPool, +} + +up_registry = { + # 'sample': UpSample, + 'pool': UpAvgPool, + 'avg': UpAvgPool, + 'linear': UpLinearPool, + # 'spectral': UpSpectralPool, # Not implemented and no way to make this causal +} + diff --git a/src/clm/src/models/sequence/simple_lm.py b/src/clm/src/models/sequence/simple_lm.py new file mode 100644 index 00000000..bc525d55 --- /dev/null +++ b/src/clm/src/models/sequence/simple_lm.py @@ -0,0 +1,469 @@ +# Copyright (c) 2023, Tri Dao, Dan Fu. +# Simplified, mostly standalone version of LongConvLM for synthetics. + +import math +from functools import partial + +from collections import namedtuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.ops import StochasticDepth + +from einops import rearrange + +from clm.src.utils import instantiate +import clm.src.utils.registry as registry + +class LinearResidual(nn.Linear): + """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense. + """ + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return super().forward(input), input + +class SelfAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + super().__init__() + self.causal = causal + self.softmax_scale = softmax_scale + self.dropout_p = attention_dropout + + def forward(self, qkv, causal=None, key_padding_mask=None): + """Implements the multihead softmax attention. + Arguments + --------- + qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) + causal: if passed, will override self.causal + key_padding_mask: boolean mask to apply to the attention weights. True means to keep, + False means to mask out. (B, S) + """ + batch_size, seqlen = qkv.shape[0], qkv.shape[1] + causal = self.causal if causal is None else causal + q, k, v = qkv.unbind(dim=2) + softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) + scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale) + if key_padding_mask is not None: + padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, + device=scores.device) + padding_mask.masked_fill_(key_padding_mask, 0.0) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s') + if causal: + # "triu_tril_cuda_template" not implemented for 'BFloat16' + # So we have to construct the mask in float + causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + causal_mask.to(dtype=scores.dtype) + attention = torch.softmax(scores, dim=-1, dtype=v.dtype) + attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0) + output = torch.einsum('bhts,bshd->bthd', attention_drop, v) + return output + +class MHA(nn.Module): + """Multi-head self-attention and cross-attention + """ + + def __init__(self, embed_dim, num_heads, bias=True, dropout=0.0, + softmax_scale=None, causal=False, layer_idx=None, dwconv=False,return_residual=False,device=None, dtype=None) -> None: + """ + return_residual: whether to return the input x along with the output. This is for + performance reason: for post-norm architecture, returning the input allows us + to fuse the backward of nn.Linear with the residual connection. + """ + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.embed_dim = embed_dim + self.causal = causal + self.layer_idx = layer_idx + self.dwconv = dwconv + self.return_residual = return_residual + + self.num_heads = num_heads + assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" + self.head_dim = self.embed_dim // num_heads + + linear_cls = nn.Linear + linear_resid_cls = LinearResidual + inner_attn_cls = SelfAttention + + if not self.return_residual: + self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) + else: + self.Wqkv = linear_resid_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) + if self.dwconv: + self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2, + groups=3 * embed_dim) + + self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, + attention_dropout=dropout) + + # output projection always have the bias (for now) + self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs) + + def forward(self, x, key_padding_mask=None, **kwargs): + """ + Arguments: + x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if + cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total + is the is the sum of the sequence lengths in the batch. + cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into x. Only applicable when using + FlashAttention. + max_seqlen: int. Maximum sequence length in the batch. + key_padding_mask: boolean mask, True means to keep, False means to mask out. + (batch, seqlen). Only applicable when not using FlashAttention. + mixer_subset: for cross-attention only. If not None, will take a subset of x + before applying the query projection. Useful for e.g., ViT where we only care + about the CLS token in the last layer. + inference_params: for generation. Adapted from Megatron-LM (and Apex) + https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 + """ + + kwargs = ({'key_padding_mask': key_padding_mask, **kwargs}) + + if not self.return_residual: + qkv = self.Wqkv(x) + else: + qkv, x = self.Wqkv(x) + if self.dwconv: + qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2], + 'b d s -> b s d').contiguous() + qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim) + + context = self.inner_attn(qkv, **kwargs) + + out = self.out_proj(rearrange(context, '... h d -> ... (h d)')) + return out if not self.return_residual else (out, x) + + +class GPT2Embeddings(nn.Module): + + def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None, + word_embed_proj_dim=None, device=None, dtype=None): + """ + If max_position_embeddings <= 0, there's no position embeddings + If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension + the project up to embed_dim + """ + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + if word_embed_proj_dim is None: + self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx, + **factory_kwargs) + self.project_in = None + else: + self.word_embeddings = nn.Embedding(vocab_size, word_embed_proj_dim, + padding_idx=padding_idx, **factory_kwargs) + self.project_in = nn.Linear(word_embed_proj_dim, embed_dim, bias=False, + **factory_kwargs) + self.max_position_embeddings = max_position_embeddings + if self.max_position_embeddings > 0: + self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim, + **factory_kwargs) + + def forward(self, input_ids, position_ids=None): + """ + input_ids: (batch, seqlen) + position_ids: (batch, seqlen) + """ + batch_size, seqlen = input_ids.shape + embeddings = self.word_embeddings(input_ids) + if self.project_in is not None: + embeddings = self.project_in(embeddings) + if self.max_position_embeddings > 0: + if position_ids is None: + position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + return embeddings + +class Mlp(nn.Module): + + def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu, + return_residual=False, device=None, dtype=None): + """ + From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/mlp.py + """ + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.return_residual = return_residual + self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs) + self.activation = activation + self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs) + + def forward(self, x): + y = self.fc1(x) + y = self.activation(y) + y = self.fc2(y) + return y if not self.return_residual else (y, x) + +class Block(nn.Module): + + def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm, + dropout_cls=nn.Dropout, prenorm=True, resid_dropout1=0., resid_dropout2=0., + drop_path1=0., drop_path2=0., + return_residual=False, + residual_in_fp32=False): + """ + From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/block.py + For prenorm=True, this Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both + the hidden_states (output of the MLP) and the residual. + This is for performance reasons, as we can fuse the dropout, add and LayerNorm. + The residual needs to be provided (except for the very first block). + For prenorm=False, this Block has the same structure as a regular postnorm Transformer + block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN. + return_residual: whether each of the sub-layers (mixer and mlp) will return the residual. + This is for performance reason: for post-norm architecture, returning the input allows us + to fuse the backward of nn.Linear with the residual connection. + """ + super().__init__() + self.prenorm = prenorm + self.return_residual = return_residual + self.residual_in_fp32 = residual_in_fp32 + if self.residual_in_fp32: + assert self.prenorm, 'residual_in_fp32 is only compatible with prenorm=True' + if mixer_cls is None: + mixer_cls = partial(MHA, num_heads=dim // 64) + if mlp_cls is None: + mlp_cls = partial(Mlp, hidden_features=4 * dim) + self.mixer = mixer_cls(dim) + self.dropout1 = dropout_cls(resid_dropout1) + self.drop_path1 = StochasticDepth(drop_path1, mode='row') + self.norm1 = norm_cls(dim) + self.mlp = mlp_cls(dim) + if not isinstance(self.mlp, nn.Identity): + self.dropout2 = dropout_cls(resid_dropout2) + self.drop_path2 = StochasticDepth(drop_path2, mode='row') + self.norm2 = norm_cls(dim) + + def forward(self, hidden_states, residual = None, + mixer_subset=None, mixer_kwargs=None): + r"""Pass the input through the encoder layer. + Args: + hidden_states: the sequence to the encoder layer (required). + residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual)) + mixer_subset: for cross-attention only. If not None, will take a subset of x + before applying the query projection. Useful for e.g., ViT where we only care + about the CLS token in the last layer. + """ + if self.prenorm: + dropped = self.drop_path1(self.dropout1(hidden_states)) + residual = (dropped + residual) if residual is not None else dropped + hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + if mixer_kwargs is None: + mixer_kwargs = {} + if mixer_subset is not None: + mixer_kwargs['mixer_subset'] = mixer_subset + hidden_states = self.mixer(hidden_states, **mixer_kwargs) + if mixer_subset is not None: + residual = residual[:, mixer_subset] + if not isinstance(self.mlp, nn.Identity): + dropped = self.drop_path2(self.dropout2(hidden_states)) + residual = (dropped + residual) if residual is not None else dropped + hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + else: + assert residual is None + mixer_out = self.mixer( + hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {}) + ) + if self.return_residual: # mixer out is actually a pair here + mixer_out, hidden_states = mixer_out + + hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out)) + + hidden_states).to(dtype=self.norm1.weight.dtype)) + + if not isinstance(self.mlp, nn.Identity): + mlp_out = self.mlp(hidden_states) + if self.return_residual: # mlp out is actually a pair here + mlp_out, hidden_states = mlp_out + + hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out)) + + hidden_states).to(dtype=self.norm2.weight.dtype)) + + return hidden_states + +def create_mixer_cls(layer=None, + attn_layer_idx=None, attn_cfg=None, layer_idx=None, + device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + if attn_layer_idx is not None and layer_idx in attn_layer_idx: + causal = True if attn_cfg is None else attn_cfg.pop('causal', True) + + mha_cls = MHA + + mixer_cls = partial(mha_cls, causal=causal, layer_idx=layer_idx, + **(attn_cfg if attn_cfg is not None else {}),**factory_kwargs) + else: + mixer_cls = instantiate(registry.layer, layer, partial=True, layer_idx=layer_idx, **factory_kwargs) + return mixer_cls + + +def create_mlp_cls(d_model, d_inner=None, device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + inner_dim = d_inner if d_inner is not None else 4 * d_model + + mlp_cls = partial(Mlp, hidden_features=inner_dim, + activation=partial(F.gelu, approximate='tanh'), **factory_kwargs) + + return mlp_cls + + +def create_block(d_model, d_inner=None, + layer=None, attn_layer_idx=None, + attn_cfg=None, layer_norm_epsilon=1e-5, + resid_dropout1=0.0, resid_dropout2=0.0, residual_in_fp32=False, + layer_idx=None, + device=None, dtype=None): + factory_kwargs = {'device': device, 'dtype': dtype} + mixer_cls = create_mixer_cls(layer=layer, + attn_layer_idx=attn_layer_idx, + attn_cfg=attn_cfg, layer_idx=layer_idx, + **factory_kwargs) + mlp_cls = create_mlp_cls(d_model, d_inner=d_inner, + **factory_kwargs) + norm_cls = partial(nn.LayerNorm, eps=layer_norm_epsilon, **factory_kwargs) + block = Block(d_model, mixer_cls, mlp_cls, norm_cls=norm_cls, + prenorm=True, resid_dropout1=resid_dropout1, resid_dropout2=resid_dropout2,residual_in_fp32=residual_in_fp32) + block.layer_idx = layer_idx + return block + + +# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 +def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True, + glu_act=False): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, std=initializer_range) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.normal_(module.weight, std=initializer_range) + + if rescale_prenorm_residual: + # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: + # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale + # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. + # > -- GPT-2 :: https://openai.com/blog/better-language-models/ + # + # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py + for name, p in module.named_parameters(): + if name in ["out_proj.weight", "fc2.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer)) + # If using GLU activation for now, we scale the std by 2 + elif name in ["output_linear.0.weight"]: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + if not glu_act: + nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer)) + else: + out_features = p.shape[0] + # Multiplying the first half of the matrix by 2 since sigmoid scales it down by 0.5 + # on average. + nn.init.normal_(p[:out_features // 2], mean=0.0, std=initializer_range / math.sqrt(2 * n_layer) * 2) + + +class LMBackbone(nn.Module): + + def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_size: int, + process_group=None, layer=None, + attn_layer_idx=None, attn_cfg=None, max_position_embeddings=0, + resid_dropout: float = 0.0, embed_dropout: float = 0.1, + layer_norm_epsilon: float = 1e-5, initializer_cfg=None,residual_in_fp32=False, + device=None, dtype=None, **kwargs) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.process_group = process_group + self.residual_in_fp32 = residual_in_fp32 + self.embeddings = GPT2Embeddings(d_model, vocab_size, max_position_embeddings, + **factory_kwargs) + + + self.layers = nn.ModuleList([create_block( + d_model, d_inner=d_inner, + layer=layer, attn_layer_idx=attn_layer_idx, + attn_cfg=attn_cfg, layer_norm_epsilon=layer_norm_epsilon, + resid_dropout1=embed_dropout if i == 0 else resid_dropout, + resid_dropout2=resid_dropout, residual_in_fp32=residual_in_fp32,layer_idx=i, + **factory_kwargs, + ) for i in range(n_layer)]) + + self.drop_f = nn.Dropout(resid_dropout) + self.ln_f = nn.LayerNorm(d_model, eps=layer_norm_epsilon, **factory_kwargs) + + self.apply(partial(_init_weights, n_layer=n_layer, + **(initializer_cfg if initializer_cfg is not None else {}))) + + def forward(self, input_ids, position_ids=None): + hidden_states = self.embeddings(input_ids, position_ids=position_ids,) + residual = None + + for layer in self.layers: + hidden_states, residual = layer(hidden_states, residual) + + dropped = self.drop_f(hidden_states) + residual = (dropped + residual) if residual is not None else dropped + hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype)) + + return hidden_states + + +class SimpleLMHeadModel(nn.Module): + + def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_size: int, + layer=None, + attn_layer_idx=None, attn_cfg=None, max_position_embeddings=0, + resid_dropout: float = 0.0, embed_dropout: float = 0.1, + layer_norm_epsilon: float = 1e-5, initializer_cfg=None,residual_in_fp32=False, + pad_vocab_size_multiple: int = 1, + device=None, dtype=None, **kwargs) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + if vocab_size % pad_vocab_size_multiple != 0: + vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) + self.backbone = LMBackbone( + d_model=d_model, n_layer=n_layer, d_inner=d_inner, vocab_size=vocab_size, + layer=layer, attn_layer_idx=attn_layer_idx, attn_cfg=attn_cfg, + max_position_embeddings=max_position_embeddings, + resid_dropout=resid_dropout, embed_dropout=embed_dropout, + layer_norm_epsilon=layer_norm_epsilon, + initializer_cfg=initializer_cfg, residual_in_fp32=residual_in_fp32, + **factory_kwargs, **kwargs + ) + self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) + + # Initialize weights and apply final processing + self.apply(partial(_init_weights, n_layer=n_layer, + **(initializer_cfg if initializer_cfg is not None else {}))) + self.tie_weights() + + def tie_weights(self): + self.lm_head.weight = self.backbone.embeddings.word_embeddings.weight + + def forward(self, input_ids, position_ids=None, state=None): # state for the repo interface + hidden_states = self.backbone(input_ids, position_ids=position_ids) + lm_logits = self.lm_head(hidden_states) + CausalLMOutput = namedtuple('CausalLMOutput', ['logits']) + return CausalLMOutput(logits=lm_logits), None diff --git a/src/clm/src/models/sequence/ssm/dplr.py b/src/clm/src/models/sequence/ssm/dplr.py new file mode 100644 index 00000000..a817cc07 --- /dev/null +++ b/src/clm/src/models/sequence/ssm/dplr.py @@ -0,0 +1,107 @@ +# Copied from https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/sequence/ss/dplr.py + +"""Initializations of structured state space models""" +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +from clm.src.models.sequence.ssm import hippo + + +def dplr(scaling='linear', N=64, rank=1, H=1, dtype=torch.float, real_scale=1.0, imag_scale=1.0, random_real=False, random_imag=False, normalize=False, diagonal=True, random_B=False): + assert dtype == torch.float or dtype == torch.double + dtype = torch.cfloat if dtype == torch.float else torch.cdouble + + pi = torch.tensor(math.pi) + if random_real: + real_part = torch.rand(H, N//2) + else: + real_part = .5 * torch.ones(H, N//2) + if random_imag: + imag_part = N//2 * torch.rand(H, N//2) + else: + imag_part = repeat(torch.arange(N//2), 'n -> h n', h=H) + + real_part = real_scale * real_part + if scaling == 'random': + imag_part = torch.randn(H, N//2) + elif scaling == 'real': + imag_part = 0 * imag_part + real_part = 1 + repeat(torch.arange(N//2), 'n -> h n', h=H) + elif scaling in ['linear', 'lin']: + imag_part = pi * imag_part + elif scaling in ['inverse', 'inv']: # Based on asymptotics of the default HiPPO matrix + imag_part = 1/pi * N * (N/(1+2*imag_part)-1) + elif scaling in ['inverse2', 'inv2']: + imag_part = 1/pi * N * (N/(1+imag_part)-1) + elif scaling in ['quadratic', 'quad']: + imag_part = 1/pi * (1+2*imag_part)**2 + elif scaling in ['legs', 'hippo']: + w, _, _, _ = hippo.nplr('legsd', N) + imag_part = w.imag + + else: raise NotImplementedError + imag_part = imag_scale * imag_part + w = -real_part + 1j * imag_part + + # Initialize B + if random_B: + B = torch.randn(H, N//2, dtype=dtype) + else: + B = torch.ones(H, N//2, dtype=dtype) + + if normalize: + norm = -B/w # (H, N) # Result if you integrate the kernel with constant 1 function + zeta = 2*torch.sum(torch.abs(norm)**2, dim=-1, keepdim=True) # Variance with a random C vector + B = B / zeta**.5 + + P = torch.randn(rank, H, N//2, dtype=dtype) + if diagonal: P = P * 0.0 + V = torch.eye(N, dtype=dtype)[:, :N//2] # Only used in testing + V = repeat(V, 'n m -> h n m', h=H) + + return w, P, B, V + +def ssm(measure, N, R, H, **ssm_args): + """Dispatcher to create single SSM initialization + N: state size + R: rank (for DPLR parameterization) + H: number of independent SSM copies + """ + + if measure == "dplr": + w, P, B, V = dplr(N=N, rank=R, H=H, **ssm_args) + elif measure.startswith("diag"): + args = measure.split("-") + assert args[0] == "diag" and len(args) > 1 + scaling = args[1] + w, P, B, V = dplr(scaling=scaling, N=N, rank=R, H=H, diagonal=True, **ssm_args) + else: + w, P, B, V = hippo.nplr(measure, N, R, **ssm_args) + w = repeat(w, 'n -> s n', s=H) + P = repeat(P, 'r n -> r s n', s=H) + B = repeat(B, 'n -> s n', s=H) + V = repeat(V, 'n m -> s n m', s=H) + return w, P, B, V + +combinations = { + 'hippo': ['legs', 'fourier'], + 'diag': ['diag-inv', 'diag-lin'], + 'all': ['legs', 'fourier', 'diag-inv', 'diag-lin'], +} + +def combination(measures, N, R, S, **ssm_args): + if isinstance(measures, str): + measures = combinations[measures] if measures in combinations else [measures] + + assert S % len(measures) == 0, f"{S} independent trainable SSM copies must be multiple of {len(measures)} different measures" + w, P, B, V = zip( + *[ssm(measure, N, R, S // len(measures), **ssm_args) for measure in measures] + ) + w = torch.cat(w, dim=0) # (S N) + P = torch.cat(P, dim=1) # (R S N) + B = torch.cat(B, dim=0) # (S N) + V = torch.cat(V, dim=0) # (S N N) + return w, P, B, V diff --git a/src/clm/src/models/sequence/ssm/hippo.py b/src/clm/src/models/sequence/ssm/hippo.py new file mode 100644 index 00000000..07707b65 --- /dev/null +++ b/src/clm/src/models/sequence/ssm/hippo.py @@ -0,0 +1,259 @@ +# Copied from https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/hippo/hippo.py + +""" Definitions of A and B matrices for various HiPPO operators. """ + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from scipy import special as ss +from einops import rearrange, repeat +from opt_einsum import contract + +def embed_c2r(A): + A = rearrange(A, '... m n -> ... m () n ()') + A = np.pad(A, ((0, 0), (0, 1), (0, 0), (0, 1))) + \ + np.pad(A, ((0, 0), (1, 0), (0, 0), (1,0))) + return rearrange(A, 'm x n y -> (m x) (n y)') + +# TODO take in 'torch' option to return torch instead of numpy, and converts the shape of B from (N, 1) to (N) +def transition(measure, N, **measure_args): + """ A, B transition matrices for different measures + measure: the type of measure + legt - Legendre (translated) + legs - Legendre (scaled) + glagt - generalized Laguerre (translated) + lagt, tlagt - previous versions of (tilted) Laguerre with slightly different normalization + """ + # Laguerre (translated) + if measure == 'lagt': + b = measure_args.get('beta', 1.0) + A = np.eye(N) / 2 - np.tril(np.ones((N, N))) + B = b * np.ones((N, 1)) + # Generalized Laguerre + # alpha 0, beta small is most stable (limits to the 'lagt' measure) + # alpha 0, beta 1 has transition matrix A = [lower triangular 1] + elif measure == 'glagt': + alpha = measure_args.get('alpha', 0.0) + beta = measure_args.get('beta', 0.01) + A = -np.eye(N) * (1 + beta) / 2 - np.tril(np.ones((N, N)), -1) + B = ss.binom(alpha + np.arange(N), np.arange(N))[:, None] + + L = np.exp(.5 * (ss.gammaln(np.arange(N)+alpha+1) - ss.gammaln(np.arange(N)+1))) + A = (1./L[:, None]) * A * L[None, :] + B = (1./L[:, None]) * B * np.exp(-.5 * ss.gammaln(1-alpha)) * beta**((1-alpha)/2) + # Legendre (translated) + elif measure == 'legt': + Q = np.arange(N, dtype=np.float64) + R = (2*Q + 1) ** .5 + j, i = np.meshgrid(Q, Q) + A = R[:, None] * np.where(i < j, (-1.)**(i-j), 1) * R[None, :] + B = R[:, None] + A = -A + + # Halve again for timescale correctness + A *= 0.5 + B *= 0.5 + # LMU: equivalent to LegT up to normalization + elif measure == 'lmu': + Q = np.arange(N, dtype=np.float64) + R = (2*Q + 1)[:, None] # / theta + j, i = np.meshgrid(Q, Q) + A = np.where(i < j, -1, (-1.)**(i-j+1)) * R + B = (-1.)**Q[:, None] * R + # Legendre (scaled) + elif measure == 'legs': + q = np.arange(N, dtype=np.float64) + col, row = np.meshgrid(q, q) + r = 2 * q + 1 + M = -(np.where(row >= col, r, 0) - np.diag(q)) + T = np.sqrt(np.diag(2 * q + 1)) + A = T @ M @ np.linalg.inv(T) + B = np.diag(T)[:, None] + B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) + elif measure == 'legsd': + q = np.arange(N, dtype=np.float64) + col, row = np.meshgrid(q, q) + r = 2 * q + 1 + M = -(np.where(row >= col, r, 0) - np.diag(q)) + T = np.sqrt(np.diag(2 * q + 1)) + A = T @ M @ np.linalg.inv(T) + B = np.diag(T)[:, None] + B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) + A += .5 * B*B[None, :, 0] + B = B / 2.0 + elif measure in ['fourier_diag', 'foud']: + freqs = np.arange(N//2) + d = np.stack([freqs, np.zeros(N//2)], axis=-1).reshape(-1)[:-1] + A = 2*np.pi*(-np.diag(d, 1) + np.diag(d, -1)) + A = A - .5 * np.eye(N) + B = np.zeros(N) + B[0::2] = 2**.5 + B[0] = 1 + B = B[:, None] + elif measure in ['fourier', 'fout']: + freqs = np.arange(N//2) + d = np.stack([np.zeros(N//2), freqs], axis=-1).reshape(-1)[1:] + A = np.pi*(-np.diag(d, 1) + np.diag(d, -1)) + B = np.zeros(N) + B[0::2] = 2**.5 + B[0] = 1 + + # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case + A = A - B[:, None] * B[None, :] + B = B[:, None] + elif measure == 'fourier_decay': + freqs = np.arange(N//2) + d = np.stack([np.zeros(N//2), freqs], axis=-1).reshape(-1)[1:] + A = np.pi*(-np.diag(d, 1) + np.diag(d, -1)) + B = np.zeros(N) + B[0::2] = 2**.5 + B[0] = 1 + + # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case + A = A - .5 * B[:, None] * B[None, :] + B = .5 * B[:, None] + elif measure == 'fourier2': # Double everything: orthonormal on [0, 1] + freqs = 2*np.arange(N//2) + d = np.stack([np.zeros(N//2), freqs], axis=-1).reshape(-1)[1:] + A = np.pi*(-np.diag(d, 1) + np.diag(d, -1)) + B = np.zeros(N) + B[0::2] = 2**.5 + B[0] = 1 + + # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case + A = A - B[:, None] * B[None, :] * 2 + B = B[:, None] * 2 + elif measure == 'random': + A = np.random.randn(N, N) / N + B = np.random.randn(N, 1) + elif measure == 'diagonal': + A = -np.diag(np.exp(np.random.randn(N))) + B = np.random.randn(N, 1) + else: + raise NotImplementedError + + return A, B + +def rank_correction(measure, N, rank=1, dtype=torch.float): + """ Return low-rank matrix L such that A + L is normal """ + + if measure == 'legs': + assert rank >= 1 + P = torch.sqrt(.5+torch.arange(N, dtype=dtype)).unsqueeze(0) # (1 N) + elif measure == 'legt': + assert rank >= 2 + P = torch.sqrt(1+2*torch.arange(N, dtype=dtype)) # (N) + P0 = P.clone() + P0[0::2] = 0. + P1 = P.clone() + P1[1::2] = 0. + P = torch.stack([P0, P1], dim=0) # (2 N) + P *= 2**(-0.5) # Halve the rank correct just like the original matrix was halved + elif measure == 'lagt': + assert rank >= 1 + P = .5**.5 * torch.ones(1, N, dtype=dtype) + elif measure in ['fourier', 'fout']: + P = torch.zeros(N) + P[0::2] = 2**.5 + P[0] = 1 + P = P.unsqueeze(0) + elif measure == 'fourier_decay': + P = torch.zeros(N) + P[0::2] = 2**.5 + P[0] = 1 + P = P.unsqueeze(0) + P = P / 2**.5 + elif measure == 'fourier2': + P = torch.zeros(N) + P[0::2] = 2**.5 + P[0] = 1 + P = 2**.5 * P.unsqueeze(0) + elif measure in ['fourier_diag', 'foud', 'legsd']: + P = torch.zeros(1, N, dtype=dtype) + else: raise NotImplementedError + + d = P.size(0) + if rank > d: + P = torch.cat([P, torch.zeros(rank-d, N, dtype=dtype)], dim=0) # (rank N) + return P + +def initial_C(measure, N, dtype=torch.float): + """ Return C that captures the other endpoint in the HiPPO approximation """ + + if measure == 'legt': + C = (torch.arange(N, dtype=dtype)*2+1)**.5 * (-1)**torch.arange(N) + elif measure == 'fourier': + C = torch.zeros(N) + C[0::2] = 2**.5 + C[0] = 1 + else: + C = torch.zeros(N, dtype=dtype) # (N) + + return C + + +def nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=True): + """ Return w, p, q, V, B such that + (w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V + i.e. A = V[w - p q^*]V^*, B = V B + """ + assert dtype == torch.float or dtype == torch.double + cdtype = torch.cfloat if dtype == torch.float else torch.cdouble + + A, B = transition(measure, N) + A = torch.as_tensor(A, dtype=dtype) # (N, N) + B = torch.as_tensor(B, dtype=dtype)[:, 0] # (N,) + + P = rank_correction(measure, N, rank=rank, dtype=dtype) # (r N) + AP = A + torch.sum(P.unsqueeze(-2)*P.unsqueeze(-1), dim=-3) + + # We require AP to be nearly skew-symmetric + _A = AP + AP.transpose(-1, -2) + if (err := torch.sum((_A - _A[0,0]*torch.eye(N))**2) / N) > 1e-5: # if not torch.allclose(_A - _A[0,0]*torch.eye(N), torch.zeros(N, N), atol=1e-5): + print("WARNING: HiPPO matrix not skew symmetric", err) + + + # Take advantage of identity + skew-symmetric form to calculate real and imaginary parts separately + # Imaginary part can use eigh instead of eig + w_re = torch.mean(torch.diagonal(AP), -1, keepdim=True) + + # Diagonalize in double precision + if diagonalize_precision: AP = AP.to(torch.double) + # w, V = torch.linalg.eig(AP) # (..., N) (..., N, N) + w_im, V = torch.linalg.eigh(AP*-1j) # (..., N) (..., N, N) + if diagonalize_precision: w_im, V = w_im.to(cdtype), V.to(cdtype) + w = w_re + 1j * w_im + # Check: V w V^{-1} = A + # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)) + + + # Only keep half of each conjugate pair + _, idx = torch.sort(w.imag) + w_sorted = w[idx] + V_sorted = V[:, idx] + + # There is an edge case when eigenvalues can be 0, which requires some machinery to handle + # We use a huge hack here: Assume only one pair is 0, and that it is the first row/column of A (only happens in Fourier case) + V = V_sorted[:, :N//2] + w = w_sorted[:N//2] + assert w[-2].abs() > 1e-4, "Only 1 zero eigenvalue allowed in diagonal part of A" + if w[-1].abs() < 1e-4: + V[:, -1] = 0. + V[0, -1] = 2**-0.5 + V[1, -1] = 2**-0.5 * 1j + + _AP = V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2) + if ((err := torch.sum((2*_AP.real-AP)**2)/N) > 1e-5): + print("Warning: Diagonalization of A matrix not numerically precise - error", err) + # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)) + + V_inv = V.conj().transpose(-1, -2) + + # C = initial_C(measure, N, dtype=dtype) + B = contract('ij, j -> i', V_inv, B.to(V)) # V^* B + # C = contract('ij, j -> i', V_inv, C.to(V)) # V^* C + P = contract('ij, ...j -> ...i', V_inv, P.to(V)) # V^* P + + # return w, P, B, C, V + return w, P, B, V diff --git a/src/clm/src/models/sequence/ssm/s4_simple.py b/src/clm/src/models/sequence/ssm/s4_simple.py new file mode 100644 index 00000000..2176b481 --- /dev/null +++ b/src/clm/src/models/sequence/ssm/s4_simple.py @@ -0,0 +1,262 @@ +import torch +import torch.nn as nn +from clm.src.models.nn import LinearActivation, Activation, DropoutNd +from einops import rearrange, repeat +import opt_einsum as oe + +import math +class OurModule(nn.Module): + def __init__(self): super().__init__() + + def register(self, name, tensor, trainable=False, lr=None, wd=None): + """Utility method: register a tensor as a buffer or trainable parameter""" + + if trainable: + self.register_parameter(name, nn.Parameter(tensor)) + else: + self.register_buffer(name, tensor) + + optim = {} + if trainable and lr is not None: optim["lr"] = lr + if trainable and wd is not None: optim["weight_decay"] = wd + if len(optim) > 0: setattr(getattr(self, name), "_optim", optim) + +# +# This is intended to match np.convolve(x,w)[:len(w)] +# That is, (u \ast v)[k] = sum_{j} u[k-j]v[j] +# Here y = (u \ask v) on return. +# We assume the inputs are: +# u (B H L) +# v (C H L) +# and we want to produce y that is (B C H L) +# + + +def fft_conv(u,v): + L = u.shape[-1] + u_f = torch.fft.rfft(u, n=2*L) # (B H L) + v_f = torch.fft.rfft(v, n=2*L) # (C H L) + + y_f = oe.contract('bhl,chl->bchl', u_f, v_f) + y = torch.fft.irfft(y_f, n=2*L)[..., :L] # (B C H L) + return y + +def normalize_param(a, method, norm_const=None): + if method == "l1": + if norm_const is not None: + return a/((1+norm_const)*torch.linalg.norm(a,ord=1,dim=2).unsqueeze(2)) + return a/torch.linalg.norm(a,ord=1,dim=2).unsqueeze(2) + if method == "l2": + return a/torch.linalg.norm(a,ord=2,dim=2).unsqueeze(2) + if method == "max": + return 0.1*a/torch.max(a,dim=2)[0].unsqueeze(2) + if method == "none": + return a + raise ValueError(f"{method} normalization not implemented") + +class SimpleS4(OurModule): + def __init__(self, + nHippos, + d_state=64, + channels=1, + use_initial=True, # Use the initial state? + zero_order_hold=False, # Use zero-order hold approximation + trap_rule=True, + dt_min=0.001, + dt_max=0.1, + lr=None, # Hook to set LR of SSM parameters differently + learn_a=True, + learn_theta=True, + learn_dt=False, # whether to learn separate dt for each hippo + theta_scale=False, + skip_connection=True, + repr='cont', # representation to use: ['cont','disc','comp'] + param_norm = 'none', # for normalizing parameters for stability + **kernel_args,): # Use the trapezoid rule + super().__init__() + # H is number of hippos + # D is the dimension (also shockingly n other places) + # B is the batch + # L is the length + self.h = nHippos + self.d = d_state // 2 + self.channels = channels + self.use_initial = use_initial + self.zero_order_hold = zero_order_hold + # + # Use the trapezoid rule correct or just do zero-order hold. + self.trap_rule = trap_rule + self.repr = repr + self.learn_dt = learn_dt + self.shift = 'shift' in self.repr + self.param_norm = param_norm + + _fp = (self.channels, self.h, self.d) + + # Chebyshev initialization + h_scale = torch.exp(torch.arange(self.h)/self.h * math.log(dt_max/dt_min)) + angles = torch.arange(self.d)*torch.pi + t_scale = h_scale if theta_scale else torch.ones(self.h) + theta = oe.contract('c,h,d->chd', torch.ones(self.channels), t_scale, angles) + if self.repr == 'disc': + # discrete diagonal representation + a = torch.randn(*_fp).abs() + #a = 2*torch.rand(*_fp)-1 # init randomly from [-1,1] + else: + # default continuous diagonal representation + a = -repeat(h_scale, 'h -> c h d', c=self.channels, d=self.d) + + self.register("theta", theta,learn_theta,lr=lr, wd=None) + self.register("a", a, learn_a,lr=lr, wd=None) + + if self.learn_dt: + log_dt = torch.rand(self.h) * ( + math.log(dt_max) - math.log(dt_min) + ) + math.log(dt_min) + self.register("log_dt", log_dt, True,lr=lr, wd=None) + + # The other maps + if not skip_connection: + self.register("D", torch.zeros((channels, self.h)), False) + else: + self.D = nn.Parameter(torch.randn(channels, self.h)) + + if use_initial or 'comp' in self.repr: + if self.shift: + b = torch.zeros(*_fp) + b[:,:,0] = 1 + self.register("b", b, False) + else: + self.b = nn.Parameter(torch.randn(*_fp)) + self.c = nn.Parameter(torch.randn(*_fp)) + self.x0 = nn.Parameter(torch.randn(*_fp)) + else: + # This is an optimization that we combine q = c * b + # It's as if we're setting x0 = 0. + self.q = nn.Parameter(torch.randn(*_fp)) + + + def quadrature_method(self, u, horizon): + # The input is now Batch x Hippos x Length + l = u.size(-1) + + dt = 1/(l-1) # the step size + if self.learn_dt: + dt = torch.exp(self.log_dt).view(1,-1,1, 1) + + # q and a are both C x H x D + # zk is of length l we want a C x H x L matrix + zk = dt*torch.arange(l, device=u.device).view(1,1,-1,1) + + if self.repr == 'disc': + # discrete diagonal representation + a_ = (self.a).abs() + base_term = 2 * dt * torch.pow(a_.unsqueeze(2), zk) * torch.cos(self.theta.unsqueeze(2) * zk) + else: + # continuous diagonal representation + a_ = self.a #/torch.linalg.norm(self.a,ord=1,dim=2).unsqueeze(2) + a_ = -a_.abs() + # a_ = -self.a.abs() + base_term = 2*dt*torch.exp(a_.unsqueeze(2) * zk)*torch.cos( self.theta.unsqueeze(2) * zk) + + q = self.b*self.c if self.use_initial else self.q + f = (q.unsqueeze(2)*base_term).sum(-1) + + y = fft_conv(u,f) + # Add in the skip connection with per-channel D matrix + y = y + oe.contract('bhl,ch->bchl', u, self.D) + # Add back the initial state + if self.use_initial: + y = y + (2*(self.c*self.x0).unsqueeze(2)*base_term).sum(-1) + + return rearrange(y, 'b c h l-> b (c h) l'), None # flatten the channels. + + def forward(self, u, horizon=None): + return self.quadrature_method(u, horizon) + + +# Below here are standard wrapper classes to handle +# (1) Non-linearity +# (2) Integration with the Hippo Code base +class NonLinear(nn.Module): + def __init__(self, h, channels, + ln=False, # Extra normalization + transposed=True, + dropout=0.0, + postact=None, # activation after FF + activation='gelu', # activation in between SS and FF + initializer=None, # initializer on FF + weight_norm=False, # weight normalization on FF + ): + super().__init__() + dropout_fn = DropoutNd # nn.Dropout2d bugged in PyTorch 1.11 + dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() + #norm = Normalization(h*channels, transposed=transposed) if ln else nn.Identity() + + activation_fn = Activation(activation) + + output_linear = LinearActivation( + h*channels, + h, + transposed=transposed, + initializer=initializer, + activation=postact, + activate=True, + weight_norm=weight_norm, + ) + #self.f = nn.Sequential(activation_fn, dropout, norm, output_linear) + self.f = nn.Sequential(activation_fn, dropout, output_linear) + def forward(self,x): # Always (B H L) + return self.f(x) + +class SimpleS4Wrapper(nn.Module): + def __init__( + self, + d_model, + d_state=64, + channels=1, + bidirectional=False, + dropout=0.0, + transposed=True, # axis ordering (B, L, D) or (B, D, L) + ln=True, # IGNORED: Extra normalization + postact=None, # activation after FF + activation='gelu', # activation in between SS and FF + initializer=None, # initializer on FF + weight_norm=False, # weight normalization on FF + linear=False, + # SSM Kernel arguments + **kernel_args, + ): + super().__init__() + self.h = d_model + self.d = d_state + self.channels = channels + #self.shift = shift + #self.linear = linear + self.out_d = self.h + self.transposed = transposed + self.bidirectional = bidirectional + assert not bidirectional, f"Bidirectional NYI" + self.s4 = SimpleS4(nHippos=d_model, d_state=d_state, + channels=channels, **kernel_args) + # the mapping + # We transpose if it's not in the forward. + nl = NonLinear(self.h, channels=self.channels, ln=ln, # Extra normalization + dropout=dropout, postact=postact, activation=activation, transposed=True, + initializer=initializer, weight_norm=weight_norm) + self.out = nn.Identity() if linear else nl + + def forward(self, u, *w, state=None, horizon=None): + # u: (B H L) if self.transposed else (B L H) + if not self.transposed: u = u.transpose(-1, -2) + # We only pass BHL, and it is as if transposed is True. + y, state = self.s4(u,horizon=horizon) + ret = self.out(y) + if not self.transposed: ret = ret.transpose(-1, -2) + return ret, state + + @property + def d_state(self): return self.h * self.d + + @property + def d_output(self): return self.out_d \ No newline at end of file diff --git a/src/clm/src/models/sequence/ssm/s4d.py b/src/clm/src/models/sequence/ssm/s4d.py new file mode 100644 index 00000000..643e1a55 --- /dev/null +++ b/src/clm/src/models/sequence/ssm/s4d.py @@ -0,0 +1,404 @@ +""" Standalone version of Structured (Sequence) State Space (S4) model. """ + + +import logging +from functools import partial +import math +import numpy as np +from scipy import special as ss +import torch +import torch.nn as nn +import torch.nn.functional as F +from pytorch_lightning.utilities import rank_zero_only +from einops import rearrange, repeat +import opt_einsum as oe + +contract = oe.contract +contract_expression = oe.contract_expression + + +_c2r = torch.view_as_real +_r2c = torch.view_as_complex +if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 10): + _resolve_conj = lambda x: x.conj().resolve_conj() +else: + _resolve_conj = lambda x: x.conj() + + + +""" simple nn.Module components """ + +def Activation(activation=None, dim=-1): + if activation in [ None, 'id', 'identity', 'linear' ]: + return nn.Identity() + elif activation == 'tanh': + return nn.Tanh() + elif activation == 'relu': + return nn.ReLU() + elif activation == 'gelu': + return nn.GELU() + elif activation in ['swish', 'silu']: + return nn.SiLU() + elif activation == 'glu': + return nn.GLU(dim=dim) + elif activation == 'sigmoid': + return nn.Sigmoid() + else: + raise NotImplementedError("hidden activation '{}' is not implemented".format(activation)) + +def LinearActivation( + d_input, d_output, bias=True, + transposed=False, + activation=None, + activate=False, # Apply activation as part of this module + **kwargs, + ): + """ Returns a linear nn.Module with control over axes order, initialization, and activation """ + + # Construct core module + linear_cls = partial(nn.Conv1d, kernel_size=1) if transposed else nn.Linear + if activation == 'glu': d_output *= 2 + linear = linear_cls(d_input, d_output, bias=bias, **kwargs) + + if activate and activation is not None: + activation = Activation(activation, dim=-2 if transposed else -1) + linear = nn.Sequential(linear, activation) + return linear + + +""" HiPPO utilities """ + +def random_dplr(N, H=1, scaling='inverse', real_scale=1.0, imag_scale=1.0): + dtype = torch.cfloat + + pi = torch.tensor(np.pi) + real_part = .5 * torch.ones(H, N//2) + imag_part = repeat(torch.arange(N//2), 'n -> h n', h=H) + + real_part = real_scale * real_part + if scaling == 'random': + imag_part = torch.randn(H, N//2) + elif scaling == 'linear': + imag_part = pi * imag_part + elif scaling == 'inverse': # Based on asymptotics of the default HiPPO matrix + imag_part = 1/pi * N * (N/(1+2*imag_part)-1) + else: raise NotImplementedError + imag_part = imag_scale * imag_part + w = -real_part + 1j * imag_part + + + B = torch.randn(H, N//2, dtype=dtype) + + norm = -B/w # (H, N) # Result if you integrate the kernel with constant 1 function + zeta = 2*torch.sum(torch.abs(norm)**2, dim=-1, keepdim=True) # Variance with a random C vector + B = B / zeta**.5 + + return w, B + + +class SSKernelDiag(nn.Module): + """ Version using (complex) diagonal state matrix. Note that it is slower and less memory efficient than the NPLR kernel because of lack of kernel support. + + """ + + def __init__( + self, + w, C, log_dt, + lr=None, + train_w = True, + train_dt = True, + **kwargs # For compatibility with other kernels + ): + + super().__init__() + + # Rank of low-rank correction + assert w.size(-1) == C.size(-1) + self.H = log_dt.size(-1) + self.N = w.size(-1) + assert self.H % w.size(0) == 0 + self.copies = self.H // w.size(0) + + # Broadcast everything to correct shapes + C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))) # (H, C, N) + + # Register parameters + self.C = nn.Parameter(_c2r(_resolve_conj(C))) + self.register("log_dt", log_dt, train_dt, lr, 0.0) + + log_w_real = torch.log(-w.real + 1e-4) + w_imag = w.imag + self.register("log_w_real", log_w_real, train_w, lr, 0.0) + self.register("w_imag", w_imag, train_w, lr, 0.0) + + + def _w(self): + # Get the internal w (diagonal) parameter + w_real = -torch.exp(self.log_w_real) + w_imag = self.w_imag + w = w_real + 1j * w_imag + w = repeat(w, 't n -> (v t) n', v=self.copies) # (H N) + return w + + def forward(self, L): + """ + returns: (..., c, L) where c is number of channels (default 1) + """ + + dt = torch.exp(self.log_dt) # (H) + C = _r2c(self.C) # (C H N) + w = self._w() # (H N) + + # Incorporate dt into A + dtA = w * dt.unsqueeze(-1) # (H N) + + # Power up + K = dtA.unsqueeze(-1) * torch.arange(L, device=w.device) # (H N L) + C = C * (torch.exp(dtA)-1.) / w + K = contract('chn, hnl -> chl', C, torch.exp(K)) + K = 2*K.real + # Keops version is more memory efficient + # C = C * (torch.exp(dtA)-1.) / w + # K = log_vandermonde(C, dtA, L) # (H L) + + return K + + def setup_step(self): + dt = torch.exp(self.log_dt) # (H) + C = _r2c(self.C) # (C H N) + w = self._w() # (H N) + + # Incorporate dt into A + dtA = w * dt.unsqueeze(-1) # (H N) + self.dA = torch.exp(dtA) # (H N) + self.dC = C * (torch.exp(dtA)-1.) / w # (C H N) + self.dB = self.dC.new_ones(self.H, self.N) # (H N) + + def default_state(self, *batch_shape): + C = _r2c(self.C) + state = torch.zeros(*batch_shape, self.H, self.N, dtype=C.dtype, device=C.device) + return state + + def step(self, u, state): + next_state = contract("h n, b h n -> b h n", self.dA, state) \ + + contract("h n, b h -> b h n", self.dB, u) + y = contract("c h n, b h n -> b c h", self.dC, next_state) + return 2*y.real, next_state + + + def register(self, name, tensor, trainable=False, lr=None, wd=None): + """Utility method: register a tensor as a buffer or trainable parameter""" + + if trainable: + self.register_parameter(name, nn.Parameter(tensor)) + else: + self.register_buffer(name, tensor) + + optim = {} + if trainable and lr is not None: + optim["lr"] = lr + if trainable and wd is not None: + optim["weight_decay"] = wd + if len(optim) > 0: + setattr(getattr(self, name), "_optim", optim) + +class S4DKernel(nn.Module): + """Wrapper around SSKernelDiag that generates the diagonal SSM parameters + """ + + def __init__( + self, + H, + N=64, + scaling="inverse", + channels=1, # 1-dim to C-dim map; can think of C as having separate "heads" + dt_min=0.001, + dt_max=0.1, + lr=None, # Hook to set LR of SSM parameters differently + n_ssm=1, # Copies of the ODE parameters A and B. Must divide H + **kernel_args, + ): + super().__init__() + self.N = N + self.H = H + dtype = torch.float + cdtype = torch.cfloat + self.channels = channels + self.n_ssm = n_ssm + + # Generate dt + log_dt = torch.rand(self.H, dtype=dtype) * ( + math.log(dt_max) - math.log(dt_min) + ) + math.log(dt_min) + + # Compute the preprocessed representation + # Generate low rank correction p for the measure + w, B = random_dplr(self.N, H=n_ssm, scaling=scaling) + + C = torch.randn(channels, self.H, self.N // 2, dtype=cdtype) + + # Broadcast tensors to n_ssm copies + # These will be the parameters, so make sure tensors are materialized and contiguous + B = repeat(B, 't n -> (v t) n', v=self.n_ssm // B.size(-2)).clone().contiguous() + w = repeat(w, 't n -> (v t) n', v=self.n_ssm // w.size(-2)).clone().contiguous() + + # Combine B and C using structure of diagonal SSM + C = C * repeat(B, 't n -> (v t) n', v=H//self.n_ssm) + self.kernel = SSKernelDiag( + w, C, log_dt, + lr=lr, + **kernel_args, + ) + + def forward(self, L=None): + k = self.kernel(L=L) + return k.float() + + def setup_step(self): + self.kernel.setup_step() + + def step(self, u, state, **kwargs): + u, state = self.kernel.step(u, state, **kwargs) + return u.float(), state + + def default_state(self, *args, **kwargs): + return self.kernel.default_state(*args, **kwargs) + + +class S4D(nn.Module): + + def __init__( + self, + d_model, + d_state=64, + channels=1, # maps 1-dim to C-dim + bidirectional=False, + # Arguments for FF + activation='gelu', # activation in between SS and FF + postact=None, # activation after FF + dropout=0.0, + transposed=True, # axis ordering (B, L, D) or (B, D, L) + return_state=True, # return state in addition to output + # SSM Kernel arguments + **kernel_args, + ): + """ + d_state: the dimension of the state, also denoted by N + channels: can be interpreted as a number of "heads" + bidirectional: bidirectional + dropout: standard dropout argument + transposed: choose backbone axis ordering of (B, L, H) or (B, H, L) [B=batch size, L=sequence length, H=hidden dimension] + + Other options are all experimental and should not need to be configured + """ + + super().__init__() + + self.h = d_model + self.n = d_state + self.bidirectional = bidirectional + self.channels = channels + self.transposed = transposed + self.return_state = return_state + + self.D = nn.Parameter(torch.randn(channels, self.h)) + + if self.bidirectional: + channels *= 2 + + # SSM Kernel + self.kernel = S4DKernel(self.h, N=self.n, channels=channels, **kernel_args) + + # Pointwise + self.activation = Activation(activation) + dropout_fn = nn.Dropout2d if self.transposed else nn.Dropout + self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() + + # position-wise output transform to mix features + self.output_linear = LinearActivation( + self.h*self.channels, + self.h, + transposed=self.transposed, + activation=postact, + activate=True, + ) + + + def forward(self, u, **kwargs): # absorbs return_output and transformer src mask + """ + u: (B H L) if self.transposed else (B L H) + state: (H N) never needed unless you know what you're doing + + Returns: same shape as u + """ + if not self.transposed: u = u.transpose(-1, -2) + L = u.size(-1) + + # Compute SS Kernel + k = self.kernel(L=L) # (C H L) (B C H L) + + # Convolution + if self.bidirectional: + k0, k1 = rearrange(k, '(s c) h l -> s c h l', s=2) + k = F.pad(k0, (0, L)) \ + + F.pad(k1.flip(-1), (L, 0)) \ + + k_f = torch.fft.rfft(k, n=2*L) # (C H L) + u_f = torch.fft.rfft(u, n=2*L) # (B H L) + y_f = contract('bhl,chl->bchl', u_f, k_f) # k_f.unsqueeze(-4) * u_f.unsqueeze(-3) # (B C H L) + y = torch.fft.irfft(y_f, n=2*L)[..., :L] # (B C H L) + + + # Compute D term in state space equation - essentially a skip connection + y = y + contract('bhl,ch->bchl', u, self.D) # u.unsqueeze(-3) * self.D.unsqueeze(-1) + + # Reshape to flatten channels + y = rearrange(y, '... c h l -> ... (c h) l') + + y = self.dropout(self.activation(y)) + + if not self.transposed: y = y.transpose(-1, -2) + + y = self.output_linear(y) + + if self.return_state: + return y, None # Return a None to satisfy this repo's interface, but this can be modified + else: + return y + + def setup_step(self): + self.kernel.setup_step() + + def step(self, u, state): + """ Step one time step as a recurrent model. Intended to be used during validation. + + u: (B H) + state: (B H N) + Returns: output (B H), state (B H N) + """ + assert not self.training + + y, next_state = self.kernel.step(u, state) # (B C H) + y = y + u.unsqueeze(-2) * self.D + y = rearrange(y, '... c h -> ... (c h)') + y = self.activation(y) + if self.transposed: + y = self.output_linear(y.unsqueeze(-1)).squeeze(-1) + else: + y = self.output_linear(y) + return y, next_state + + def default_state(self, *batch_shape, device=None): + return self.kernel.default_state(*batch_shape) + + @property + def d_state(self): + return self.h * self.n + + @property + def d_output(self): + return self.h + + @property + def state_to_tensor(self): + return lambda state: rearrange('... h n -> ... (h n)', state) diff --git a/src/clm/src/models/sequence/ssm/ss_kernel.py b/src/clm/src/models/sequence/ssm/ss_kernel.py new file mode 100644 index 00000000..b0079898 --- /dev/null +++ b/src/clm/src/models/sequence/ssm/ss_kernel.py @@ -0,0 +1,180 @@ +# TD: [2023-01-05]: Extracted the SSKernel class from +# https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/sequence/ss/kernel.py +# We add option to use the shift kernel, and remove the option of SSKernelNPLR + +"""SSM convolution kernels. +SSKernel wraps different kernels with common options and handles the initialization. +""" + +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange, repeat +from opt_einsum import contract + +from clm.src.models.sequence.ssm.ss_kernel_diag import SSKernelDiag, EMAKernel +from clm.src.models.sequence.ssm.ss_kernel_shift import SSKernelShift + +from clm.src.models.sequence.ssm import hippo +from clm.src.models.sequence.ssm import dplr +from clm.src.ops.krylov import power + +from clm.src.utils.train import get_logger + +log = get_logger(__name__) + + +_conj = lambda x: torch.cat([x, x.conj()], dim=-1) + + +class SSKernel(nn.Module): + """Wrapper around SSKernel parameterizations. + + The SSKernel is expected to support the interface + forward() + default_state() + _setup_step() + step() + """ + + def __init__( + self, + H, + N=64, + L=None, + measure="diag-lin", + rank=1, + channels=1, + dt_min=0.001, + dt_max=0.1, + deterministic=False, + lr=None, + mode="diag", + n_ssm=None, + verbose=False, + measure_args={}, + **kernel_args, + ): + """State Space Kernel which computes the convolution kernel $\\bar{K}$ + + H: Number of independent SSM copies; controls the size of the model. Also called d_model in the config. + N: State size (dimensionality of parameters A, B, C). Also called d_state in the config. Generally shouldn't need to be adjusted and doens't affect speed much. + L: Maximum length of convolution kernel, if known. Should work in the majority of cases even if not known. + measure: Options for initialization of (A, B). For NPLR mode, recommendations are "legs", "fout", "hippo" (combination of both). For Diag mode, recommendations are "diag-inv", "diag-lin", "diag-legs", and "diag" (combination of diag-inv and diag-lin) + rank: Rank of low-rank correction for NPLR mode. Needs to be increased for measure "legt" + channels: C channels turns the SSM from a 1-dim to C-dim map; can think of it having C separate "heads" per SSM. This was partly a feature to make it easier to implement bidirectionality; it is recommended to set channels=1 and adjust H to control parameters instead + dt_min, dt_max: min and max values for the step size dt (\Delta) + mode: Which kernel algorithm to use. 'nplr' is the full S4 model; 'diag' is the simpler S4D; 'slow' is a dense version for testing + n_ssm: Number of independent trainable (A, B) SSMs, e.g. n_ssm=1 means all A/B parameters are tied across the H different instantiations of C. n_ssm=None means all H SSMs are completely independent. Generally, changing this option can save parameters but doesn't affect performance or speed much. This parameter must divide H + lr: Passing in a number (e.g. 0.001) sets attributes of SSM parameers (A, B, dt). A custom optimizer hook is needed to configure the optimizer to set the learning rates appropriately for these parameters. + """ + super().__init__() + self.N = N + self.H = H + dtype, cdtype = torch.float, torch.cfloat + self.channels = channels + self.n_ssm = n_ssm if n_ssm is not None else H + self.mode = mode + self.verbose = verbose + self.kernel_args = kernel_args + + # Generate dt + if deterministic: + log_dt = torch.exp(torch.linspace(math.log(dt_min), math.log(dt_max), H)) + else: + log_dt = torch.rand(self.H, dtype=dtype) * ( + math.log(dt_max) - math.log(dt_min) + ) + math.log(dt_min) + + # Compute the preprocessed representation + if mode == "ema": + self.kernel = EMAKernel(H, N=N, channels=channels, **kernel_args) + else: + w, P, B, V = dplr.combination(measure, self.N, rank, self.n_ssm, **measure_args) + + # Broadcast C to have H channels + if deterministic: + C = torch.zeros(channels, self.n_ssm, self.N, dtype=cdtype) + C[:, :, :1] = 1. + C = contract('hmn, chn -> chm', V.conj().transpose(-1, -2), C) # V^* C + C = repeat(C, 'c t n -> c (v t) n', v=self.n_ssm // C.size(-2)).clone().contiguous() + else: + C = torch.randn(channels, self.H, self.N//2, dtype=cdtype) + + # Broadcast other parameters to have n_ssm copies + assert self.n_ssm % B.size(-2) == 0 \ + and self.n_ssm % P.size(-2) == 0 \ + and self.n_ssm % w.size(-2) == 0 + # Broadcast tensors to n_ssm copies + # These will be the parameters, so make sure tensors are materialized and contiguous + B = repeat(B, 't n -> (v t) n', v=self.n_ssm // B.size(-2)).clone().contiguous() + P = repeat(P, 'r t n -> r (v t) n', v=self.n_ssm // P.size(-2)).clone().contiguous() + w = repeat(w, 't n -> (v t) n', v=self.n_ssm // w.size(-2)).clone().contiguous() + + if mode == "diag": + if not measure.startswith("diag"): + log.warning("Diagonal kernel (S4D) activated but initialization is not intended for S4D. Set `measure` to 'diag-lin', 'diag-inv', or 'diag-legs' for the main variants, or 'diag' for a combination of S4D-Lin and S4D-Inv.") + C = C * repeat(B, 't n -> (v t) n', v=H//self.n_ssm) + self.kernel = SSKernelDiag( + w, B, C, log_dt, L=L, + lr=lr, + **kernel_args, + ) + elif mode == 'shift': + # Initializing B to be e_1 + B = torch.zeros(self.H, self.N) + B[..., 0] = 1.0 + # Match torch.Conv1d init + C = torch.randn(self.H, self.channels, self.N) + nn.init.kaiming_uniform_(C, a=math.sqrt(5)) + C = rearrange(C, 'h c n -> c h n') + self.kernel = SSKernelShift(B, C, L=L, lr=lr, **kernel_args) + else: + raise NotImplementedError(f"{mode=} is not valid") + + def forward(self, state=None, L=None, rate=None): + return self.kernel(state=state, L=L, rate=rate) + + @torch.no_grad() + def forward_state(self, u, state): + """ Forward the state through a sequence, i.e. computes the state after passing chunk through SSM + + state: (B, H, N) + u: (B, H, L) + + Returns: (B, H, N) + """ + + if hasattr(self.kernel, "forward_state"): + return self.kernel.forward_state(u, state) + + dA, dB = self.kernel._setup_state() # Construct dA, dB matrices + # dA, dB = self.kernel.dA, self.kernel.dB # (H N N) (H N) + + conj = state.size(-1) != dA.size(-1) + if conj: state = _conj(state) + + v = contract('h n, b h l -> b h n l', dB, u.flip(-1)) # dB.unsqueeze(-1) * u.flip(-1).unsqueeze(-2) + AL, v = power(u.size(-1), dA, v) + next_state = contract("h m n, b h n -> b h m", AL, state) + next_state = next_state + v + + if conj: next_state = next_state[..., : next_state.size(-1) // 2] + return next_state + + def _setup_step(self, **kwargs): + # This method is intended to be private so that setting up an S4 module with + # ``` + # if hasattr(module, 'setup_step'): module.setup_step() + # ``` + # will not trigger this method multiple times + self.kernel._setup_step(**kwargs) + + def step(self, u, state, **kwargs): + y, state = self.kernel.step(u, state, **kwargs) + return y, state + + def default_state(self, *args, **kwargs): + return self.kernel.default_state(*args, **kwargs) diff --git a/src/clm/src/models/sequence/ssm/ss_kernel_diag.py b/src/clm/src/models/sequence/ssm/ss_kernel_diag.py new file mode 100644 index 00000000..49ab0118 --- /dev/null +++ b/src/clm/src/models/sequence/ssm/ss_kernel_diag.py @@ -0,0 +1,331 @@ +# TD: [2023-01-05]: Extracted the SSKernelDiag class from +# https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/sequence/ss/kernel.py +# We make a small change to use the log_vandermonde CUDA code. + +"""SSKernelDiag is the S4D kernel, a simpler algorithm for computing the kernel for the case of diagonal state matrices A. +""" +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange, repeat +from opt_einsum import contract + +from clm.src.utils.train import OptimModule + +from clm.src.utils.train import get_logger + +log = get_logger(__name__) + +# This could be None if the CUDA import fails +from clm.src.ops.vandermonde import log_vandermonde_fast +try: + import pykeops + from clm.src.ops.vandermonde import log_vandermonde, log_vandermonde_transpose + has_pykeops = True + log.info("Pykeops installation found.") +except ImportError: + has_pykeops = False + from clm.src.ops.vandermonde import log_vandermonde_naive as log_vandermonde + from clm.src.ops.vandermonde import log_vandermonde_transpose_naive as log_vandermonde_transpose + log.warning( + "Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency." + ) + + +_c2r = torch.view_as_real +_r2c = torch.view_as_complex + +if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 10): + _resolve_conj = lambda x: x.conj().resolve_conj() +else: + _resolve_conj = lambda x: x.conj() + + +class SSKernelDiag(OptimModule): + """Version using (complex) diagonal state matrix (S4D)""" + + def __init__( + self, + A, B, C, log_dt, + L=None, + disc='bilinear', + real_type='exp', + lr=None, + bandlimit=None, + force_real=False, + ): + + super().__init__() + self.L = L + self.disc = disc + self.bandlimit = bandlimit + self.real_type = real_type + self.force_real = force_real + + # Rank of low-rank correction + assert A.size(-1) == C.size(-1) + self.H = log_dt.size(-1) + self.N = A.size(-1) + assert A.size(-2) == B.size(-2) # Number of independent SSMs trained + assert self.H % A.size(-2) == 0 + self.n_ssm = A.size(-2) + self.repeat = self.H // A.size(0) + + self.channels = C.shape[0] + self.C = nn.Parameter(_c2r(_resolve_conj(C))) + + # Register parameters + if lr is None or isinstance(lr, float): lr_dict = {} + else: lr_dict, lr = lr, None + + self.register("log_dt", log_dt, lr_dict.get('dt', lr)) + self.register("B", _c2r(B), lr_dict.get('B', lr)) + self.register("inv_A_real", self._A_init(A.real), lr_dict.get('A', lr)) + self.register("A_imag", A.imag, lr_dict.get('A', lr)) + + def _A_init(self, A_real): + A_real = torch.clamp(A_real, max=-1e-4) + if self.real_type == 'none': + return -A_real + elif self.real_type == 'exp': + return torch.log(-A_real) # Some of the HiPPO methods have real part 0 + elif self.real_type == 'relu': + return -A_real + elif self.real_type == 'sigmoid': + return torch.logit(-A_real) + elif self.real_type == 'softplus': + return torch.log(torch.exp(-A_real)-1) + else: raise NotImplementedError + + def _A(self): + # Get the internal A (diagonal) parameter + if self.real_type == 'none': + A_real = -self.inv_A_real + elif self.real_type == 'exp': + A_real = -torch.exp(self.inv_A_real) + elif self.real_type == 'relu': + # JAX version seems to NaN if you alloA 0's, although this code Aas fine Aithout it + A_real = -F.relu(self.inv_A_real)-1e-4 + elif self.real_type == 'sigmoid': + A_real = -F.sigmoid(self.inv_A_real) + elif self.real_type == 'softplus': + A_real = -F.softplus(self.inv_A_real) + else: raise NotImplementedError + A = A_real + 1j * self.A_imag + return A + + def forward(self, L, state=None, rate=1.0, u=None): + """ + state: (B, H, N) initial state + rate: sampling rate factor + L: target length + returns: + (C, H, L) convolution kernel (generally C=1) + (B, H, L) output from initial state + """ + + dt = torch.exp(self.log_dt) * rate # (H) + C = _r2c(self.C) # (C H N) + A = self._A() # (H N) + + B = _r2c(self.B) + B = repeat(B, 't n -> 1 (v t) n', v=self.repeat) + + # Force A to be real valued, so the whole kernel can be interpreted as a "multi-head EMA" + if self.force_real: + A = A.real + 0j + + if self.bandlimit is not None: + freqs = dt[:, None] / rate * A.imag.abs() / (2*math.pi) # (H, N) + mask = torch.where(freqs < self.bandlimit * .5, 1, 0) + C = C * mask + + # Incorporate dt into A + A = repeat(A, 't n -> (v t) n', v=self.repeat) + dtA = A * dt.unsqueeze(-1) # (H N) + + + # Augment B with state + if state is not None: + s = state / dt.unsqueeze(-1) + if self.disc == 'bilinear': + s = s * (1. + dtA/2) + elif self.disc == 'zoh': + s = s * dtA * dtA.exp() / (dtA.exp() - 1.) + B = torch.cat([s, B], dim=-3) # (1+B H N) + + C = (B[:, None, :, :] * C).view(-1, self.H, self.N) + if self.disc == 'zoh': + # Power up + C = C * (torch.exp(dtA)-1.) / A + # TODO (TD): make it work for C.shape[0] > 1 + if log_vandermonde_fast is not None and C.shape[0] == 1: + K = log_vandermonde_fast(C.squeeze(0), dtA, L).unsqueeze(0) # (H L) + else: + K = log_vandermonde(C, dtA, L) # (H L) + elif self.disc == 'bilinear': + C = C * (1. - dtA/2).reciprocal() * dt.unsqueeze(-1) # or * dtA / A + dA = (1. + dtA/2) / (1. - dtA/2) + if log_vandermonde_fast is not None: + dA_log = repeat(dA.log(), 'h d -> (c h) d', c=C.shape[0]) + K = rearrange(log_vandermonde_fast(rearrange(C, 'c h d -> (c h) d'), dA_log, L), + '(c h) d -> c h d', c=C.shape[0]) + else: + K = log_vandermonde(C, dA.log(), L) + elif self.disc == 'dss': + # Implementation from DSS meant for case when real eigenvalues can be positive + P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L] + A_gt_0 = A.real > 0 # [N] + if A_gt_0.any(): + with torch.no_grad(): + P_max = dtA * (A_gt_0 * (L-1)) # [H N] + P = P - P_max.unsqueeze(-1) # [H N L] + S = P.exp() # [H N L] + + dtA_neg = dtA * (1 - 2*A_gt_0) # [H N] + num = dtA_neg.exp() - 1 # [H N] + den = (dtA_neg * L).exp() - 1 # [H N] + + # Inline reciprocal function for DSS logic + x = den * A + x_conj = _resolve_conj(x) + r = x_conj / (x*x_conj + 1e-7) + + C = C * num * r # [C H N] + K = contract('chn,hnl->chl', C, S).float() + else: assert False, f"{self.disc} not supported" + + K = K.view(-1, self.channels, self.H, L) # (1+B C H L) + if state is not None: + K_state = K[:-1, :, :, :] # (B C H L) + else: + K_state = None + K = K[-1, :, :, :] # (C H L) + return K, K_state + + def _setup_step(self): + # These methods are organized like this to be compatible with the NPLR kernel interface + dt = torch.exp(self.log_dt) # (H) + B = _r2c(self.B) # (H N) + C = _r2c(self.C) # (C H N) + self.dC = C + A = self._A() # (H N) + + A = repeat(A, 't n -> (v t) n', v=self.repeat) + B = repeat(B, 't n -> (v t) n', v=self.repeat) + + # Incorporate dt into A + dtA = A * dt.unsqueeze(-1) # (H N) + if self.disc == 'zoh': + self.dA = torch.exp(dtA) # (H N) + self.dB = B * (torch.exp(dtA)-1.) / A # (C H N) + elif self.disc == 'bilinear': + self.dA = (1. + dtA/2) / (1. - dtA/2) + self.dB = B * (1. - dtA/2).reciprocal() * dt.unsqueeze(-1) # or * dtA / A + + + def default_state(self, *batch_shape): + C = _r2c(self.C) + state = torch.zeros(*batch_shape, self.H, self.N, dtype=C.dtype, device=C.device) + return state + + def step(self, u, state): + next_state = contract("h n, b h n -> b h n", self.dA, state) \ + + contract("h n, b h -> b h n", self.dB, u) + y = contract("c h n, b h n -> b c h", self.dC, next_state) + return 2*y.real, next_state + + def forward_state(self, u, state): + self._setup_step() + AL = self.dA ** u.size(-1) + u = u.flip(-1).to(self.dA).contiguous() # (B H L) + v = log_vandermonde_transpose(u, self.dB, self.dA.log(), u.size(-1)) + next_state = AL * state + v + return next_state + + +class EMAKernel(OptimModule): + """Translation of Mega's MultiHeadEMA. + This is a minimal implementation of the convolution kernel part of the module. + This module, together with the main S4 block in clm.src.models.sequence.ss.s4 + (which is really just a fft-conv wrapper around any convolution kernel, + such as this one), should be exactly equivalent to using the original Mega + EMA module in clm.src.models.sequence.ss.ema. + Two additional flags have been provided to resolve discrepencies in parameter + count between S4(D) and EMA + - `dt_tie` makes the shape of the step size \Delta (H, 1) instead of (H, N) + - `efficient_bidirectional` ties the A/B/dt parameters for the conv kernels + in both forwards and backwards directions. This should have exactly the same + speed, slightly more parameter efficiency, and unchanged performance. + """ + + def __init__( + self, + H, + N=2, + channels=1, + l_max=None, + dt_tie=False, + efficient_bidirectional=False, + ): + super().__init__() + + self.H = H + self.N = N + self.channels = channels + self.l_max = l_max + self.scale = math.sqrt(1.0 / self.N) + + # Exactly match the parameter count of S4(D) when bididirectional is on + self.efficient_bidirectional = efficient_bidirectional + if self.efficient_bidirectional: + H_C = H * channels + else: + H *= channels + H_C = H + + self.delta = nn.Parameter(torch.Tensor(H, 1 if dt_tie else N, 1)) + self.alpha = nn.Parameter(torch.Tensor(H, N, 1)) + self.beta = nn.Parameter(torch.Tensor(H, N, 1)) + self.gamma = nn.Parameter(torch.Tensor(H_C, N)) + # self.omega = nn.Parameter(torch.Tensor(H)) # D skip connection handled by outside class + + self.reset_parameters() + + def reset_parameters(self): + with torch.no_grad(): + nn.init.normal_(self.delta, mean=0.0, std=0.2) + nn.init.normal_(self.alpha, mean=0.0, std=0.2) + # Mega comment: beta [1, -1, 1, -1, ...] seems more stable. + val = torch.ones(self.N, 1) + if self.N > 1: + idx = torch.tensor(list(range(1, self.N, 2))) + val.index_fill_(0, idx, -1.0) + self.beta.normal_(mean=0.0, std=0.02).add_(val) + nn.init.normal_(self.gamma, mean=0.0, std=1.0) + # nn.init.normal_(self.omega, mean=0.0, std=1.0) + + def coeffs(self): # Same as discretize + p = torch.sigmoid(self.delta) # (H N 1) + alpha = torch.sigmoid(self.alpha) + q = 1.0 - p * alpha + return p, q + + def forward(self, L=None, state=None, rate=1.0): + L = L if self.l_max is None else min(self.l_max, L) + p, q = self.coeffs() # (H N 1) + vander = torch.arange(L).to(p).view(1, 1, L) * torch.log(q) # (H N L) + kernel = (p * self.beta) * torch.exp(vander) + if self.efficient_bidirectional: + C = rearrange(self.gamma * self.scale, '(c h) n -> c h n', c=self.channels) + kernel = torch.einsum('dnl,cdn->cdl', kernel, C) + # kernel = rearrange(kernel, 'c d l -> (c d) l') + else: + kernel = torch.einsum('dnl,dn->dl', kernel, self.gamma * self.scale) + kernel = rearrange(kernel, '(c h) l -> c h l', c=self.channels) + + kernel = kernel[..., :L] + # kernel = rearrange(kernel, '(c h) l -> c h l', c=self.channels) + return kernel, None # k_state diff --git a/src/clm/src/models/sequence/ssm/ss_kernel_shift.py b/src/clm/src/models/sequence/ssm/ss_kernel_shift.py new file mode 100644 index 00000000..b926297a --- /dev/null +++ b/src/clm/src/models/sequence/ssm/ss_kernel_shift.py @@ -0,0 +1,83 @@ +# TD: [2023-01-05]: Extracted the SSKernelDiag class from +# https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/sequence/ss/kernel.py +# We make a small change to use the log_vandermonde CUDA code. + +"""SSKernelDiag is the S4D kernel, a simpler algorithm for computing the kernel for the case of diagonal state matrices A. +""" +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange, repeat +from opt_einsum import contract + +from clm.src.utils.train import OptimModule + + +class SSKernelShift(OptimModule): + + def __init__(self, B, C, L=None, lr=None, **kwargs): + """ + B: (H, d), real + C: (channel, H, d), real + """ + super().__init__() + self.L = L + self.N = B.size(-1) + self.H = B.shape[0] + + # Register parameters + if lr is None or isinstance(lr, float): lr_dict = {} + else: lr_dict, lr = lr, None + self.register("B", B, lr_dict.get('B', lr)) + self.C = nn.Parameter(C) + + def forward(self, state=None, rate=1.0, L=None): + if L is None: + L = self.L + # This class doesn't support variable length functionalities, since it's a discrete SSM + assert rate == 1.0 and L is not None + + # Augment B with state + B = self.B + if state is not None: + B = rearrange(torch.cat([rearrange(B, 'h n -> 1 h n'), state], dim=-3), + 'bp1 h n -> bp1 1 h n') # (1 + B, 1, H, N) + B_f = torch.fft.rfft(B, n=2*self.N) + C_f = torch.fft.rfft(self.C, n=2*self.N) + k = torch.fft.irfft(B_f.conj() * C_f, n=2*self.N)[..., :min(self.N, L)] + # If self.N < L, need to pad with zeros to reach length L + if self.N < L: + k = F.pad(k, (0, L - self.N)) + k = k.float() # Otherwise it could be dtype half + if state is not None: + k, k_state = k[0], k[1:] + else: + k_state = None + return k, k_state + + def _setup_step(self): + # Just here to conform to the interface, eventually we should refactor out + pass + + def default_state(self, *batch_shape): + return torch.zeros(*batch_shape, self.H, self.N, dtype=self.C.dtype, device=self.C.device) + + def step(self, u, state): + """u: (B, H), state: (B, H, N)""" + next_state = F.pad(state, (1, -1)) + contract("h n, b h -> b h n", self.B, u) + y = contract("c h n, b h n -> b c h", self.C, next_state) + return y, next_state + + def forward_state(self, u, state): + """u: (B, H, L), state: (B, H, N)""" + L = u.shape[-1] + B_f = torch.fft.rfft(self.B, n=2 * self.N) + u_f = torch.fft.rfft(u[..., -self.N:].flip(-1).to(dtype=self.B.dtype), n=2 * self.N) + v = torch.fft.irfft(B_f * u_f, n=2 * self.N)[..., :self.N] + if L < self.N: + next_state = F.pad(state, (L, -L)) + v + else: + next_state = v + return next_state diff --git a/src/clm/src/ops/fftconv.py b/src/clm/src/ops/fftconv.py new file mode 100644 index 00000000..b5d2749b --- /dev/null +++ b/src/clm/src/ops/fftconv.py @@ -0,0 +1,103 @@ +import math + +import torch +import torch.nn.functional as F + +from einops import rearrange + +from fftconv import fftconv_fwd, fftconv_bwd + +@torch.jit.script +def _mul_sum(y, q): + return (y * q).sum(dim=1) + +# reference convolution with residual connection +def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None): + seqlen = u.shape[-1] + fft_size = 2 * seqlen + k_f = torch.fft.rfft(k, n=fft_size) / fft_size + if k_rev is not None: + k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size + k_f = k_f + k_rev_f.conj() + u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) + y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen] + out = y + u * D.unsqueeze(-1) + if gelu: + out = F.gelu(out) + if dropout_mask is not None: + return (out * rearrange(dropout_mask, 'b H -> b H 1')).to(dtype=u.dtype) + else: + return out.to(dtype=u.dtype) + +# reference H3 forward pass +def fftconv_h3_ref(k, ssm_kernel, D, q, v, head_dim=1, ssm_kernel_rev=None): + seqlen = k.shape[-1] + fft_size = 2 * seqlen + kv = (rearrange(k, 'b (h d1) l -> b d1 1 h l', d1=head_dim) + * rearrange(v, 'b (h d2) l -> b 1 d2 h l', d2=head_dim)) # b d1 d2 h l + kv_f = torch.fft.rfft(kv.to(dtype=ssm_kernel.dtype), n=fft_size) / fft_size + ssm_kernel_f = torch.fft.rfft(ssm_kernel, n=fft_size) # h L+1 + if ssm_kernel_rev is not None: + ssm_kernel_rev_f = torch.fft.rfft(ssm_kernel_rev, n=fft_size) # h L+1 + ssm_kernel_f = ssm_kernel_f + ssm_kernel_rev_f.conj() + y = torch.fft.irfft(kv_f * ssm_kernel_f, n=fft_size, norm='forward')[..., :seqlen] # b d1 d2 h l + out = y + kv * D.unsqueeze(-1) # b d1 d2 h l + q = rearrange(q, 'b (h d1) l -> b d1 1 h l', d1=head_dim) + if head_dim > 1: + out = _mul_sum(out, q) + return rearrange(out, 'b d2 h l -> b (h d2) l').to(dtype=k.dtype) + else: + return rearrange(out * q, 'b 1 1 h l -> b h l').to(dtype=k.dtype) + + +class FFTConvFunc(torch.autograd.Function): + + @staticmethod + def forward(ctx, u, k, D, dropout_mask=None, gelu=True, force_fp16_output=False, + output_hbl_layout=False, v=None, head_dim=1, q=None, fftfp16=False, k_rev=None): + seqlen = u.shape[-1] + fft_size = max(2 * 2 ** int(math.ceil(math.log2(seqlen))), 16) + k_f = torch.fft.rfft(k, n=fft_size) + if k_rev is not None: + k_f = k_f + torch.fft.rfft(k_rev, n=fft_size).conj() + if u.stride(-1) != 1: + u = u.contiguous() + k_f = k_f.contiguous() + D = D.contiguous() + if v is not None and v.stride(-1) != 1: + v = v.contiguous() + if q is not None and q.stride(-1) != 1: + q = q.contiguous() + if dropout_mask is not None: + dropout_mask = dropout_mask.contiguous() + ctx.save_for_backward(u, k_f, D, dropout_mask, v, q) + ctx.output_hbl_layout = output_hbl_layout + ctx.head_dim = head_dim + ctx.gelu = gelu + ctx.fftfp16 = fftfp16 + ctx.has_k_rev = k_rev is not None + out = fftconv_fwd(u, k_f, D, v, head_dim, q, dropout_mask, gelu, False, False, fft_size, force_fp16_output, output_hbl_layout, fftfp16) + return out + + @staticmethod + def backward(ctx, dout): + if ctx.output_hbl_layout: + dout = rearrange(rearrange(dout, 'b h l -> h b l').contiguous(), 'h b l -> b h l') + else: + dout = dout.contiguous() + u, k_f, D, dropout_mask, v, q = ctx.saved_tensors + seqlen = u.shape[-1] + fft_size = max(2 * 2 ** int(math.ceil(math.log2(seqlen))), 16) + du, dk_f, dD, dv, dq = fftconv_bwd(dout, u, k_f, D, v, ctx.head_dim, q, dropout_mask, ctx.gelu, False, False, fft_size, + ctx.output_hbl_layout, ctx.fftfp16) + dk = torch.fft.irfft(dk_f, n=fft_size, norm='forward')[..., :seqlen] + dk_rev = (None if not ctx.has_k_rev + else torch.fft.irfft(dk_f.conj(), n=fft_size, norm='forward')[..., :seqlen]) + if v is not None: + dv = dv.to(dtype=v.dtype) # We do atomicAdd in fp32 so might need to convert to fp16 + return du, dk, dD, None, None, None, None, dv if v is not None else None, None, dq if q is not None else None, None, dk_rev + +def fftconv_func(u, k, D, dropout_mask=None, gelu=True, force_fp16_output=False, + output_hbl_layout=False, v=None, head_dim=1, q=None, fftfp16=False, k_rev=None): + return FFTConvFunc.apply(u, k, D, dropout_mask, gelu, force_fp16_output, + output_hbl_layout, v, head_dim, q, fftfp16, k_rev) diff --git a/src/clm/src/ops/krylov.py b/src/clm/src/ops/krylov.py new file mode 100644 index 00000000..34544252 --- /dev/null +++ b/src/clm/src/ops/krylov.py @@ -0,0 +1,198 @@ +""" Compute a Krylov function efficiently. (S4 renames the Krylov function to a "state space kernel") + +A : (N, N) +b : (N,) +c : (N,) +Return: [c^T A^i b for i in [L]] +""" + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + +from clm.src.ops.toeplitz import causal_convolution + +def krylov_sequential(L, A, b, c=None): + """ Constant matrix A + + A : (..., N, N) + b : (..., N) + c : (..., N) + + Returns + if c: + x : (..., L) + x[i, l] = c[i] @ A^l @ b[i] + + else: + x : (..., N, L) + x[i, l] = A^l @ b[i] + """ + + # Check which of dim b and c is smaller to save memory + if c is not None and c.numel() < b.numel(): + return krylov_sequential(L, A.transpose(-1, -2), c, b) + + b_ = b + x = [] + for _ in range(L): + if c is not None: + x_ = torch.sum(c*b_, dim=-1) # (...) # could be faster with matmul or einsum? + else: + x_ = b_ + x.append(x_) + b_ = (A @ b_.unsqueeze(-1)).squeeze(-1) + + x = torch.stack(x, dim=-1) + return x + + +def krylov(L, A, b, c=None, return_power=False): + """ + Compute the Krylov matrix (b, Ab, A^2b, ...) using the squaring trick. + + If return_power=True, return A^{L-1} as well + """ + # TODO There is an edge case if L=1 where output doesn't get broadcasted, which might be an issue if caller is expecting broadcasting semantics... can deal with it if it arises + + x = b.unsqueeze(-1) # (..., N, 1) + A_ = A + + AL = None + if return_power: + AL = torch.eye(A.shape[-1], dtype=A.dtype, device=A.device) + _L = L-1 + + done = L == 1 + # loop invariant: _L represents how many indices left to compute + while not done: + if return_power: + if _L % 2 == 1: AL = A_ @ AL + _L //= 2 + + # Save memory on last iteration + l = x.shape[-1] + if L - l <= l: + done = True + _x = x[..., :L-l] + else: _x = x + + _x = A_ @ _x + x = torch.cat([x, _x], dim=-1) # there might be a more efficient way of ordering axes + if not done: A_ = A_ @ A_ + + assert x.shape[-1] == L + + if c is not None: + x = torch.einsum('...nl, ...n -> ...l', x, c) + x = x.contiguous() # WOW!! + if return_power: + return x, AL + else: + return x + +@torch.no_grad() +def power(L, A, v=None): + """ Compute A^L and the scan sum_i A^i v_i + + A: (..., N, N) + v: (..., N, L) + """ + + I = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device) + + powers = [A] + l = 1 + while True: + if L % 2 == 1: I = powers[-1] @ I + L //= 2 + if L == 0: break + l *= 2 + if v is None: + powers = [powers[-1] @ powers[-1]] + else: + powers.append(powers[-1] @ powers[-1]) + + if v is None: return I + + # Invariants: + # powers[-1] := A^l + # l := largest po2 at most L + + # Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A + # We do this reverse divide-and-conquer for efficiency reasons: + # 1) it involves fewer padding steps for non-po2 L + # 2) it involves more contiguous arrays + + # Take care of edge case for non-po2 arrays + # Note that this initial step is a no-op for the case of power of 2 (l == L) + k = v.size(-1) - l + v_ = powers.pop() @ v[..., l:] + v = v[..., :l] + v[..., :k] = v[..., :k] + v_ + + # Handle reduction for power of 2 + while v.size(-1) > 1: + v = rearrange(v, '... (z l) -> ... z l', z=2) + v = v[..., 0, :] + powers.pop() @ v[..., 1, :] + return I, v.squeeze(-1) + +def krylov_toeplitz(L, A, b, c=None): + """ Specializes to lower triangular Toeplitz matrix A represented by its diagonals + + A : (..., N) + b : (..., N) + c : (..., N) + + Returns + x : (..., N, L) + x[i, l] = A^l @ b[i] + """ + x = b.unsqueeze(0) # (1, ..., N) + A_ = A + while x.shape[0] < L: + xx = causal_convolution(A_, x) + x = torch.cat([x, xx], dim=0) # there might be a more efficient way of ordering axes + A_ = causal_convolution(A_, A_) + x = x[:L, ...] # (L, ..., N) + if c is not None: + x = torch.einsum('l...n, ...n -> ...l', x, c) + else: + x = rearrange(x, 'l ... n -> ... n l') + x = x.contiguous() + return x + +def krylov_toeplitz_(L, A, b, c=None): + """ Padded version of krylov_toeplitz that saves some fft's + + TODO currently not faster than original version, not sure why + """ + N = A.shape[-1] + + x = b.unsqueeze(0) # (1, ..., N) + x = F.pad(x, (0, N)) + A = F.pad(A, (0, N)) + done = L == 1 + while not done: + l = x.shape[0] + # Save memory on last iteration + if L - l <= l: + done = True + _x = x[:L-l] + else: _x = x + Af = torch.fft.rfft(A, n=2*N, dim=-1) + xf = torch.fft.rfft(_x, n=2*N, dim=-1) + xf_ = Af * xf + x_ = torch.fft.irfft(xf_, n=2*N, dim=-1) + x_[..., N:] = 0 + x = torch.cat([x, x_], dim=0) # there might be a more efficient way of ordering axes + if not done: + A = torch.fft.irfft(Af*Af, n=2*N, dim=-1) + A[..., N:] = 0 + x = x[:L, ..., :N] # (L, ..., N) + if c is not None: + x = torch.einsum('l...n, ...n -> ...l', x, c) + else: + x = rearrange(x, 'l ... n -> ... n l') + x = x.contiguous() + return x diff --git a/src/clm/src/ops/toeplitz.py b/src/clm/src/ops/toeplitz.py new file mode 100644 index 00000000..af007390 --- /dev/null +++ b/src/clm/src/ops/toeplitz.py @@ -0,0 +1,157 @@ +""" Utilities for computing convolutions. + +There are 3 equivalent views: + 1. causal convolution + 2. multiplication of (lower) triangular Toeplitz matrices + 3. polynomial multiplication (mod x^N) +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def construct_toeplitz(v, f=0.0): + """Explicit construction of Krylov matrix [v A @ v A^2 @ v ... A^{n-1} @ v] + where A = Z_f. This uses vectorized indexing and cumprod so it's much + faster than using the Krylov function. + Parameters: + v: the starting vector of size n or (rank, n). + f: real number + Returns: + K: Krylov matrix of size (n, n) or (rank, n, n). + """ + n = v.shape[-1] + a = torch.arange(n, device=v.device) + b = -a + indices = a[:, None] + b[None] + K = v[..., indices] + K[..., indices < 0] *= f + return K + +def triangular_toeplitz_multiply_(u, v, sum=None): + n = u.shape[-1] + u_expand = F.pad(u, (0, n)) + v_expand = F.pad(v, (0, n)) + u_f = torch.fft.rfft(u_expand, n=2*n, dim=-1) + v_f = torch.fft.rfft(v_expand, n=2*n, dim=-1) + uv_f = u_f * v_f + if sum is not None: + uv_f = uv_f.sum(dim=sum) + output = torch.fft.irfft(uv_f, n=2*n, dim=-1)[..., :n] + return output + +def triangular_toeplitz_multiply_padded_(u, v): + """ Same as triangular_toeplitz_multiply but inputs and output assume to be 0-padded already. """ + n = u.shape[-1] + assert n % 2 == 0 + u_f = torch.fft.rfft(u, n=n, dim=-1) + v_f = torch.fft.rfft(v, n=n, dim=-1) + uv_f = u_f * v_f + output = torch.fft.irfft(uv_f, n=n, dim=-1) + output[..., n:] = 0 + return output + +class TriangularToeplitzMult(torch.autograd.Function): + @staticmethod + def forward(ctx, u, v): + ctx.save_for_backward(u, v) + return triangular_toeplitz_multiply_(u, v) + + @staticmethod + def backward(ctx, grad): + u, v = ctx.saved_tensors + d_u = triangular_toeplitz_multiply_(grad.flip(-1), v).flip(-1) + d_v = triangular_toeplitz_multiply_(grad.flip(-1), u).flip(-1) + return d_u, d_v + +class TriangularToeplitzMultFast(torch.autograd.Function): + @staticmethod + def forward(ctx, u, v): + n = u.shape[-1] + u_expand = F.pad(u, (0, n)) + v_expand = F.pad(v, (0, n)) + u_f = torch.fft.rfft(u_expand, n=2*n, dim=-1) + v_f = torch.fft.rfft(v_expand, n=2*n, dim=-1) + + ctx.save_for_backward(u_f, v_f) + + uv_f = u_f * v_f + output = torch.fft.irfft(uv_f, n=2*n, dim=-1)[..., :n] + return output + + @staticmethod + def backward(ctx, grad): + u_f, v_f = ctx.saved_tensors + n = grad.shape[-1] + g_expand = F.pad(grad.flip(-1), (0, n)) + g_f = torch.fft.rfft(g_expand, n=2*n, dim=-1) + gu_f = g_f * u_f + gv_f = g_f * v_f + d_u = torch.fft.irfft(gv_f, n=2*n, dim=-1)[..., :n] + d_v = torch.fft.irfft(gu_f, n=2*n, dim=-1)[..., :n] + d_u = d_u.flip(-1) + d_v = d_v.flip(-1) + return d_u, d_v + +class TriangularToeplitzMultPadded(torch.autograd.Function): + @staticmethod + def forward(ctx, u, v): + ctx.save_for_backward(u, v) + output = triangular_toeplitz_multiply_(u, v) + return output + + @staticmethod + def backward(ctx, grad): + u, v = ctx.saved_tensors + d_u = triangular_toeplitz_multiply_padded_(grad.flip(-1), v).flip(-1) + d_v = triangular_toeplitz_multiply_padded_(grad.flip(-1), u).flip(-1) + return d_u, d_v + +class TriangularToeplitzMultPaddedFast(torch.autograd.Function): + """ Trade off speed (20-25% faster) for more memory (20-25%) """ + + @staticmethod + def forward(ctx, u, v): + n = u.shape[-1] + u_f = torch.fft.rfft(u, n=n, dim=-1) + v_f = torch.fft.rfft(v, n=n, dim=-1) + + ctx.save_for_backward(u_f, v_f) + + uv_f = u_f * v_f + output = torch.fft.irfft(uv_f, n=n, dim=-1) + output[..., n//2:].zero_() + return output + + @staticmethod + def backward(ctx, grad): + u_f, v_f = ctx.saved_tensors + n = grad.shape[-1] + g_expand = F.pad(grad[..., :n//2].flip(-1), (0, n//2)) + g_f = torch.fft.rfft(g_expand, n=n, dim=-1) + gu_f = g_f * u_f + gv_f = g_f * v_f + d_u = torch.fft.irfft(gv_f, n=n, dim=-1) + d_v = torch.fft.irfft(gu_f, n=n, dim=-1) + d_u[..., n//2:].zero_() + d_v[..., n//2:].zero_() + d_u[..., :n//2] = d_u[..., :n//2].flip(-1) # TODO + d_v[..., :n//2] = d_v[..., :n//2].flip(-1) # TODO + return d_u, d_v + +# triangular_toeplitz_multiply = triangular_toeplitz_multiply_ +triangular_toeplitz_multiply = TriangularToeplitzMult.apply +triangular_toeplitz_multiply_fast = TriangularToeplitzMultFast.apply +triangular_toeplitz_multiply_padded = TriangularToeplitzMultPadded.apply +triangular_toeplitz_multiply_padded_fast = TriangularToeplitzMultPaddedFast.apply + +def causal_convolution(u, v, fast=True, pad=False): + if not pad and not fast: + return triangular_toeplitz_multiply(u, v) + if not pad and fast: + return triangular_toeplitz_multiply_fast(u, v) + if pad and not fast: + return triangular_toeplitz_multiply_padded(u, v) + if pad and fast: + return triangular_toeplitz_multiply_padded_fast(u, v) diff --git a/src/clm/src/ops/unroll.py b/src/clm/src/ops/unroll.py new file mode 100644 index 00000000..b8f8c8db --- /dev/null +++ b/src/clm/src/ops/unroll.py @@ -0,0 +1,421 @@ +""" Old utilities for parallel scan implementation of Linear RNNs. """ +# TODO this file could use much cleanup + +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import math + +from clm.src.models.functional.toeplitz import triangular_toeplitz_multiply, triangular_toeplitz_multiply_padded +from clm.src.utils.permutations import bitreversal_po2, bitreversal_permutation + + + +### Utilities + + +def shift_up(a, s=None, drop=True, dim=0): + assert dim == 0 + if s is None: + s = torch.zeros_like(a[0, ...]) + s = s.unsqueeze(dim) + if drop: + a = a[:-1, ...] + return torch.cat((s, a), dim=dim) + +def interleave(a, b, uneven=False, dim=0): + """ Interleave two tensors of same shape """ + # assert(a.shape == b.shape) + assert dim == 0 # TODO temporary to make handling uneven case easier + if dim < 0: + dim = N + dim + if uneven: + a_ = a[-1:, ...] + a = a[:-1, ...] + c = torch.stack((a, b), dim+1) + out_shape = list(a.shape) + out_shape[dim] *= 2 + c = c.view(out_shape) + if uneven: + c = torch.cat((c, a_), dim=dim) + return c + +def batch_mult(A, u, has_batch=None): + """ Matrix mult A @ u with special case to save memory if u has additional batch dim + + The batch dimension is assumed to be the second dimension + A : (L, ..., N, N) + u : (L, [B], ..., N) + has_batch: True, False, or None. If None, determined automatically + + Output: + x : (L, [B], ..., N) + A @ u broadcasted appropriately + """ + + if has_batch is None: + has_batch = len(u.shape) >= len(A.shape) + + if has_batch: + u = u.permute([0] + list(range(2, len(u.shape))) + [1]) + else: + u = u.unsqueeze(-1) + v = (A @ u) + if has_batch: + v = v.permute([0] + [len(u.shape)-1] + list(range(1, len(u.shape)-1))) + else: + v = v[..., 0] + return v + + + +### Main unrolling functions + +def unroll(A, u): + """ + A : (..., N, N) # TODO I think this can't take batch dimension? + u : (L, ..., N) + output : x (..., N) # TODO a lot of these shapes are wrong + x[i, ...] = A^{i} @ u[0, ...] + ... + A @ u[i-1, ...] + u[i, ...] + """ + + m = u.new_zeros(u.shape[1:]) + outputs = [] + for u_ in torch.unbind(u, dim=0): + m = F.linear(m, A) + u_ + outputs.append(m) + + output = torch.stack(outputs, dim=0) + return output + + +def parallel_unroll_recursive(A, u): + """ Bottom-up divide-and-conquer version of unroll. """ + + # Main recursive function + def parallel_unroll_recursive_(A, u): + if u.shape[0] == 1: + return u + + u_evens = u[0::2, ...] + u_odds = u[1::2, ...] + + # u2 = F.linear(u_evens, A) + u_odds + u2 = (A @ u_evens.unsqueeze(-1)).squeeze(-1) + u_odds + A2 = A @ A + + x_odds = parallel_unroll_recursive_(A2, u2) + # x_evens = F.linear(shift_up(x_odds), A) + u_evens + x_evens = (A @ shift_up(x_odds).unsqueeze(-1)).squeeze(-1) + u_evens + + x = interleave(x_evens, x_odds, dim=0) + return x + + # Pad u to power of 2 + n = u.shape[0] + m = int(math.ceil(math.log(n)/math.log(2))) + N = 1 << m + u = torch.cat((u, u.new_zeros((N-u.shape[0],) + u.shape[1:] )), dim=0) + + return parallel_unroll_recursive_(A, u)[:n, ...] + + + +def parallel_unroll_recursive_br(A, u): + """ Same as parallel_unroll_recursive but uses bit reversal for locality. """ + + # Main recursive function + def parallel_unroll_recursive_br_(A, u): + n = u.shape[0] + if n == 1: + return u + + m = n//2 + u_0 = u[:m, ...] + u_1 = u[m:, ...] + + u2 = F.linear(u_0, A) + u_1 + A2 = A @ A + + x_1 = parallel_unroll_recursive_br_(A2, u2) + x_0 = F.linear(shift_up(x_1), A) + u_0 + + # x = torch.cat((x_0, x_1), dim=0) # is there a way to do this with cat? + x = interleave(x_0, x_1, dim=0) + return x + + # Pad u to power of 2 + n = u.shape[0] + m = int(math.ceil(math.log(n)/math.log(2))) + N = 1 << m + u = torch.cat((u, u.new_zeros((N-u.shape[0],) + u.shape[1:] )), dim=0) + + # Apply bit reversal + br = bitreversal_po2(N) + u = u[br, ...] + + x = parallel_unroll_recursive_br_(A, u) + return x[:n, ...] + +def parallel_unroll_iterative(A, u): + """ Bottom-up divide-and-conquer version of unroll, implemented iteratively """ + + # Pad u to power of 2 + n = u.shape[0] + m = int(math.ceil(math.log(n)/math.log(2))) + N = 1 << m + u = torch.cat((u, u.new_zeros((N-u.shape[0],) + u.shape[1:] )), dim=0) + + # Apply bit reversal + br = bitreversal_po2(N) + u = u[br, ...] + + # Main recursive loop, flattened + us = [] # stores the u_0 terms in the recursive version + N_ = N + As = [] # stores the A matrices + for l in range(m): + N_ = N_ // 2 + As.append(A) + u_0 = u[:N_, ...] + us.append(u_0) + u = F.linear(u_0, A) + u[N_:, ...] + A = A @ A + x_0 = [] + x = u # x_1 + for l in range(m-1, -1, -1): + x_0 = F.linear(shift_up(x), As[l]) + us[l] + x = interleave(x_0, x, dim=0) + + return x[:n, ...] + + +def variable_unroll_sequential(A, u, s=None, variable=True): + """ Unroll with variable (in time/length) transitions A. + + A : ([L], ..., N, N) dimension L should exist iff variable is True + u : (L, [B], ..., N) updates + s : ([B], ..., N) start state + output : x (..., N) + x[i, ...] = A[i]..A[0] @ s + A[i..1] @ u[0] + ... + A[i] @ u[i-1] + u[i] + """ + + if s is None: + s = torch.zeros_like(u[0]) + + if not variable: + A = A.expand((u.shape[0],) + A.shape) + has_batch = len(u.shape) >= len(A.shape) + + outputs = [] + for (A_, u_) in zip(torch.unbind(A, dim=0), torch.unbind(u, dim=0)): + # s = F.linear(s, A_) + u_ + s = batch_mult(A_.unsqueeze(0), s.unsqueeze(0), has_batch)[0] + s = s + u_ + outputs.append(s) + + output = torch.stack(outputs, dim=0) + return output + + + +def variable_unroll(A, u, s=None, variable=True, recurse_limit=16): + """ Bottom-up divide-and-conquer version of variable_unroll. """ + + if u.shape[0] <= recurse_limit: + return variable_unroll_sequential(A, u, s, variable) + + if s is None: + s = torch.zeros_like(u[0]) + + uneven = u.shape[0] % 2 == 1 + has_batch = len(u.shape) >= len(A.shape) + + u_0 = u[0::2, ...] + u_1 = u[1::2, ...] + + if variable: + A_0 = A[0::2, ...] + A_1 = A[1::2, ...] + else: + A_0 = A + A_1 = A + + u_0_ = u_0 + A_0_ = A_0 + if uneven: + u_0_ = u_0[:-1, ...] + if variable: + A_0_ = A_0[:-1, ...] + + u_10 = batch_mult(A_1, u_0_, has_batch) + u_10 = u_10 + u_1 + A_10 = A_1 @ A_0_ + + # Recursive call + x_1 = variable_unroll(A_10, u_10, s, variable, recurse_limit) + + x_0 = shift_up(x_1, s, drop=not uneven) + x_0 = batch_mult(A_0, x_0, has_batch) + x_0 = x_0 + u_0 + + + x = interleave(x_0, x_1, uneven, dim=0) # For some reason this interleave is slower than in the (non-multi) unroll_recursive + return x + +def variable_unroll_general_sequential(A, u, s, op, variable=True): + """ Unroll with variable (in time/length) transitions A with general associative operation + + A : ([L], ..., N, N) dimension L should exist iff variable is True + u : (L, [B], ..., N) updates + s : ([B], ..., N) start state + output : x (..., N) + x[i, ...] = A[i]..A[0] s + A[i..1] u[0] + ... + A[i] u[i-1] + u[i] + """ + + if not variable: + A = A.expand((u.shape[0],) + A.shape) + + outputs = [] + for (A_, u_) in zip(torch.unbind(A, dim=0), torch.unbind(u, dim=0)): + s = op(A_, s) + s = s + u_ + outputs.append(s) + + output = torch.stack(outputs, dim=0) + return output + +def variable_unroll_matrix_sequential(A, u, s=None, variable=True): + if s is None: + s = torch.zeros_like(u[0]) + + if not variable: + A = A.expand((u.shape[0],) + A.shape) + # has_batch = len(u.shape) >= len(A.shape) + + # op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0), has_batch)[0] + op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0))[0] + + return variable_unroll_general_sequential(A, u, s, op, variable=True) + +def variable_unroll_toeplitz_sequential(A, u, s=None, variable=True, pad=False): + if s is None: + s = torch.zeros_like(u[0]) + + if not variable: + A = A.expand((u.shape[0],) + A.shape) + # has_batch = len(u.shape) >= len(A.shape) + + # op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0), has_batch)[0] + # op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0))[0] + + if pad: + n = A.shape[-1] + A = F.pad(A, (0, n)) + u = F.pad(u, (0, n)) + s = F.pad(s, (0, n)) + ret = variable_unroll_general_sequential(A, u, s, triangular_toeplitz_multiply_padded, variable=True) + ret = ret[..., :n] + return ret + + return variable_unroll_general_sequential(A, u, s, triangular_toeplitz_multiply, variable=True) + + + +### General parallel scan functions with generic binary composition operators + +def variable_unroll_general(A, u, s, op, compose_op=None, sequential_op=None, variable=True, recurse_limit=16): + """ Bottom-up divide-and-conquer version of variable_unroll. + + compose is an optional function that defines how to compose A without multiplying by a leaf u + """ + + if u.shape[0] <= recurse_limit: + if sequential_op is None: + sequential_op = op + return variable_unroll_general_sequential(A, u, s, sequential_op, variable) + + if compose_op is None: + compose_op = op + + uneven = u.shape[0] % 2 == 1 + # has_batch = len(u.shape) >= len(A.shape) + + u_0 = u[0::2, ...] + u_1 = u[1::2, ...] + + if variable: + A_0 = A[0::2, ...] + A_1 = A[1::2, ...] + else: + A_0 = A + A_1 = A + + u_0_ = u_0 + A_0_ = A_0 + if uneven: + u_0_ = u_0[:-1, ...] + if variable: + A_0_ = A_0[:-1, ...] + + u_10 = op(A_1, u_0_) # batch_mult(A_1, u_0_, has_batch) + u_10 = u_10 + u_1 + A_10 = compose_op(A_1, A_0_) + + # Recursive call + x_1 = variable_unroll_general(A_10, u_10, s, op, compose_op, sequential_op, variable=variable, recurse_limit=recurse_limit) + + x_0 = shift_up(x_1, s, drop=not uneven) + x_0 = op(A_0, x_0) # batch_mult(A_0, x_0, has_batch) + x_0 = x_0 + u_0 + + + x = interleave(x_0, x_1, uneven, dim=0) # For some reason this interleave is slower than in the (non-multi) unroll_recursive + return x + +def variable_unroll_matrix(A, u, s=None, variable=True, recurse_limit=16): + if s is None: + s = torch.zeros_like(u[0]) + has_batch = len(u.shape) >= len(A.shape) + op = lambda x, y: batch_mult(x, y, has_batch) + sequential_op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0), has_batch)[0] + matmul = lambda x, y: x @ y + return variable_unroll_general(A, u, s, op, compose_op=matmul, sequential_op=sequential_op, variable=variable, recurse_limit=recurse_limit) + +def variable_unroll_toeplitz(A, u, s=None, variable=True, recurse_limit=8, pad=False): + """ Unroll with variable (in time/length) transitions A with general associative operation + + A : ([L], ..., N) dimension L should exist iff variable is True + u : (L, [B], ..., N) updates + s : ([B], ..., N) start state + output : x (L, [B], ..., N) same shape as u + x[i, ...] = A[i]..A[0] s + A[i..1] u[0] + ... + A[i] u[i-1] + u[i] + """ + # Add the batch dimension to A if necessary + A_batch_dims = len(A.shape) - int(variable) + u_batch_dims = len(u.shape)-1 + if u_batch_dims > A_batch_dims: + # assert u_batch_dims == A_batch_dims + 1 + if variable: + while len(A.shape) < len(u.shape): + A = A.unsqueeze(1) + # else: + # A = A.unsqueeze(0) + + if s is None: + s = torch.zeros_like(u[0]) + + if pad: + n = A.shape[-1] + A = F.pad(A, (0, n)) + u = F.pad(u, (0, n)) + s = F.pad(s, (0, n)) + op = triangular_toeplitz_multiply_padded + ret = variable_unroll_general(A, u, s, op, compose_op=op, variable=variable, recurse_limit=recurse_limit) + ret = ret[..., :n] + return ret + + op = triangular_toeplitz_multiply + ret = variable_unroll_general(A, u, s, op, compose_op=op, variable=variable, recurse_limit=recurse_limit) + return ret diff --git a/src/clm/src/ops/vandermonde.py b/src/clm/src/ops/vandermonde.py new file mode 100644 index 00000000..0325b4ec --- /dev/null +++ b/src/clm/src/ops/vandermonde.py @@ -0,0 +1,167 @@ +"""pykeops implementations of the Vandermonde matrix multiplication kernel used in the S4D kernel.""" +import math +import torch + +from einops import rearrange, repeat +from opt_einsum import contract + +import os + +try: + import pykeops + from pykeops.torch import LazyTensor, Genred +except: + pass + +try: + from cauchy_mult import vand_log_mult_sym_fwd, vand_log_mult_sym_bwd +except: + vand_log_mult_sym_fwd, vand_log_mult_sym_bwd = None, None + +_conj = lambda x: torch.cat([x, x.conj()], dim=-1) +def _broadcast_dims(*tensors): + max_dim = max([len(tensor.shape) for tensor in tensors]) + tensors = [tensor.view((1,)*(max_dim-len(tensor.shape))+tensor.shape) for tensor in tensors] + return tensors + +def _c2r(x): return torch.view_as_real(x) +def _r2c(x): return torch.view_as_complex(x) + +def vandermonde_naive(v, x, L, conj=True): + """ + v: (..., N) + x: (..., N) + returns: (..., L) \sum v x^l + """ + if conj: + x = _conj(x) + v = _conj(v) + vandermonde_matrix = x.unsqueeze(-1) ** torch.arange(L).to(x) # (... N L) + vandermonde_prod = torch.sum(v.unsqueeze(-1) * vandermonde_matrix, dim=-2) # (... L) + return vandermonde_prod + +def log_vandermonde_naive(v, x, L, conj=True): + """ + v: (..., N) + x: (..., N) + returns: (..., L) \sum v x^l + """ + vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L) + vandermonde_prod = contract('... n, ... n l -> ... l', v, vandermonde_matrix) # (... L) + if conj: + return 2*vandermonde_prod.real + else: + return vandermonde_prod + +def log_vandermonde_lazy(v, x, L, conj=True): + if conj: + v = _conj(v) + x = _conj(x) + l = torch.arange(L).to(x) + v, x, l = _broadcast_dims(v, x, l) + v_l = LazyTensor(rearrange(v, '... N -> ... N 1 1')) + x_l = LazyTensor(rearrange(x, '... N -> ... N 1 1')) + l_l = LazyTensor(rearrange(l, '... L -> ... 1 L 1')) + # exp + vand = (x_l * l_l).exp() + s = (v_l*vand).sum(dim=len(v_l.shape)-2) + return s.squeeze(-1) + +def log_vandermonde(v, x, L, conj=True): + expr = 'ComplexMult(v, ComplexExp(ComplexMult(x, l)))' + vandermonde_mult = Genred( + expr, + [ + 'v = Vj(2)', + 'x = Vj(2)', + 'l = Vi(2)', + ], + reduction_op='Sum', + axis=1, + ) + + l = torch.arange(L).to(x) + v, x, l = _broadcast_dims(v, x, l) + v = _c2r(v) + x = _c2r(x) + l = _c2r(l) + + r = vandermonde_mult(v, x, l, backend='GPU') + if conj: + return 2*_r2c(r).real + else: + return _r2c(r) + +def log_vandermonde_transpose_naive(u, v, x, L): + vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L) + vandermonde_prod = contract('... l, ... n, ... n l -> ... n', u.to(x), v.to(x), vandermonde_matrix) # (... L) + return vandermonde_prod + +def log_vandermonde_transpose(u, v, x, L): + """ + u: ... H L + v: ... H N + x: ... H N + Returns: ... H N + + V = Vandermonde(a, L) : (H N L) + contract_L(V * u * v) + """ + expr = 'ComplexMult(ComplexMult(v, u), ComplexExp(ComplexMult(x, l)))' + vandermonde_mult = Genred( + expr, + [ + 'u = Vj(2)', + 'v = Vi(2)', + 'x = Vi(2)', + 'l = Vj(2)', + ], + reduction_op='Sum', + axis=1, + ) + + l = torch.arange(L).to(x) + u, v, x, l = _broadcast_dims(u, v, x, l) + u = _c2r(u) + v = _c2r(v) + x = _c2r(x) + l = _c2r(l) + + r = vandermonde_mult(u, v, x, l, backend='GPU') + return _r2c(r) + +def _log_vandermonde_matmul(x, L): + vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L) + return vandermonde_matrix + +def log_vandermonde_matmul(v, K): + prod = contract('...n, ...nl -> ...l', v, K) + return 2*prod.real + +class LogVandMultiplySymmetric(torch.autograd.Function): + + @staticmethod + def forward(ctx, v, x, L): + batch, N = v.shape + supported_N_values = [1 << log_n for log_n in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] + if not N in supported_N_values: + raise NotImplementedError(f'Only support N values in {supported_N_values}') + max_L_value = 32 * 1024 * 64 * 1024 + if L > max_L_value: + raise NotImplementedError(f'Only support L values <= {max_L_value}') + if not v.is_cuda and x.is_cuda: + raise NotImplementedError(f'Only support CUDA tensors') + ctx.save_for_backward(v, x) + return vand_log_mult_sym_fwd(v, x, L) + + @staticmethod + def backward(ctx, dout): + v, x = ctx.saved_tensors + dv, dx = vand_log_mult_sym_bwd(v, x, dout) + return dv, dx, None + + +if vand_log_mult_sym_fwd and vand_log_mult_sym_bwd is not None: + log_vandermonde_fast = LogVandMultiplySymmetric.apply +else: + log_vandermonde_fast = None \ No newline at end of file diff --git a/src/clm/src/retnet/__init__.py b/src/clm/src/retnet/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/clm/src/retnet/complex/retention.py b/src/clm/src/retnet/complex/retention.py new file mode 100644 index 00000000..9a61a2d9 --- /dev/null +++ b/src/clm/src/retnet/complex/retention.py @@ -0,0 +1,177 @@ +import math + +import torch +import torch.nn as nn + +from util import ComplexGroupNorm + +class SimpleRetention(nn.Module): + def __init__(self, hidden_size, gamma, precision="single"): + """ + Simple retention mechanism based on the paper + "Retentive Network: A Successor to Transformer for Large Language Models"[https://arxiv.org/pdf/2307.08621.pdf] + """ + super(SimpleRetention, self).__init__() + + if precision == "half": + raise NotImplementedError("batchmm does not support half precision complex yet.") + self.complex_type = torch.complex32 + self.real_type = torch.float16 + elif precision == "single": + self.complex_type = torch.complex64 + self.real_type = torch.float32 + + self.precision = precision + self.hidden_size = hidden_size + self.gamma = gamma + + self.i = torch.complex(torch.tensor(0.0), torch.tensor(1.0)) + + self.W_Q = nn.Parameter(torch.randn(hidden_size, hidden_size, dtype=self.real_type) / hidden_size) + self.W_K = nn.Parameter(torch.randn(hidden_size, hidden_size, dtype=self.real_type) / hidden_size) + self.W_V = nn.Parameter(torch.randn(hidden_size, hidden_size, dtype=self.real_type) / hidden_size) + + + self.theta = torch.randn(hidden_size) / hidden_size + self.theta = nn.Parameter(self.theta) + + + + def forward(self, X): + """ + Parallel (default) representation of the retention mechanism. + X: (batch_size, sequence_length, hidden_size) + """ + sequence_length = X.shape[1] + D = self._get_D(sequence_length).to(X.device) + + if X.dtype != self.complex_type: + X = torch.complex(X, torch.zeros_like(X)).to(self.complex_type) + + i = self.i.to(X.device) + ns = torch.arange(1, sequence_length + 1, dtype=self.real_type, device=X.device) + ns = torch.complex(ns, torch.zeros_like(ns)).to(self.complex_type) + Theta = [] + + for n in ns: + Theta.append(torch.exp(i * n * self.theta)) + + Theta = torch.stack(Theta, dim=0) + + Theta_bar = Theta.conj() + + Q = (X @ self.W_Q.to(self.complex_type)) * Theta.unsqueeze(0) + K = (X @ self.W_K.to(self.complex_type)) * Theta_bar.unsqueeze(0) + V = X @ self.W_V.to(self.complex_type) + att = (Q @ K.permute(0, 2, 1)) * D.unsqueeze(0) + + return att @ V + + def forward_recurrent(self, x_n, s_n_1, n): + """ + Recurrent representation of the retention mechanism. + x_n: (batch_size, hidden_size) + s_n_1: (batch_size, hidden_size) + """ + if x_n.dtype != self.complex_type: + x_n = torch.complex(x_n, torch.zeros_like(x_n)).to(self.complex_type) + + n = torch.tensor(n, dtype=self.complex_type, device=x_n.device) + + Theta = torch.exp(self.i * n * self.theta) + Theta_bar = Theta.conj() + + Q = (x_n @ self.W_Q.to(self.complex_type)) * Theta + K = (x_n @ self.W_K.to(self.complex_type)) * Theta_bar + V = x_n @ self.W_V.to(self.complex_type) + + # K: (batch_size, hidden_size) + # V: (batch_size, hidden_size) + # s_n_1: (batch_size, hidden_size, hidden_size) + # s_n = gamma * s_n_1 + K^T @ V + + s_n = self.gamma * s_n_1 + K.unsqueeze(2) @ V.unsqueeze(1) + + return (Q.unsqueeze(1) @ s_n).squeeze(1), s_n + + def _get_D(self, sequence_length): + n = torch.arange(sequence_length).unsqueeze(1) + m = torch.arange(sequence_length).unsqueeze(0) + + # Broadcast self.gamma ** (n - m) with appropriate masking to set values where n < m to 0 + D = (self.gamma ** (n - m)) * (n >= m).float() #this results in some NaN when n is much larger than m + # fill the NaN with 0 + D[D != D] = 0 + + return D + +class MultiScaleRetention(nn.Module): + def __init__(self, hidden_size, heads, precision="single"): + """ + Multi-scale retention mechanism based on the paper + "Retentive Network: A Successor to Transformer for Large Language Models"[https://arxiv.org/pdf/2307.08621.pdf] + """ + super(MultiScaleRetention, self).__init__() + self.hidden_size = hidden_size + self.heads = heads + self.precision = precision + assert hidden_size % heads == 0, "hidden_size must be divisible by heads" + self.head_size = hidden_size // heads + + if precision == "half": + raise NotImplementedError("batchmm does not support half precision complex yet.") + self.complex_type = torch.complex32 + self.real_type = torch.float16 + elif precision == "single": + self.complex_type = torch.complex64 + self.real_type = torch.float32 + + self.gammas = (1 - torch.exp(torch.linspace(math.log(1/32), math.log(1/512), heads, dtype=self.real_type))).detach().cpu().tolist() + + self.swish = lambda x: x * torch.sigmoid(x) + self.W_G = nn.Parameter(torch.randn(hidden_size, hidden_size, dtype=self.complex_type) / hidden_size) + self.W_O = nn.Parameter(torch.randn(hidden_size, hidden_size, dtype=self.complex_type) / hidden_size) + self.group_norm = ComplexGroupNorm(heads, hidden_size) + + self.retentions = nn.ModuleList([ + SimpleRetention(self.head_size, gamma) for gamma in self.gammas + ]) + + def forward(self, X): + """ + parallel representation of the multi-scale retention mechanism + """ + if X.dtype != self.complex_type: + X = torch.complex(X, torch.zeros_like(X)).to(self.complex_type) + + # apply each individual retention mechanism to a slice of X + Y = [] + for i in range(self.heads): + Y.append(self.retentions[i](X[:, :, i*self.head_size:(i+1)*self.head_size])) + + Y = torch.cat(Y, dim=2) + Y = self.group_norm(Y.reshape(-1, self.hidden_size)).reshape(X.shape) + + return (self.swish(X @ self.W_G.to(self.complex_type)) * Y) @ self.W_O.to(self.complex_type) + + def forward_recurrent(self, x_n, s_n_1s, n): + """ + recurrent representation of the multi-scale retention mechanism + """ + if x_n.dtype != self.complex_type: + x_n = torch.complex(x_n, torch.zeros_like(x_n)).to(self.complex_type) + n = torch.tensor(n, dtype=self.complex_type, device=x_n.device) + + # apply each individual retention mechanism to a slice of X + Y = [] + s_ns = [] + for i in range(self.heads): + y, s_n = self.retentions[i].forward_recurrent( + x_n[:, i*self.head_size:(i+1)*self.head_size], s_n_1s[i], n + ) + Y.append(y) + s_ns.append(s_n) + + Y = torch.cat(Y, dim=1) + Y = self.group_norm(Y) + return (self.swish(x_n @ self.W_G.to(self.complex_type)) * Y) @ self.W_O.to(self.complex_type), s_ns diff --git a/src/clm/src/retnet/complex/retnet.py b/src/clm/src/retnet/complex/retnet.py new file mode 100644 index 00000000..4582c859 --- /dev/null +++ b/src/clm/src/retnet/complex/retnet.py @@ -0,0 +1,118 @@ +import torch +import torch.nn as nn + +from retention import MultiScaleRetention +from util import ComplexFFN, ComplexGroupNorm, ComplexLayerNorm + +class RetNet(nn.Module): + def __init__(self, layers, hidden_dim, ffn_size, heads): + super(RetNet, self).__init__() + self.layers = layers + self.hidden_dim = hidden_dim + self.ffn_size = ffn_size + self.heads = heads + + self.retentions = nn.ModuleList([ + MultiScaleRetention(hidden_dim, heads) + for _ in range(layers) + ]) + self.ffns = nn.ModuleList([ + ComplexFFN(hidden_dim, ffn_size) + for _ in range(layers) + ]) + self.layer_norm = ComplexLayerNorm(hidden_dim) + + def forward(self, X): + """ + X: (batch_size, sequence_length, hidden_size) + """ + for i in range(self.layers): + Y = self.retentions[i](self.layer_norm(X)) + X + X = self.ffns[i](self.layer_norm(Y)) + Y + + return X + + def forward_recurrent(self, x_n, s_n_1s, n): + """ + X: (batch_size, sequence_length, hidden_size) + s_n_1s: list of lists of tensors of shape (batch_size, hidden_size // heads, hidden_size // heads) + + """ + s_ns = [] + for i in range(self.layers): + o_n, s_n = self.retentions[i].forward_recurrent(self.layer_norm(x_n), s_n_1s[i], n) + y_n = o_n + x_n + s_ns.append(s_n) + x_n = self.ffns[i](self.layer_norm(y_n)) + y_n + + return x_n, s_ns + +class RetNetCLM(nn.Module): + def __init__(self, layers, hidden_dim, ffn_size, heads, vocab_size): + """ + NOTE: softmax not included! + """ + super(RetNetCLM, self).__init__() + self.layers = layers + self.hidden_dim = hidden_dim + self.ffn_size = ffn_size + self.heads = heads + self.vocab_size = vocab_size + + self.retnet = RetNet(layers, hidden_dim, ffn_size, heads) + self.embed = nn.Embedding(vocab_size, hidden_dim) + self.proj = nn.Parameter(torch.randn(hidden_dim, vocab_size, dtype=torch.float32) / hidden_dim) + + def forward(self, input_ids): + """ + input_ids: (batch_size, sequence_length) + """ + X = self.embed(input_ids) + X = self.retnet(X) + X = X @ self.proj.to(X.dtype) + + return X.real + + def forward_recurrent(self, input_ids, s_n_1s, n): + """ + input_ids: (batch_size) + s_n_1s: list of lists of tensors of shape (batch_size, hidden_size // heads, hidden_size // heads) + """ + X = self.embed(input_ids) + X, s_ns = self.retnet.forward_recurrent(X, s_n_1s, n) + X = X @ self.proj.to(X.dtype) + + return X.real, s_ns + + def sample(self, input_ids, sample_length, temperature=1.0): + """ + input_ids: (batch_size, sequence_length) + s_n_1s: list of lists of tensors of shape (batch_size, hidden_size // heads, hidden_size // heads) + """ + s_n_1s = [ + [ + torch.zeros(self.hidden_dim // self.heads, self.hidden_dim // self.heads, dtype=torch.complex64).unsqueeze(0).repeat(input_ids.shape[0], 1, 1) + for _ in range(self.heads) + ] for _ in range(self.layers) + ] + for i in range(input_ids.shape[1]): + X, s_n_1s = self.forward_recurrent(input_ids[:, i], s_n_1s, i+1) + + # get softmax of x (real part only) + X = X.real / temperature + X = torch.softmax(X, dim=-1) + X = torch.multinomial(X, num_samples=1) + next_char = X[:, -1] + output_ids = [] + # now start sampling! + for i in range(sample_length): + X, s_n_1s = self.forward_recurrent(next_char, s_n_1s, i+1) + X = X.real / temperature + X = torch.softmax(X, dim=-1) + X = torch.multinomial(X, num_samples=1) + next_char = X[:, -1] + output_ids.append(next_char) + + output_ids = torch.stack(output_ids, dim=1) + + return output_ids \ No newline at end of file diff --git a/src/clm/src/retnet/complex/test_retention.py b/src/clm/src/retnet/complex/test_retention.py new file mode 100644 index 00000000..07e30d6a --- /dev/null +++ b/src/clm/src/retnet/complex/test_retention.py @@ -0,0 +1,119 @@ +import unittest +import torch +from retention import SimpleRetention, MultiScaleRetention + +class TestSimpleRetention(unittest.TestCase): + def test_simple_retention_parallel(self): + batch_size = 4 + hidden_size = 8 + sequence_length = 16 + gamma = 0.9 + + X = torch.rand(batch_size, sequence_length, hidden_size) + retention = SimpleRetention(hidden_size, gamma) + + Y = retention(X) + self.assertEqual(Y.shape, (batch_size, sequence_length, hidden_size)) + + def test_simple_retention_recurrent(self): + batch_size = 4 + hidden_size = 8 + sequence_length = 16 + gamma = 0.9 + + X = torch.rand(batch_size, sequence_length, hidden_size) + retention = SimpleRetention(hidden_size, gamma) + + s_n_1 = torch.zeros(hidden_size, dtype=torch.complex64).unsqueeze(0).repeat(batch_size, 1, 1) + Y = [] + for i in range(sequence_length): + y_n, s_n = retention.forward_recurrent(X[:, i, :], s_n_1, i+1) + Y.append(y_n) + s_n_1 = s_n + Y = torch.stack(Y, dim=1) + self.assertEqual(Y.shape, (batch_size, sequence_length, hidden_size)) + + def test_paradigms_identical(self): + """ + check that the parallel and recurrent paradigms have identical outputs + """ + batch_size = 1 + hidden_size = 8 + sequence_length = 4 + gamma = 0.90 + + X = torch.rand(batch_size, sequence_length, hidden_size) + retention = SimpleRetention(hidden_size, gamma) + + Y_parallel = retention(X) + + s_n_1 = torch.zeros(hidden_size, hidden_size, dtype=torch.complex64).unsqueeze(0).repeat(batch_size, 1, 1) + Y_recurrent = [] + for i in range(sequence_length): + y_n, s_n = retention.forward_recurrent(X[:, i, :], s_n_1, i+1) + Y_recurrent.append(y_n) + s_n_1 = s_n + Y_recurrent = torch.stack(Y_recurrent, dim=1) + + self.assertTrue(torch.allclose(Y_parallel, Y_recurrent)) + +class TestMultiScaleRetention(unittest.TestCase): + def test_multiscale_retention_parallel(self): + batch_size = 4 + sequence_length = 5 + hidden_size = 32 + heads = 4 + retention = MultiScaleRetention(hidden_size, heads) + + X = torch.rand(batch_size, sequence_length, hidden_size) + Y = retention(X) + self.assertEqual(Y.shape, (batch_size, sequence_length, hidden_size)) + + def test_multiscale_retention_recurrent(self): + batch_size = 4 + sequence_length = 5 + hidden_size = 32 + heads = 4 + retention = MultiScaleRetention(hidden_size, heads) + + X = torch.rand(batch_size, sequence_length, hidden_size) + s_n_1s = [ + torch.zeros(hidden_size // heads, hidden_size // heads, dtype=torch.complex64).unsqueeze(0).repeat(batch_size, 1, 1) + for _ in range(heads) + ] + Y = [] + for i in range(sequence_length): + y_n, s_ns = retention.forward_recurrent(X[:, i, :], s_n_1s, i) + Y.append(y_n) + s_n_1s = s_ns + Y = torch.stack(Y, dim=1) + self.assertEqual(Y.shape, (batch_size, sequence_length, hidden_size)) + + def test_multiscale_paradigms_identical(self): + """ + check that the parallel and recurrent paradigms have identical outputs + """ + batch_size = 2 + hidden_size = 36 + sequence_length = 5 + heads = 3 + + X = torch.rand(batch_size, sequence_length, hidden_size) + retention = MultiScaleRetention(hidden_size, heads) + + Y_parallel = retention(X) + + s_n_1s = [ + torch.zeros(hidden_size // heads, hidden_size // heads, dtype=torch.complex64).unsqueeze(0).repeat(batch_size, 1, 1) + for _ in range(heads) + ] + Y_recurrent = [] + for i in range(sequence_length): + y_n, s_ns = retention.forward_recurrent(X[:, i, :], s_n_1s, i) + Y_recurrent.append(y_n) + s_n_1s = s_ns + Y_recurrent = torch.stack(Y_recurrent, dim=1) + + self.assertTrue(torch.allclose(Y_parallel, Y_recurrent)) + +unittest.main() \ No newline at end of file diff --git a/src/clm/src/retnet/complex/test_retnet.py b/src/clm/src/retnet/complex/test_retnet.py new file mode 100644 index 00000000..a2b1d2cd --- /dev/null +++ b/src/clm/src/retnet/complex/test_retnet.py @@ -0,0 +1,102 @@ +import unittest +import torch +from retnet import RetNet, RetNetCLM + +class TestRetNet(unittest.TestCase): + + def test_paradigms_equivalent(self): + batch_size = 2 + layers = 2 + hidden_dim = 8 + heads = 4 + sequence_length = 4 + ffn_size = 16 + + X = torch.rand(batch_size, sequence_length, hidden_dim) + + retnet = RetNet(layers, hidden_dim, ffn_size, heads) + Y_parallel = retnet(X) + + s_n_1s = [ + [ + torch.zeros(hidden_dim // heads, hidden_dim // heads, dtype=torch.complex64).unsqueeze(0).repeat(batch_size, 1, 1) + for _ in range(heads) + ] for _ in range(layers) + ] + + Y_recurrent = [] + for i in range(sequence_length): + Y, s_ns = retnet.forward_recurrent(X[:, i, :], s_n_1s, i+1) + Y_recurrent.append(Y) + s_n_1s = s_ns + + Y_recurrent = torch.stack(Y_recurrent, dim=1) + + print((Y_parallel - Y_recurrent).abs().max()) + + self.assertTrue((Y_parallel - Y_recurrent).abs().max() < 1e-4) + + def test_clm(self): + batch_size = 2 + layers = 2 + hidden_dim = 16 + heads = 4 + sequence_length = 6 + ffn_size = 32 + vocab_size = 10 + + X = torch.randint(0, vocab_size, (batch_size, sequence_length)) + + retnet = RetNetCLM(layers, hidden_dim, ffn_size, heads, vocab_size) + Y_parallel = retnet(X) + + s_n_1s = [ + [ + torch.zeros(hidden_dim // heads, hidden_dim // heads, dtype=torch.complex64).unsqueeze(0).repeat(batch_size, 1, 1) + for _ in range(heads) + ] for _ in range(layers) + ] + + Y_recurrent = [] + for i in range(sequence_length): + Y, s_ns = retnet.forward_recurrent(X[:, i], s_n_1s, i+1) + Y_recurrent.append(Y) + s_n_1s = s_ns + + Y_recurrent = torch.stack(Y_recurrent, dim=1) + + # test sample + Y_sample = retnet.sample(X, 5) + + self.assertTrue(Y_sample.shape == (batch_size, 5)) + + self.assertTrue((Y_parallel - Y_recurrent).abs().max() < 1e-4) + + def test_training(self): + batch_size = 2 + layers = 3 + hidden_dim = 16 + heads = 4 + sequence_length = 6 + ffn_size = 32 + vocab_size = 10 + bos_idx = 0 + + data = torch.randint(0, vocab_size, (batch_size, sequence_length - 1)) + X = torch.cat([torch.ones(batch_size, 1).long() * bos_idx, data[:,:-1]], dim=1) + Y = data + + # verify we can overfit autoregressive model + model = RetNetCLM(layers, hidden_dim, ffn_size, heads, vocab_size) + + optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) + criterion = torch.nn.CrossEntropyLoss() + initial_loss = criterion(model(X).reshape(-1, 10), Y.reshape(-1)) + for i in range(10): + optimizer.zero_grad() + output = model(X) + loss = criterion(output.reshape(-1, 10), Y.reshape(-1)) + loss.backward() + optimizer.step() + self.assertTrue((loss < initial_loss).item()) +unittest.main() \ No newline at end of file diff --git a/src/clm/src/retnet/complex/util.py b/src/clm/src/retnet/complex/util.py new file mode 100644 index 00000000..f7a89da0 --- /dev/null +++ b/src/clm/src/retnet/complex/util.py @@ -0,0 +1,71 @@ +import math +import torch +import torch.nn as nn + +class ComplexGroupNorm(nn.Module): + def __init__(self, num_groups, num_channels, eps=1e-5): + super(ComplexGroupNorm, self).__init__() + self.num_groups = num_groups + self.num_channels = num_channels + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_channels, dtype=torch.float32)) + self.bias = nn.Parameter(torch.zeros(num_channels, dtype=torch.float32)) + + def forward(self, X): + """ + X: (batch_size, sequence_length, hidden_size) + X is assumed to be complex + """ + X = X.reshape(-1, self.num_groups, self.num_channels // self.num_groups) + mean = X.mean(dim=2, keepdim=True) + var = X.var(dim=2, keepdim=True) + X = (X - mean) / torch.sqrt(var + self.eps) + X = X.reshape(-1, self.num_channels) + X = X * self.weight + self.bias + + return X + +class ComplexLayerNorm(nn.Module): + def __init__(self, num_channels, eps=1e-5): + super(ComplexLayerNorm, self).__init__() + self.num_channels = num_channels + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_channels, dtype=torch.float32)) + self.bias = nn.Parameter(torch.zeros(num_channels, dtype=torch.float32)) + + def forward(self, X): + """ + X: unknown shape ending in hidden_size + we treat the last dimension as the hidden_size + """ + X_shape = X.shape + X = X.reshape(-1, X_shape[-1]) + mean = X.mean(dim=1, keepdim=True) + var = X.abs().var(dim=1, keepdim=True) + X = (X - mean) / torch.sqrt(var + self.eps) + X = X * self.weight + self.bias + X = X.reshape(X_shape) + return X + + +class ComplexFFN(nn.Module): + """ + 2 linear layers with no bias + """ + def __init__(self, hidden_size, ffn_size): + super(ComplexFFN, self).__init__() + self.W1 = nn.Parameter(torch.randn(hidden_size, ffn_size, dtype=torch.float32) / math.sqrt(hidden_size)) + self.W2 = nn.Parameter(torch.randn(ffn_size, hidden_size, dtype=torch.float32) / math.sqrt(ffn_size)) + self.gelu = lambda x: 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + def forward(self, X): + """ + X: (batch_size, sequence_length, hidden_size) + X is assumed to be complex + """ + # reshaping + X = X @ self.W1.to(X) + X = self.gelu(X) + X = X @ self.W2.to(X) + + return X diff --git a/src/clm/src/retnet/example.py b/src/clm/src/retnet/example.py new file mode 100644 index 00000000..0dcaaedb --- /dev/null +++ b/src/clm/src/retnet/example.py @@ -0,0 +1,17 @@ +import torch +import torch.nn as nn + +import retnet + +if __name__ == "__main__": + # verify model size for hyperparameters in paper + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + # 1.3B model + layers = 24 + hidden_dim = 2048 + ffn_size = 4096 + heads = 16 + + retnet = retnet.RetNet(layers, hidden_dim, ffn_size, heads, double_v_dim=True).to(device) + print("1.3B model:",sum(p.numel() for p in retnet.parameters() if p.requires_grad)) diff --git a/src/clm/src/retnet/retention.py b/src/clm/src/retnet/retention.py new file mode 100644 index 00000000..e23e9e3c --- /dev/null +++ b/src/clm/src/retnet/retention.py @@ -0,0 +1,204 @@ +import math + +import torch +import torch.nn as nn + +from clm.src.retnet.xpos_relative_position import XPOS + +class SimpleRetention(nn.Module): + def __init__(self, hidden_size, gamma, head_size=None, double_v_dim=False): + """ + Simple retention mechanism based on the paper + "Retentive Network: A Successor to Transformer for Large Language Models"[https://arxiv.org/pdf/2307.08621.pdf] + """ + super(SimpleRetention, self).__init__() + + self.hidden_size = hidden_size + if head_size is None: + head_size = hidden_size + self.head_size = head_size + + self.v_dim = head_size * 2 if double_v_dim else head_size + self.gamma = gamma + + self.W_Q = nn.Parameter(torch.randn(hidden_size, head_size) / hidden_size) + self.W_K = nn.Parameter(torch.randn(hidden_size, head_size) / hidden_size) + self.W_V = nn.Parameter(torch.randn(hidden_size, self.v_dim) / hidden_size) + + self.xpos = XPOS(head_size) + + def forward(self, X): + """ + Parallel (default) representation of the retention mechanism. + X: (batch_size, sequence_length, hidden_size) + """ + sequence_length = X.shape[1] + D = self._get_D(sequence_length).to(self.W_Q.device) + + Q = (X @ self.W_Q) + K = (X @ self.W_K) + + Q = self.xpos(Q) + K = self.xpos(K, downscale=True) + + V = X @ self.W_V + ret = (Q @ K.permute(0, 2, 1)) * D.unsqueeze(0) + + return ret @ V + + def forward_recurrent(self, x_n, s_n_1, n): + """ + Recurrent representation of the retention mechanism. + x_n: (batch_size, 1, hidden_size) + s_n_1: (batch_size, hidden_size, v_dim) + """ + + Q = (x_n @ self.W_Q) + K = (x_n @ self.W_K) + + Q = self.xpos(Q, n+1) + K = self.xpos(K, n+1, downscale=True) + + V = x_n @ self.W_V + + # K: (batch_size, 1, hidden_size) + # V: (batch_size, 1, v_dim) + # s_n = gamma * s_n_1 + K^T @ V + + s_n = self.gamma * s_n_1 + (K.transpose(-1, -2) @ V) + + return (Q @ s_n), s_n + + def forward_chunkwise(self, x_i, r_i_1, i): + """ + Chunkwise representation of the retention mechanism. + x_i: (batch_size, chunk_size, hidden_size) + r_i_1: (batch_size, hidden_size, v_dim) + """ + batch, chunk_size, _ = x_i.shape + D = self._get_D(chunk_size) + + Q = (x_i @ self.W_Q) + K = (x_i @ self.W_K) + + Q = self.xpos(Q, i * chunk_size) + K = self.xpos(K, i * chunk_size, downscale=True) + + V = x_i @ self.W_V + + r_i =(K.transpose(-1, -2) @ (V * D[-1].view(1, chunk_size, 1))) + (self.gamma ** chunk_size) * r_i_1 + + inner_chunk = ((Q @ K.transpose(-1, -2)) * D.unsqueeze(0)) @ V + + #e[i,j] = gamma ** (i+1) + e = torch.zeros(batch, chunk_size, 1) + + for _i in range(chunk_size): + e[:, _i, :] = self.gamma ** (_i + 1) + + cross_chunk = (Q @ r_i_1) * e + + return inner_chunk + cross_chunk, r_i + + def _get_D(self, sequence_length): + n = torch.arange(sequence_length).unsqueeze(1) + m = torch.arange(sequence_length).unsqueeze(0) + + # Broadcast self.gamma ** (n - m) with appropriate masking to set values where n < m to 0 + D = (self.gamma ** (n - m)) * (n >= m).float() #this results in some NaN when n is much larger than m + # fill the NaN with 0 + D[D != D] = 0 + + return D + + + +class MultiScaleRetention(nn.Module): + def __init__(self, hidden_size, heads, double_v_dim=False): + """ + Multi-scale retention mechanism based on the paper + "Retentive Network: A Successor to Transformer for Large Language Models"[https://arxiv.org/pdf/2307.08621.pdf] + """ + super(MultiScaleRetention, self).__init__() + self.hidden_size = hidden_size + self.v_dim = hidden_size * 2 if double_v_dim else hidden_size + self.heads = heads + assert hidden_size % heads == 0, "hidden_size must be divisible by heads" + self.head_size = hidden_size // heads + self.head_v_dim = hidden_size * 2 if double_v_dim else hidden_size + + self.gammas = (1 - torch.exp(torch.linspace(math.log(1/32), math.log(1/512), heads))).detach().cpu().tolist() + + self.swish = lambda x: x * torch.sigmoid(x) + self.W_G = nn.Parameter(torch.randn(hidden_size, self.v_dim) / hidden_size) + self.W_O = nn.Parameter(torch.randn(self.v_dim, hidden_size) / hidden_size) + self.group_norm = nn.GroupNorm(heads, self.v_dim) + + self.retentions = nn.ModuleList([ + SimpleRetention(self.hidden_size, gamma, self.head_size, double_v_dim) for gamma in self.gammas + ]) + + def forward(self, X): + """ + parallel representation of the multi-scale retention mechanism + """ + + # apply each individual retention mechanism to X + Y = [] + for i in range(self.heads): + Y.append(self.retentions[i](X)) + + Y = torch.cat(Y, dim=2) + Y_shape = Y.shape + Y = self.group_norm(Y.reshape(-1, self.v_dim)).reshape(Y_shape) + + return (self.swish(X @ self.W_G) * Y) @ self.W_O + + def forward_recurrent(self, x_n, s_n_1s, n): + """ + recurrent representation of the multi-scale retention mechanism + x_n: (batch_size, 1, hidden_size) + s_n_1s: (batch_size, heads, head_size, head_size) + + """ + + # apply each individual retention mechanism to a slice of X + Y = [] + s_ns = [] + for i in range(self.heads): + y, s_n = self.retentions[i].forward_recurrent( + x_n[:, :, :], s_n_1s[i], n + ) + Y.append(y) + s_ns.append(s_n) + + Y = torch.cat(Y, dim=2) + Y_shape = Y.shape + Y = self.group_norm(Y.reshape(-1, self.v_dim)).reshape(Y_shape) + + return (self.swish(x_n @ self.W_G) * Y) @ self.W_O, s_ns + + def forward_chunkwise(self, x_i, r_i_1s, i): + """ + chunkwise representation of the multi-scale retention mechanism + x_i: (batch_size, chunk_size, hidden_size) + r_i_1s: (batch_size, heads, head_size, head_size) + """ + batch, chunk_size, _ = x_i.shape + + # apply each individual retention mechanism to a slice of X + Y = [] + r_is = [] + for j in range(self.heads): + y, r_i = self.retentions[j].forward_chunkwise( + x_i[:, :, :], r_i_1s[j], i + ) + Y.append(y) + r_is.append(r_i) + + + Y = torch.cat(Y, dim=2) + Y_shape = Y.shape + Y = self.group_norm(Y.reshape(-1, self.v_dim)).reshape(Y_shape) + + return (self.swish(x_i @ self.W_G) * Y) @ self.W_O, r_is diff --git a/src/clm/src/retnet/retnet.py b/src/clm/src/retnet/retnet.py new file mode 100644 index 00000000..dced11ec --- /dev/null +++ b/src/clm/src/retnet/retnet.py @@ -0,0 +1,76 @@ +import torch +import torch.nn as nn + +from clm.src.retnet.retention import MultiScaleRetention + +class RetNet(nn.Module): + def __init__(self, layers, hidden_dim, ffn_size, heads, double_v_dim=False): + super(RetNet, self).__init__() + self.layers = layers + self.hidden_dim = hidden_dim + self.ffn_size = ffn_size + self.heads = heads + self.v_dim = hidden_dim * 2 if double_v_dim else hidden_dim + + self.retentions = nn.ModuleList([ + MultiScaleRetention(hidden_dim, heads, double_v_dim) + for _ in range(layers) + ]) + self.ffns = nn.ModuleList([ + nn.Sequential( + nn.Linear(hidden_dim, ffn_size), + nn.GELU(), + nn.Linear(ffn_size, hidden_dim) + ) + for _ in range(layers) + ]) + self.layer_norms_1 = nn.ModuleList([ + nn.LayerNorm(hidden_dim) + for _ in range(layers) + ]) + self.layer_norms_2 = nn.ModuleList([ + nn.LayerNorm(hidden_dim) + for _ in range(layers) + ]) + + def forward(self, X): + """ + X: (batch_size, sequence_length, hidden_size) + """ + for i in range(self.layers): + Y = self.retentions[i](self.layer_norms_1[i](X)) + X + + X = self.ffns[i](self.layer_norms_2[i](Y)) + Y + + return X + + def forward_recurrent(self, x_n, s_n_1s, n): + """ + X: (batch_size, sequence_length, hidden_size) + s_n_1s: list of lists of tensors of shape (batch_size, hidden_size // heads, hidden_size // heads) + + """ + s_ns = [] + for i in range(self.layers): + # list index out of range + o_n, s_n = self.retentions[i].forward_recurrent(self.layer_norms_1[i](x_n), s_n_1s[i], n) + y_n = o_n + x_n + s_ns.append(s_n) + x_n = self.ffns[i](self.layer_norms_2[i](y_n)) + y_n + + return x_n, s_ns + + def forward_chunkwise(self, x_i, r_i_1s, i): + """ + X: (batch_size, sequence_length, hidden_size) + r_i_1s: list of lists of tensors of shape (batch_size, hidden_size // heads, hidden_size // heads) + + """ + r_is = [] + for j in range(self.layers): + o_i, r_i = self.retentions[j].forward_chunkwise(self.layer_norms_1[j](x_i), r_i_1s[j], i) + y_i = o_i + x_i + r_is.append(r_i) + x_i = self.ffns[j](self.layer_norms_2[j](y_i)) + y_i + + return x_i, r_is diff --git a/src/clm/src/retnet/tests.py b/src/clm/src/retnet/tests.py new file mode 100644 index 00000000..44c8fc6d --- /dev/null +++ b/src/clm/src/retnet/tests.py @@ -0,0 +1,154 @@ +import unittest + +import torch + +from clm.src.retnet.retention import SimpleRetention, MultiScaleRetention +from clm.src.retnet.retnet import RetNet + +class TestRetention(unittest.TestCase): + + def test_simple(self): + """ + verify that the three implementations of SimpleRetention are identical + """ + batch_size = 4 + sequence_length = 12 + hidden_size = 6 + chunk_size = 4 + + gamma = 0.9 + + X = torch.rand(batch_size, sequence_length, hidden_size) + sr = SimpleRetention(hidden_size, gamma, double_v_dim=True) + + Y_parallel = sr(X) + + s_n_1 = torch.zeros(hidden_size, sr.v_dim).unsqueeze(0).repeat(batch_size, 1, 1) + Y_recurrent = [] + for i in range(sequence_length): + y_n, s_n = sr.forward_recurrent(X[:, i:i+1, :], s_n_1, i) + Y_recurrent.append(y_n) + s_n_1 = s_n + + Y_recurrent = torch.concat(Y_recurrent, dim=1) + + r_n_1 = torch.zeros(hidden_size, sr.v_dim).unsqueeze(0).repeat(batch_size, 1, 1) + Y_chunkwise = [] + for i in range(sequence_length // chunk_size): + y_i, r_i = sr.forward_chunkwise(X[:, i*chunk_size:(i+1)*chunk_size, :], r_n_1, i) + Y_chunkwise.append(y_i) + r_n_1 = r_i + + + Y_chunkwise = torch.concat(Y_chunkwise, dim=1) + + + assert torch.allclose(Y_parallel, Y_recurrent, atol=1e-5) + assert torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5) + + + def test_multiscale(self): + """ + verify that the three implementations of MultiScaleRetention are identical + """ + batch_size = 2 + hidden_size = 6 + sequence_length = 12 + heads = 3 + chunk_size = 2 + + X = torch.rand(batch_size, sequence_length, hidden_size) + retention = MultiScaleRetention(hidden_size, heads, double_v_dim=False) + # print total number of parameters + print("Default v_dim:",sum(p.numel() for p in retention.parameters() if p.requires_grad)) + + retention = MultiScaleRetention(hidden_size, heads, double_v_dim=True) + print("Double v_dim:",sum(p.numel() for p in retention.parameters() if p.requires_grad)) + + Y_parallel = retention(X) + + s_n_1s = [ + torch.zeros(hidden_size // heads, retention.v_dim // heads).unsqueeze(0).repeat(batch_size, 1, 1) + for _ in range(heads) + ] + Y_recurrent = [] + for i in range(sequence_length): + y_n, s_ns = retention.forward_recurrent(X[:, i:i+1, :], s_n_1s, i) + Y_recurrent.append(y_n) + s_n_1s = s_ns + + Y_recurrent = torch.concat(Y_recurrent, dim=1) + + r_n_1s = [ + torch.zeros(hidden_size // heads, retention.v_dim // heads).unsqueeze(0).repeat(batch_size, 1, 1) + for _ in range(heads) + ] + Y_chunkwise = [] + for i in range(sequence_length // chunk_size): + y_i, r_i = retention.forward_chunkwise(X[:, i*chunk_size:(i+1)*chunk_size, :], r_n_1s, i) + Y_chunkwise.append(y_i) + r_n_1s = r_i + + Y_chunkwise = torch.concat(Y_chunkwise, dim=1) + + self.assertTrue(torch.allclose(Y_parallel, Y_recurrent, atol=1e-5)) + self.assertTrue(torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5)) # fails + +class TestRetNet(unittest.TestCase): + + def test_retnet(self): + """ + verify that the three implementations of RetNet are identical + """ + batch_size = 2 + hidden_size = 36 + sequence_length = 5 + heads = 3 + layers = 4 + ffn_size = 128 + + X = torch.rand(batch_size, sequence_length, hidden_size) + retnet = RetNet(layers, hidden_size, ffn_size, heads, double_v_dim=False) + # print total number of parameters + print("Default v_dim:",sum(p.numel() for p in retnet.parameters() if p.requires_grad)) + + retnet = RetNet(layers, hidden_size, ffn_size, heads, double_v_dim=True) + print("Double v_dim:",sum(p.numel() for p in retnet.parameters() if p.requires_grad)) + + Y_parallel = retnet(X) + + s_n_1s = [ + [ + torch.zeros(hidden_size // heads, retnet.v_dim // heads).unsqueeze(0).repeat(batch_size, 1, 1) + for _ in range(heads) + ] + for _ in range(layers) + ] + Y_recurrent = [] + for i in range(sequence_length): + y_n, s_ns = retnet.forward_recurrent(X[:, i:i+1, :], s_n_1s, i) + Y_recurrent.append(y_n) + s_n_1s = s_ns + + Y_recurrent = torch.concat(Y_recurrent, dim=1) + + r_n_1s = [ + [ + torch.zeros(hidden_size // heads, retnet.v_dim // heads).unsqueeze(0).repeat(batch_size, 1, 1) + for _ in range(heads) + ] + for _ in range(layers) + ] + Y_chunkwise = [] + for i in range(sequence_length): + y_i, r_i = retnet.forward_chunkwise(X[:, i:i+1, :], r_n_1s, i) + Y_chunkwise.append(y_i) + r_n_1s = r_i + + Y_chunkwise = torch.concat(Y_chunkwise, dim=1) + + self.assertTrue(torch.allclose(Y_parallel, Y_recurrent, atol=1e-5)) + self.assertTrue(torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5)) + +if __name__ == "__main__": + unittest.main() diff --git a/src/clm/src/retnet/xpos_relative_position.py b/src/clm/src/retnet/xpos_relative_position.py new file mode 100644 index 00000000..1c445e5a --- /dev/null +++ b/src/clm/src/retnet/xpos_relative_position.py @@ -0,0 +1,94 @@ +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License (https://github.com/microsoft/torchscale/blob/main/LICENSE) +import torch +import torch.nn as nn + +def fixed_pos_embedding(x): + seq_len, dim = x.shape + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim) / dim)) + sinusoid_inp = ( + torch.einsum("i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq).to(x) + ) + return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) + +def rotate_every_two(x): + x1 = x[:, :, ::2] + x2 = x[:, :, 1::2] + x = torch.stack((-x2, x1), dim=-1) + if x.shape[-1]%2 == 1: + # fill last dim with zero if hidden_size is odd + x2 = torch.concat((x2, torch.zeros_like(x2[:, :, :1])), dim=-1) + return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\ + +def duplicate_interleave(m): + """ + A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy. + """ + dim0 = m.shape[0] + m = m.view(-1, 1) # flatten the matrix + m = m.repeat(1, 2) # repeat all elements into the 2nd dimension + m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy + return m + +def apply_rotary_pos_emb(x, sin, cos, scale=1): + sin, cos = map(lambda t: duplicate_interleave(t * scale), (sin, cos)) + # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) + return (x * cos[:, :x.shape[-1]]) + (rotate_every_two(x) * sin)[:, :, :x.shape[-1]] + + +class XPOS(nn.Module): + def __init__( + self, head_dim, scale_base=512 + ): + super().__init__() + self.head_dim = head_dim + self.scale_base = scale_base + self.register_buffer( + "scale", (torch.arange(0, head_dim, 2) + 0.4 * head_dim) / (1.4 * head_dim) + ) + + def forward(self, x, offset=0, downscale=False): + length = x.shape[1] + min_pos = 0 + max_pos = length + offset + min_pos + scale = self.scale ** torch.arange(min_pos, max_pos, 1).to(self.scale).div(self.scale_base)[:, None] + sin, cos = fixed_pos_embedding(scale) + + if scale.shape[0] > length: + scale = scale[-length:] + sin = sin[-length:] + cos = cos[-length:] + + if downscale: + scale = 1 / scale + + x = apply_rotary_pos_emb(x, sin, cos, scale) + return x + + def forward_reverse(self, x, offset=0, downscale=False): + length = x.shape[1] + min_pos = -(length + offset) // 2 + max_pos = length + offset + min_pos + scale = self.scale ** torch.arange(min_pos, max_pos, 1).to(self.scale).div(self.scale_base)[:, None] + sin, cos = fixed_pos_embedding(scale) + + if scale.shape[0] > length: + scale = scale[-length:] + sin = sin[-length:] + cos = cos[-length:] + + if downscale: + scale = 1 / scale + + x = apply_rotary_pos_emb(x, -sin, cos, scale) + return x + +# test +if __name__ == "__main__": + x = torch.eye(4).unsqueeze(0) + xpos = XPOS(4) + x_rot = xpos(x) + # apply reverse + x_rot_rev = xpos.forward(x) + + print(x_rot @ x_rot_rev.transpose(-1, -2)) \ No newline at end of file diff --git a/src/clm/src/tasks/decoders.py b/src/clm/src/tasks/decoders.py new file mode 100644 index 00000000..f95e5005 --- /dev/null +++ b/src/clm/src/tasks/decoders.py @@ -0,0 +1,319 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, reduce + +import clm.src.models.nn.utils as U +import clm.src.utils as utils +import clm.src.utils.config +import clm.src.utils.train + +log = clm.src.utils.train.get_logger(__name__) + + +class Decoder(nn.Module): + """This class doesn't do much but just signals the interface that Decoders are expected to adhere to + TODO: is there a way to enforce the signature of the forward method? + """ + + def forward(self, x, **kwargs): + """ + x: (batch, length, dim) input tensor + state: additional state from the model backbone + *args, **kwargs: additional info from the dataset + + Returns: + y: output tensor + *args: other arguments to pass into the loss function + """ + return x + + def step(self, x): + """ + x: (batch, dim) + """ + return self.forward(x.unsqueeze(1)).squeeze(1) + + +class SequenceDecoder(Decoder): + def __init__( + self, d_model, d_output=None, l_output=None, use_lengths=False, mode="last" + ): + super().__init__() + + self.output_transform = nn.Identity() if d_output is None else nn.Linear(d_model, d_output) + + if l_output is None: + self.l_output = None + self.squeeze = False + elif l_output == 0: + # Equivalent to getting an output of length 1 and then squeezing + self.l_output = 1 + self.squeeze = True + else: + assert l_output > 0 + self.l_output = l_output + self.squeeze = False + + self.use_lengths = use_lengths + self.mode = mode + + if mode == 'ragged': + assert not use_lengths + + def forward(self, x, state=None, lengths=None, l_output=None): + """ + x: (n_batch, l_seq, d_model) + Returns: (n_batch, l_output, d_output) + """ + + if self.l_output is None: + if l_output is not None: + assert isinstance(l_output, int) # Override by pass in + else: + # Grab entire output + l_output = x.size(-2) + squeeze = False + else: + l_output = self.l_output + squeeze = self.squeeze + + if self.mode == "last": + restrict = lambda x: x[..., -l_output:, :] + elif self.mode == "first": + restrict = lambda x: x[..., :l_output, :] + elif self.mode == "pool": + restrict = lambda x: ( + torch.cumsum(x, dim=-2) + / torch.arange( + 1, 1 + x.size(-2), device=x.device, dtype=x.dtype + ).unsqueeze(-1) + )[..., -l_output:, :] + + def restrict(x): + L = x.size(-2) + s = x.sum(dim=-2, keepdim=True) + if l_output > 1: + c = torch.cumsum(x[..., -(l_output - 1) :, :].flip(-2), dim=-2) + c = F.pad(c, (0, 0, 1, 0)) + s = s - c # (B, l_output, D) + s = s.flip(-2) + denom = torch.arange( + L - l_output + 1, L + 1, dtype=x.dtype, device=x.device + ) + s = s / denom + return s + + elif self.mode == "sum": + restrict = lambda x: torch.cumsum(x, dim=-2)[..., -l_output:, :] + # TODO use same restrict function as pool case + elif self.mode == 'ragged': + assert lengths is not None, "lengths must be provided for ragged mode" + # remove any additional padding (beyond max length of any sequence in the batch) + restrict = lambda x: x[..., : max(lengths), :] + else: + raise NotImplementedError( + "Mode must be ['last' | 'first' | 'pool' | 'sum']" + ) + + # Restrict to actual length of sequence + if self.use_lengths: + assert lengths is not None + x = torch.stack( + [ + restrict(out[..., :length, :]) + for out, length in zip(torch.unbind(x, dim=0), lengths) + ], + dim=0, + ) + else: + x = restrict(x) + + if squeeze: + assert x.size(-2) == 1 + x = x.squeeze(-2) + + x = self.output_transform(x) + + return x + + def step(self, x, state=None): + # Ignore all length logic + return self.output_transform(x) + +class NDDecoder(Decoder): + """Decoder for single target (e.g. classification or regression)""" + def __init__( + self, d_model, d_output=None, mode="pool" + ): + super().__init__() + + assert mode in ["pool", "full"] + self.output_transform = nn.Identity() if d_output is None else nn.Linear(d_model, d_output) + + self.mode = mode + + def forward(self, x, state=None): + """ + x: (n_batch, l_seq, d_model) + Returns: (n_batch, l_output, d_output) + """ + + if self.mode == 'pool': + x = reduce(x, 'b ... h -> b h', 'mean') + x = self.output_transform(x) + return x + +class StateDecoder(Decoder): + """Use the output state to decode (useful for stateful models such as RNNs or perhaps Transformer-XL if it gets implemented""" + + def __init__(self, d_model, state_to_tensor, d_output): + super().__init__() + self.output_transform = nn.Linear(d_model, d_output) + self.state_transform = state_to_tensor + + def forward(self, x, state=None): + return self.output_transform(self.state_transform(state)) + + +class RetrievalHead(nn.Module): + def __init__(self, d_input, d_model, n_classes, nli=True, activation="relu"): + super().__init__() + self.nli = nli + + if activation == "relu": + activation_fn = nn.ReLU() + elif activation == "gelu": + activation_fn = nn.GELU() + else: + raise NotImplementedError + + if ( + self.nli + ): # Architecture from https://github.com/mlpen/Nystromformer/blob/6539b895fa5f798ea0509d19f336d4be787b5708/reorganized_code/LRA/model_wrapper.py#L74 + self.classifier = nn.Sequential( + nn.Linear(4 * d_input, d_model), + activation_fn, + nn.Linear(d_model, n_classes), + ) + else: # Head from https://github.com/google-research/long-range-arena/blob/ad0ff01a5b3492ade621553a1caae383b347e0c1/lra_benchmarks/models/layers/common_layers.py#L232 + self.classifier = nn.Sequential( + nn.Linear(2 * d_input, d_model), + activation_fn, + nn.Linear(d_model, d_model // 2), + activation_fn, + nn.Linear(d_model // 2, n_classes), + ) + + def forward(self, x): + """ + x: (2*batch, dim) + """ + outs = rearrange(x, "(z b) d -> z b d", z=2) + outs0, outs1 = outs[0], outs[1] # (n_batch, d_input) + if self.nli: + features = torch.cat( + [outs0, outs1, outs0 - outs1, outs0 * outs1], dim=-1 + ) # (batch, dim) + else: + features = torch.cat([outs0, outs1], dim=-1) # (batch, dim) + logits = self.classifier(features) + return logits + + +class RetrievalDecoder(Decoder): + """Combines the standard FeatureDecoder to extract a feature before passing through the RetrievalHead""" + + def __init__( + self, + d_input, + n_classes, + d_model=None, + nli=True, + activation="relu", + *args, + **kwargs + ): + super().__init__() + if d_model is None: + d_model = d_input + self.feature = SequenceDecoder( + d_input, d_output=None, l_output=0, *args, **kwargs + ) + self.retrieval = RetrievalHead( + d_input, d_model, n_classes, nli=nli, activation=activation + ) + + def forward(self, x, state=None, **kwargs): + x = self.feature(x, state=state, **kwargs) + x = self.retrieval(x) + return x + +class PackedDecoder(Decoder): + def forward(self, x, state=None): + x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True) + return x + + +# For every type of encoder/decoder, specify: +# - constructor class +# - list of attributes to grab from dataset +# - list of attributes to grab from model + +registry = { + "stop": Decoder, + "id": nn.Identity, + "linear": nn.Linear, + "sequence": SequenceDecoder, + "nd": NDDecoder, + "retrieval": RetrievalDecoder, + "state": StateDecoder, + "pack": PackedDecoder, +} +model_attrs = { + "linear": ["d_output"], + "sequence": ["d_output"], + "nd": ["d_output"], + "retrieval": ["d_output"], + "state": ["d_state", "state_to_tensor"], + "forecast": ["d_output"], +} + +dataset_attrs = { + "linear": ["d_output"], + "sequence": ["d_output", "l_output"], + "nd": ["d_output"], + "retrieval": ["d_output"], + "state": ["d_output"], + "forecast": ["d_output", "l_output"], +} + + +def _instantiate(decoder, model=None, dataset=None): + """Instantiate a single decoder""" + if decoder is None: + return None + + if isinstance(decoder, str): + name = decoder + else: + name = decoder["_name_"] + + # Extract arguments from attribute names + dataset_args = utils.config.extract_attrs_from_obj( + dataset, *dataset_attrs.get(name, []) + ) + model_args = utils.config.extract_attrs_from_obj(model, *model_attrs.get(name, [])) + # Instantiate decoder + obj = utils.instantiate(registry, decoder, *model_args, *dataset_args) + return obj + + +def instantiate(decoder, model=None, dataset=None): + """Instantiate a full decoder config, e.g. handle list of configs + Note that arguments are added in reverse order compared to encoder (model first, then dataset) + """ + decoder = utils.to_list(decoder) + return U.PassthroughSequential( + *[_instantiate(d, model=model, dataset=dataset) for d in decoder] + ) diff --git a/src/clm/src/tasks/encoders.py b/src/clm/src/tasks/encoders.py new file mode 100644 index 00000000..e6eac313 --- /dev/null +++ b/src/clm/src/tasks/encoders.py @@ -0,0 +1,358 @@ +import datetime +import math +from typing import ForwardRef + +import torch +from torch import nn +import torch.nn.functional as F +from einops import rearrange, repeat + +import clm.src.models.nn.utils as U +import clm.src.utils as utils +import clm.src.utils.config +from clm.src.models.sequence.block import SequenceResidualBlock +from clm.src.models.nn.components import Normalization + +class Encoder(nn.Module): + """Encoder abstraction + Accepts a tensor and optional kwargs. Outside of the main tensor, all other arguments should be kwargs. + Returns a tensor and optional kwargs. + Encoders are combined via U.PassthroughSequential which passes these kwargs through in a pipeline. The resulting kwargs are accumulated and passed into the model backbone. + + """ + + def forward(self, x, **kwargs): + """ + x: input tensor + *args: additional info from the dataset (e.g. sequence lengths) + + Returns: + y: output tensor + *args: other arguments to pass into the model backbone + """ + return x, {} + +class PositionalIDEncoder(Encoder): + def forward(self, x): + position_ids = torch.arange(x.shape[-1], dtype=torch.long, device=x.device) + position_ids = repeat(position_ids, 'l -> b l', b=x.shape[0]) + return x, { 'position_ids': position_ids } + +# Adapted from https://github.com/pytorch/examples/blob/master/word_language_model/model.py +class PositionalEncoder(Encoder): + r"""Inject some information about the relative or absolute position of the tokens + in the sequence. The positional encodings have the same dimension as + the embeddings, so that the two can be summed. Here, we use sine and cosine + functions of different frequencies. + .. math:: + \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) + \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) + \text{where pos is the word position and i is the embed idx) + Args: + d_model: the embed dim (required). + dropout: the dropout value (default=0.1). + max_len: the max. length of the incoming sequence (default=5000). + Examples: + >>> pos_encoder = PositionalEncoder(d_model) + """ + + def __init__(self, d_model, dropout=0.1, max_len=16384, pe_init=None): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + if pe_init is not None: + self.pe = nn.Parameter(torch.empty(max_len, 1, d_model)) + nn.init.normal_(self.pe, 0, pe_init) + # self.pe = pe.unsqueeze(1) + else: + pe = torch.zeros(max_len, d_model) + position = torch.arange(0.0, max_len).unsqueeze(1) + div_term = torch.exp( + -math.log(10000.0) * torch.arange(0.0, d_model, 2.0) / d_model + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + self.attn_mask = None + + def forward(self, x): + r"""Inputs of forward function + Args: + x: the sequence fed to the positional encoder model (required). + lens: actual lengths of sequences + Shape: + x: [l_sequence, n_batch, d_model] + Returns: [l_sequence, n_batch, d_model] + attn_mask: [l_sequence, l_sequence] + padding_mask: + """ + + x = x + self.pe[: x.size(-2)] + return self.dropout(x) + + +class ClassEmbedding(Encoder): + # Should also be able to define this by subclassing Embedding + def __init__(self, n_classes, d_model): + super().__init__() + self.embedding = nn.Embedding(n_classes, d_model) + + def forward(self, x, y): + x = x + self.embedding(y).unsqueeze(-2) # (B, L, D) + return x + + +class Conv1DEncoder(Encoder): + def __init__(self, d_input, d_model, kernel_size=25, stride=1, padding='same'): + super().__init__() + self.conv = nn.Conv1d( + in_channels=d_input, + out_channels=d_model, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + def forward(self, x): + # BLD -> BLD + x = self.conv(x.transpose(1, 2)).transpose(1, 2) + return x + +class LayerEncoder(Encoder): + """Use an arbitary SequenceModule layer""" + + def __init__(self, d_model, prenorm=False, norm='layer', layer=None): + super().__init__() + + # Simple stack of blocks + layer["transposed"] = False + self.layer = SequenceResidualBlock( + d_input=d_model, + prenorm=prenorm, + layer=layer, + residual='R', + norm=norm, + pool=None, + ) + + def forward(self, x): + x, _ = self.layer(x) # Discard state + return x + + +class TimestampEmbeddingEncoder(Encoder): + """ + General time encoder for Pandas Timestamp objects (encoded as torch tensors). + See MonashDataset for an example of how to return time features as 'z's. + """ + + cardinalities = { + 'day': (1, 31), + 'hour': (0, 23), + 'minute': (0, 59), + 'second': (0, 59), + 'month': (1, 12), + 'year': (1950, 2010), # (1800, 3000) used to be (1970, datetime.datetime.now().year + 1) but was not enough for all datasets in monash + 'dayofweek': (0, 6), + 'dayofyear': (1, 366), + 'quarter': (1, 4), + 'week': (1, 53), + 'is_month_start': (0, 1), + 'is_month_end': (0, 1), + 'is_quarter_start': (0, 1), + 'is_quarter_end': (0, 1), + 'is_year_start': (0, 1), + 'is_year_end': (0, 1), + 'is_leap_year': (0, 1), + } + + def __init__(self, d_model, table=False, features=None): + super().__init__() + self.table = table + self.ranges = {k: max_val - min_val + 2 for k, (min_val, max_val) in self.cardinalities.items()} # padding for null included + + if features is None: + pass + else: + self.cardinalities = {k: v for k, v in self.cardinalities.items() if k in features} + + if table: + self.embedding = nn.ModuleDict({ + attr: nn.Embedding(maxval - minval + 2, d_model, padding_idx=0) + for attr, (minval, maxval) in self.cardinalities.items() + }) + else: + self.embedding = nn.ModuleDict({ + attr: nn.Linear(1, d_model) + for attr in self.cardinalities + }) + + + + def forward(self, x, timestamps=None): + for attr in timestamps: + mask = timestamps[attr] == -1 + timestamps[attr] = timestamps[attr] - self.cardinalities[attr][0] + timestamps[attr][mask] = 0 + if self.table: + x = x + self.embedding[attr](timestamps[attr].to(torch.long)) + else: + x = x + self.embedding[attr]((2 * timestamps[attr] / self.ranges[attr] - 1).unsqueeze(-1)) + + #x = x + self.embedding(timestamps[attr].to(torch.float)).unsqueeze(1) + return x + + +class TimeEncoder(Encoder): + def __init__(self, n_tokens_time, d_model, timeenc=0): + super().__init__() + + self.timeenc = timeenc + if self.timeenc == 0: + self.encoders = nn.ModuleList( + [nn.Embedding(v, d_model) for v in n_tokens_time] + ) + else: + self.encoders = nn.Linear(len(n_tokens_time), d_model) + self.mask_embed = nn.Embedding(2, d_model) + + def forward(self, x, mark=None, mask=None): + assert mark is not None and mask is not None, "Extra arguments should be returned by collate function" + if self.timeenc == 0: + assert mark.size(-1) == len(self.encoders) + embeddings = [ + embed(z) for embed, z in zip(self.encoders, torch.unbind(mark, dim=-1)) + ] + time_encode = torch.sum(torch.stack(embeddings), dim=0) + else: + time_encode = self.encoders(mark) + mask_encode = self.mask_embed(mask.squeeze(-1)) + return x + time_encode + mask_encode # (B, L, d_model) + + +class PackedEncoder(Encoder): + def forward(self, x, len_batch=None): + assert len_batch is not None + x = nn.utils.rnn.pack_padded_sequence( + x, len_batch.cpu(), enforce_sorted=False, batch_first=True, + ) + return x + + +class OneHotEncoder(Encoder): + def __init__(self, n_tokens, d_model): + super().__init__() + assert n_tokens <= d_model + self.d_model = d_model + + def forward(self, x): + return F.one_hot(x.squeeze(-1), self.d_model).float() + + +class Conv2DPatchEncoder(Encoder): + + """ + For encoding images into a sequence of patches. + """ + + def __init__(self, d_input, d_model, filter_sizes, flat=False): + + """ + d_input: dim of encoder input (data dimension) + d_model: dim of encoder output (model dimension) + filter_sizes: tuple with fh, fw + flat: if image is flattened from dataloader (like in cifar), + then we need to reshape back to 2D before conv + """ + + fh, fw = filter_sizes + self.flat = flat + + super().__init__() + assert len(filter_sizes) == 2 + + self.encoder = nn.Conv2d(d_input, d_model, kernel_size=(fh, fw), stride=(fh, fw)) + + def forward(self, x): + + """ + x shape expected = [b, h, w, c] + returns tuple with x, with new shape = [b, seq_len, c_out] + + """ + + x = rearrange(x, 'b h w c -> b c h w') + x = self.encoder(x) + x = rearrange(x, 'b c h w -> b (h w) c') + return x + + +# For every type of encoder/decoder, specify: +# - constructor class +# - list of attributes to grab from dataset +# - list of attributes to grab from model + +registry = { + "stop": Encoder, + "id": nn.Identity, + "embedding": nn.Embedding, + "linear": nn.Linear, + "position": PositionalEncoder, + "position_id": PositionalIDEncoder, + "class": ClassEmbedding, + "pack": PackedEncoder, + "time": TimeEncoder, + "onehot": OneHotEncoder, + "conv1d": Conv1DEncoder, + "patch2d": Conv2DPatchEncoder, + "timestamp_embedding": TimestampEmbeddingEncoder, + "layer": LayerEncoder, +} +dataset_attrs = { + "embedding": ["n_tokens"], + "linear": ["d_input"], # TODO make this d_data? + "class": ["n_classes"], + "time": ["n_tokens_time"], + "onehot": ["n_tokens"], + "conv1d": ["d_input"], + "patch2d": ["d_input"], +} +model_attrs = { + "embedding": ["d_model"], + "linear": ["d_model"], + "position": ["d_model"], + "class": ["d_model"], + "time": ["d_model"], + "onehot": ["d_model"], + "conv1d": ["d_model"], + "patch2d": ["d_model"], + "timestamp_embedding": ["d_model"], + "layer": ["d_model"], +} + + +def _instantiate(encoder, dataset=None, model=None): + """Instantiate a single encoder""" + if encoder is None: + return None + if isinstance(encoder, str): + name = encoder + else: + name = encoder["_name_"] + + # Extract dataset/model arguments from attribute names + dataset_args = utils.config.extract_attrs_from_obj( + dataset, *dataset_attrs.get(name, []) + ) + model_args = utils.config.extract_attrs_from_obj(model, *model_attrs.get(name, [])) + + # Instantiate encoder + obj = utils.instantiate(registry, encoder, *dataset_args, *model_args) + return obj + + +def instantiate(encoder, dataset=None, model=None): + encoder = utils.to_list(encoder) + return U.PassthroughSequential( + *[_instantiate(e, dataset=dataset, model=model) for e in encoder] + ) diff --git a/src/clm/src/tasks/metrics.py b/src/clm/src/tasks/metrics.py new file mode 100644 index 00000000..234547e0 --- /dev/null +++ b/src/clm/src/tasks/metrics.py @@ -0,0 +1,225 @@ +import math +import torch +import torch.nn.functional as F +from sklearn.metrics import f1_score, roc_auc_score +from functools import partial +import torchmetrics.functional as tm_f + +def _student_t_map(mu, sigma, nu): + sigma = F.softplus(sigma) + nu = 2.0 + F.softplus(nu) + return mu.squeeze(axis=-1), sigma.squeeze(axis=-1), nu.squeeze(axis=-1) + +def student_t_loss(outs, y): + mu, sigma, nu = outs[..., 0], outs[..., 1], outs[..., 2] + mu, sigma, nu = _student_t_map(mu, sigma, nu) + y = y.squeeze(axis=-1) + + nup1_half = (nu + 1.0) / 2.0 + part1 = 1.0 / nu * torch.square((y - mu) / sigma) + Z = ( + torch.lgamma(nup1_half) + - torch.lgamma(nu / 2.0) + - 0.5 * torch.log(math.pi * nu) + - torch.log(sigma) + ) + + ll = Z - nup1_half * torch.log1p(part1) + return -ll.mean() + +def gaussian_ll_loss(outs, y): + mu, sigma = outs[..., 0], outs[..., 1] + y = y.squeeze(axis=-1) + sigma = F.softplus(sigma) + ll = -1.0 * ( + torch.log(sigma) + + 0.5 * math.log(2 * math.pi) + + 0.5 * torch.square((y - mu) / sigma) + ) + return -ll.mean() + +def binary_cross_entropy(logits, y): + # BCE loss requires squeezing last dimension of logits so it has the same shape as y + # requires y to be float, since it's overloaded to represent a probability + return F.binary_cross_entropy_with_logits(logits.squeeze(-1), y.float()) + + +def binary_accuracy(logits, y): + return torch.eq(logits.squeeze(-1) >= 0, y).float().mean() + + +def cross_entropy(logits, y): + logits = logits.view(-1, logits.shape[-1]) + y = y.view(-1) + return F.cross_entropy(logits, y) + + +def soft_cross_entropy(logits, y, label_smoothing=0.0): + logits = logits.view(-1, logits.shape[-1]) + # target is now 2d (no target flattening) + return F.cross_entropy(logits, y, label_smoothing=label_smoothing) + + +def accuracy(logits, y): + logits = logits.view(-1, logits.shape[-1]) + if y.numel() > logits.shape[0]: + # Mixup leads to this case: use argmax class + y = y.argmax(dim=-1) + y = y.view(-1) + return torch.eq(torch.argmax(logits, dim=-1), y).float().mean() + +def accuracy_ignore_index(logits, y, ignore_index=-100): + num_classes = logits.shape[-1] + preds = torch.argmax(logits, dim=-1) + logits = logits.view(-1, logits.shape[-1]) + y = y.view(-1) + return tm_f.classification.accuracy(preds, y, 'multiclass', num_classes=num_classes, ignore_index=ignore_index, average='micro') + + +def accuracy_at_k(logits, y, k=1): + logits = logits.view(-1, logits.shape[-1]) + if y.numel() > logits.shape[0]: + # Mixup leads to this case: use argmax class + y = y.argmax(dim=-1) + y = y.view(-1) + return torch.topk(logits, k, dim=-1)[1].eq(y.unsqueeze(-1)).any(dim=-1).float().mean() + + +def f1_binary(logits, y): + logits = logits.view(-1, logits.shape[-1]) + y = y.view(-1) + y_hat = torch.argmax(logits, dim=-1) + return f1_score(y.cpu().numpy(), y_hat.cpu().numpy(), average="binary") + + +def f1_macro(logits, y): + logits = logits.view(-1, logits.shape[-1]) + y = y.view(-1) + y_hat = torch.argmax(logits, dim=-1) + return f1_score(y.cpu().numpy(), y_hat.cpu().numpy(), average="macro") + + +def f1_micro(logits, y): + logits = logits.view(-1, logits.shape[-1]) + y = y.view(-1) + y_hat = torch.argmax(logits, dim=-1) + return f1_score(y.cpu().numpy(), y_hat.cpu().numpy(), average="micro") + + +def roc_auc_macro(logits, y): + logits = logits.view( + -1, logits.shape[-1] + ).detach() # KS: had to add detach to eval while training + y = y.view(-1) + return roc_auc_score( + y.cpu().numpy(), F.softmax(logits, dim=-1).cpu().numpy()[:, 1], average="macro" + ) + + +def roc_auc_micro(logits, y): + logits = logits.view(-1, logits.shape[-1]) + y = y.view(-1) + return roc_auc_score( + y.cpu().numpy(), F.softmax(logits, dim=-1).cpu().numpy()[:, 1], average="micro" + ) + + +def mse(outs, y, len_batch=None): + # assert outs.shape[:-1] == y.shape and outs.shape[-1] == 1 + # outs = outs.squeeze(-1) + if len(y.shape) < len(outs.shape): + assert outs.shape[-1] == 1 + outs = outs.squeeze(-1) + if len_batch is None: + return F.mse_loss(outs, y) + else: + # Computes the loss of the first `lens` items in the batches + # TODO document the use case of this + mask = torch.zeros_like(outs, dtype=torch.bool) + for i, l in enumerate(len_batch): + mask[i, :l, :] = 1 + outs_masked = torch.masked_select(outs, mask) + y_masked = torch.masked_select(y, mask) + return F.mse_loss(outs_masked, y_masked) + +def forecast_rmse(outs, y, len_batch=None): + # TODO: generalize, currently for Monash dataset + return torch.sqrt(F.mse_loss(outs, y, reduction='none').mean(1)).mean() + +def mae(outs, y, len_batch=None): + # assert outs.shape[:-1] == y.shape and outs.shape[-1] == 1 + # outs = outs.squeeze(-1) + if len(y.shape) < len(outs.shape): + assert outs.shape[-1] == 1 + outs = outs.squeeze(-1) + if len_batch is None: + return F.l1_loss(outs, y) + else: + # Computes the loss of the first `lens` items in the batches + mask = torch.zeros_like(outs, dtype=torch.bool) + for i, l in enumerate(len_batch): + mask[i, :l, :] = 1 + outs_masked = torch.masked_select(outs, mask) + y_masked = torch.masked_select(y, mask) + return F.l1_loss(outs_masked, y_masked) + + +# Metrics that can depend on the loss +def loss(x, y, loss_fn): + """ This metric may be useful because the training loss may add extra regularization (e.g. weight decay implemented as L2 penalty), while adding this as a metric skips the additional losses """ + return loss_fn(x, y) + + +def bpb(x, y, loss_fn): + """ bits per byte (image density estimation, speech generation, char LM) """ + return loss_fn(x, y) / math.log(2) + + +def ppl(x, y, loss_fn): + return torch.exp(loss_fn(x, y)) + + +# should have a better way to do this +output_metric_fns = { + "binary_cross_entropy": binary_cross_entropy, + "cross_entropy": cross_entropy, + "binary_accuracy": binary_accuracy, + "accuracy": accuracy, + "accuracy_ignore_index": accuracy_ignore_index, + 'accuracy@3': partial(accuracy_at_k, k=3), + 'accuracy@5': partial(accuracy_at_k, k=5), + 'accuracy@10': partial(accuracy_at_k, k=10), + "eval_loss": loss, + "mse": mse, + "mae": mae, + "forecast_rmse": forecast_rmse, + "f1_binary": f1_binary, + "f1_macro": f1_macro, + "f1_micro": f1_micro, + "roc_auc_macro": roc_auc_macro, + "roc_auc_micro": roc_auc_micro, + "soft_cross_entropy": soft_cross_entropy, # only for pytorch 1.10+ + "student_t": student_t_loss, + "gaussian_ll": gaussian_ll_loss, +} + +try: + from segmentation_models_pytorch.utils.functional import iou + from segmentation_models_pytorch.losses.focal import focal_loss_with_logits + + def iou_with_logits(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): + return iou(pr.sigmoid(), gt, eps=eps, threshold=threshold, ignore_channels=ignore_channels) + + output_metric_fns["iou"] = partial(iou, threshold=0.5) + output_metric_fns["iou_with_logits"] = partial(iou_with_logits, threshold=0.5) + output_metric_fns["focal_loss"] = focal_loss_with_logits +except ImportError: + pass + +loss_metric_fns = { + "loss": loss, + "bpb": bpb, + "ppl": ppl, +} +metric_fns = {**output_metric_fns, **loss_metric_fns} # TODO py3.9 + diff --git a/src/clm/src/tasks/tasks.py b/src/clm/src/tasks/tasks.py new file mode 100644 index 00000000..8f0be6ea --- /dev/null +++ b/src/clm/src/tasks/tasks.py @@ -0,0 +1,371 @@ +from typing import Optional, List, Tuple +import math +import functools +import collections +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from omegaconf import ListConfig +from clm.src.models.nn.components import ReversibleInstanceNorm1dInput, ReversibleInstanceNorm1dOutput, \ + TSNormalization, TSInverseNormalization + +from clm.src.models.nn.adaptive_softmax import AdaptiveEmbedding, ProjectedAdaptiveLogSoftmax +import clm.src.tasks.metrics as M +from clm.src.tasks.torchmetrics import torchmetric_fns as tm_mine +import clm.src.models.nn.utils as U +import torchmetrics as tm +from clm.src.utils.config import to_list, instantiate +from torchmetrics import MetricCollection + +class BaseTask: + """ Abstract class that takes care of: + - loss function + - arbitrary metrics + - forward pass + - (optional) encoder module that interfaces with dataset (inputs) and model + - (optional) decoder module that interfaces with dataset (targets) and model + """ + encoder = None + decoder = None + + def __init__(self, dataset=None, model=None, loss=None, loss_val=None, metrics=None, torchmetrics=None): + """ This class is allowed to grab attributes directly off a constructed dataset and model object """ + self.dataset = dataset + self.model = model + if metrics is None: metrics = [] + self.metric_names = to_list(metrics) + + if torchmetrics is None: torchmetrics = [] + self.torchmetric_names = to_list(torchmetrics) + self._tracked_torchmetrics = {} + + # The decoder might pass through arguments that the loss needs (e.g. sequence lengths) + # but might also pass through extraneous arguments (e.g. sampling rate) + # Wrap loss and metrics so that they accept kwargs and + + # Create loss function + self.loss = instantiate(M.output_metric_fns, loss, partial=True) + self.loss = U.discard_kwargs(self.loss) + if loss_val is not None: + self.loss_val = instantiate(M.output_metric_fns, loss_val, partial=True) + self.loss_val = U.discard_kwargs(self.loss_val) + torchmetrics = MetricCollection(self._init_torchmetrics()) + self.train_torchmetrics = torchmetrics.clone(prefix='train/') + self.val_torchmetrics = torchmetrics.clone(prefix='val/') + self.test_torchmetrics = torchmetrics.clone(prefix='test/') + + def _init_torchmetrics(self): + """ + Instantiate torchmetrics. + """ + tracked_torchmetrics = {} + + for name in self.torchmetric_names: + if name in tm_mine: + tracked_torchmetrics[name] = tm_mine[name]().to('cuda') + elif name in ['AUROC', 'StatScores', 'Precision', 'Recall', 'F1', 'F1Score']: + tracked_torchmetrics[name] = getattr(tm, name)(average='macro', num_classes=self.dataset.d_output, compute_on_step=False).to('cuda') + elif '@' in name: + k = int(name.split('@')[1]) + mname = name.split('@')[0] + tracked_torchmetrics[name] = getattr(tm, mname)(average='macro', num_classes=self.dataset.d_output, compute_on_step=False, top_k=k).to('cuda') + else: + tracked_torchmetrics[name] = getattr(tm, name)(compute_on_step=False).to('cuda') + + return tracked_torchmetrics + + def _reset_torchmetrics(self, prefix=None): + """ + Reset torchmetrics for a prefix + associated with a particular dataloader (e.g. train, val, test). + + Generally do this at the start of an epoch. + """ + all_prefixes = [prefix] if prefix is not None else self._tracked_torchmetrics + + for prefix in all_prefixes: + if prefix in self._tracked_torchmetrics: + self._tracked_torchmetrics[prefix].reset() + + def get_torchmetrics(self, prefix): + """ + Compute torchmetrics for a prefix associated with + a particular dataloader (e.g. train, val, test). + + Generally do this at the end of an epoch. + """ + return {name: self._tracked_torchmetrics[prefix][name].compute() for name in self.torchmetric_names} + + def torchmetrics(self, x, y, prefix, loss=None): + """ + Update torchmetrics with new x, y . + Prefix corresponds to a particular dataloader (e.g. train, val, test). + + Generally call this every batch. + """ + if prefix not in self._tracked_torchmetrics: + self._init_torchmetrics(prefix) + self._tracked_torchmetrics[prefix](x, y, loss=loss) + + # for name in self.torchmetric_names: + # if name.startswith('Accuracy'): + # if len(x.shape) > 2: + # # Multi-dimensional, multi-class + # self._tracked_torchmetrics[prefix][name].update(x.transpose(1, 2), y.squeeze()) + # continue + # self._tracked_torchmetrics[prefix][name].update(x, y) + + def get_torchmetrics(self, prefix): + return self._tracked_torchmetrics[prefix] + + def metrics(self, x, y, **kwargs): + """ + Metrics are just functions + output metrics are a function of output and target + loss metrics are a function of loss (e.g. perplexity) + """ + output_metrics = { + name: U.discard_kwargs(M.output_metric_fns[name])(x, y, **kwargs) + for name in self.metric_names if name in M.output_metric_fns + } + loss_metrics = { + name: U.discard_kwargs(M.loss_metric_fns[name])(x, y, self.loss, **kwargs) + for name in self.metric_names if name in M.loss_metric_fns + } + return {**output_metrics, **loss_metrics} + + def forward(self, batch, encoder, model, decoder, _state): + """Passes a batch through the encoder, backbone, and decoder""" + # z holds arguments such as sequence length + x, y, *z = batch # z holds extra dataloader info such as resolution + if len(z) == 0: + z = {} + else: + assert len(z) == 1 and isinstance(z[0], dict), "Dataloader must return dictionary of extra arguments" + z = z[0] + + x, w = encoder(x, **z) # w can model-specific constructions such as key_padding_mask for transformers or state for RNNs + x, state = model(x, **w, state=_state) + self._state = state + x, w = decoder(x, state=state, **z) + return x, y, w + + +class Scalar(nn.Module): + def __init__(self, c=1): + super().__init__() + self.c = c + def forward(self, x): + return x * self.c + +class LMTask(BaseTask): + def forward(self, batch, encoder, model, decoder, _state): + """Passes a batch through the encoder, backbone, and decoder""" + # z holds arguments such as sequence length + x, y, *z = batch # z holds extra dataloader info such as resolution + if len(z) == 0: + z = {} + else: + assert len(z) == 1 and isinstance(z[0], dict), "Dataloader must return dictionary of extra arguments" + z = z[0] + x, w = encoder(x, **z) # w can model-specific constructions such as key_padding_mask for transformers or state for RNNs + x, state = model(x, **w, state=_state) + self._state = state + x, w = decoder(x, state=state, **z) + + x = x.logits + x = rearrange(x, '... C -> (...) C') + y = rearrange(y, '... -> (...)') + + return x, y, w + +class ForecastingTask(BaseTask): + + class DummyModule(nn.Module): + + def forward(self, *args): + return args + + def __init__(self, norm='mean', **kwargs): + super().__init__(**kwargs) + + if norm == 'revnorm': + self.encoder = ReversibleInstanceNorm1dInput(self.dataset.d_input, transposed=False) + self.decoder = ReversibleInstanceNorm1dOutput(self.encoder) + elif norm == 'mean': + self.encoder = TSNormalization(method='mean', horizon=self.dataset.dataset_train.forecast_horizon) + self.decoder = TSInverseNormalization(method='mean', normalizer=self.encoder) + elif norm == 'last': + self.encoder = TSNormalization(method='last', horizon=self.dataset.dataset_train.forecast_horizon) + self.decoder = TSInverseNormalization(method='last', normalizer=self.encoder) + else: + self.encoder = None + self.decoder = None + + try: + if hasattr(self.dataset.dataset_train, 'mean'): + self.mean = torch.tensor(self.dataset.dataset_train.mean) + self.std = torch.tensor(self.dataset.dataset_train.std) + elif hasattr(self.dataset.dataset_train, 'standardization'): + self.mean = torch.tensor(self.dataset.dataset_train.standardization['means']) + self.std = torch.tensor(self.dataset.dataset_train.standardization['stds']) + else: + self.mean = None + self.std = None + except AttributeError: + raise AttributeError('Dataset does not have mean/std attributes') + self.mean = torch.tensor(self.dataset.dataset_train.standardization['means']) + self.std = torch.tensor(self.dataset.dataset_train.standardization['stds']) + + if hasattr(self.dataset.dataset_train, 'log_transform'): + self.log_transform = self.dataset.dataset_train.log_transform + else: + self.log_transform = False + print("Log Transform", self.log_transform) + + def metrics(self, x, y, state=None, timestamps=None, ids=None): # Explicit about which arguments the decoder might pass through, but can future-proof with **kwargs + if self.mean is not None: + means = self.mean[ids].to(x.device) + stds = self.std[ids].to(x.device) + x_ = x * stds[:, None, None] + means[:, None, None] + y_ = y * stds[:, None, None] + means[:, None, None] + else: + x_ = x + y_ = y + + if self.log_transform: + x_ = torch.exp(x_) + y_ = torch.exp(y_) + + return super().metrics(x_, y_) + +class VideoTask(BaseTask): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + # self._y_to_logits = {} + self._vid_to_logits = {} + self._vid_to_label = {} + + # TODO needed to extract the first element of y, which includes the video idea; there should be a cleaner pattern to this + import copy + loss_fn = copy.deepcopy(self.loss) + self.loss = lambda x, y: loss_fn(x, y[0]) + if hasattr(self, 'loss_val'): + loss_val_fn = copy.deepcopy(self.loss_val) + self.loss_val = lambda x, y: loss_val_fn(x, y[0]) + + def metrics(self, logits, y, **kwargs): + labels, vids = y + return super().metrics(logits, labels, **kwargs) + + def torchmetrics(self, logits, y, prefix): + """ + logits: (batch, n_classes) + y = tuple of labels and video ids + labels: (batch) + vids: (batch) + """ + for _logits, _label, _vid in zip(logits, y[0], y[1]): + _vid = _vid.item() + # Check that labels are consistent per video id + assert self._vid_to_label[prefix].get(_vid, _label) == _label + self._vid_to_label[prefix][_vid] = _label + + self._vid_to_logits[prefix][_vid].append(_logits) + + def _reset_torchmetrics(self, prefix): + self._vid_to_logits[prefix] = collections.defaultdict(list) + self._vid_to_label[prefix] = {} + + def get_torchmetrics(self, prefix): + vid_to_average_logits = {vid: torch.mean(torch.stack(logits, dim=0), dim=0) for vid, logits in self._vid_to_logits[prefix].items()} + # y is (label, vid) pair + all_labels = torch.stack(list(self._vid_to_label[prefix].values()), dim=0) # (n_videos) + all_logits = torch.stack(list(vid_to_average_logits.values()), dim=0) # (n_videos, n_classes) + m = M.accuracy(all_logits, all_labels) + return {'aggregate_accuracy': m} + + +class AdaptiveLMTask(BaseTask): + def __init__( + self, + div_val, + cutoffs : List[int], + tie_weights : bool, + tie_projs : List[bool], + init_scale=1.0, + bias_scale=0.0, + dropemb=0.0, + dropsoft=0.0, + **kwargs, + ): + super().__init__(**kwargs) + n_tokens = self.dataset.n_tokens + d_model = self.model.d_model + d_output = self.model.d_output + + encoder = AdaptiveEmbedding( + n_tokens, + d_model, + d_model, + cutoffs=cutoffs, + div_val=div_val, + init_scale=init_scale, + dropout=dropemb, + ) + + if tie_weights: + assert d_model == d_output + emb_layers = [i.weight for i in encoder.emb_layers] + else: + emb_layers = None + + # Construct decoder/loss + emb_projs = encoder.emb_projs + loss = ProjectedAdaptiveLogSoftmax( + n_tokens, d_output, d_output, + cutoffs, div_val=div_val, + tie_projs=tie_projs, + out_projs=emb_projs, + out_layers_weights=emb_layers, + bias_scale=bias_scale, + dropout=dropsoft, + ) + + self.encoder = encoder + self.loss = loss + + +class ImageNetTask(BaseTask): + """ + Imagenet training uses mixup augmentations, which require a separate loss for train and val, + which we overide the base task here. + """ + + def __init__(self, **kwargs): + import hydra + + super().__init__( + dataset=kwargs.get("dataset", None), + model=kwargs.get("model", None), + loss=kwargs.get("loss", None), # we still create the base loss here, but will overide below + metrics=kwargs.get("metrics", None), + torchmetrics=kwargs.get("torchmetrics", None) + ) + + # if using mixup, overide loss (train) and loss_val, otherwise + # we have just one loss from the base task above + if "loss_val" in kwargs and "loss_train" in kwargs: + self.loss = hydra.utils.instantiate(kwargs.get("loss_train")) + self.loss_val = hydra.utils.instantiate(kwargs.get('loss_val')) + + +registry = { + 'base': BaseTask, + 'lm': LMTask, + 'imagenet': ImageNetTask, + 'forecasting': ForecastingTask, + 'video': VideoTask, +} diff --git a/src/clm/src/tasks/torchmetrics.py b/src/clm/src/tasks/torchmetrics.py new file mode 100644 index 00000000..f3580b43 --- /dev/null +++ b/src/clm/src/tasks/torchmetrics.py @@ -0,0 +1,120 @@ +# Inspired by https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/metrics/perplexity.py +# But we compute the perplexity correctly: exp(average(nll)), not average(exp(nll)) +# Also adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/perplexity.py +# But we pass in the loss to avoid recomputation + +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import Tensor +from torchmetrics import Metric + +try: + from flash_attn.losses.cross_entropy import CrossEntropyLoss +except ImportError: + CrossEntropyLoss = torch.nn.CrossEntropyLoss + +try: + from apex.transformer import parallel_state +except ImportError: + parallel_state = None + + +class Perplexity(Metric): + r""" + Perplexity measures how well a language model predicts a text sample. It's calculated as the average number of bits + per word a model needs to represent the sample. + Args: + kwargs: + Additional keyword arguments, see :ref:`Metric kwargs` for more info. + Examples: + >>> import torch + >>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22)) + >>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22)) + >>> target[0, 6:] = -100 + >>> metric = Perplexity(ignore_index=-100) + >>> metric(preds, target) + tensor(5.2545) + """ + is_differentiable = True + higher_is_better = False + full_state_update = False + total_log_probs: Tensor + count: Tensor + + def __init__(self, **kwargs: Dict[str, Any]): + super().__init__(**kwargs) + self.add_state("total_log_probs", default=torch.tensor(0.0, dtype=torch.float64), + dist_reduce_fx="sum") + self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum") + + self.loss_fn = CrossEntropyLoss() + + def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore + """Compute and store intermediate statistics for Perplexity. + Args: + preds: + Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size]. + target: + Ground truth values with a shape [batch_size, seq_len]. + """ + count = target.numel() + if loss is None: + loss = self.loss_fn(preds, target) + self.total_log_probs += loss.double() * count + self.count += count + + def compute(self) -> Tensor: + """Compute the Perplexity. + Returns: + Perplexity + """ + return torch.exp(self.total_log_probs / self.count) + +class NumTokens(Metric): + """Keep track of how many tokens we've seen. + """ + # TODO: how do we prevent the reset between the epochs? The reset happens on the 1st batch + # of the next epoch. + # Right now the hack is that we override reset(), which would mess up the forward method. + # We then override forward to do the right thing. + + is_differentiable = False + higher_is_better = False + full_state_update = False + count: Tensor + + def __init__(self, **kwargs: Dict[str, Any]): + super().__init__(**kwargs) + self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum", + persistent=True) # We want the count to be saved to state-dict + if parallel_state is not None and not parallel_state.is_unitialized(): + self.tensor_parallel_world_size = parallel_state.get_tensor_model_parallel_world_size() + else: + self.tensor_parallel_world_size = 1 + + def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore + self.count += target.numel() // self.tensor_parallel_world_size + + def compute(self) -> Tensor: + return self.count + + def reset(self): + count = self.count + super().reset() + self.count = count + + # Adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/metric.py + def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: + """forward computation using single call to `update` to calculate the metric value on the current batch and + accumulate global state. + This can be done when the global metric state is a sinple reduction of batch states. + """ + self.update(*args, **kwargs) + return self.compute() + +torchmetric_fns = { + "perplexity": Perplexity, + "num_tokens": NumTokens, +} \ No newline at end of file diff --git a/src/clm/src/utils/__init__.py b/src/clm/src/utils/__init__.py new file mode 100644 index 00000000..960c2b9a --- /dev/null +++ b/src/clm/src/utils/__init__.py @@ -0,0 +1 @@ +from .config import is_list, is_dict, to_list, to_dict, get_class, instantiate diff --git a/src/clm/src/utils/config.py b/src/clm/src/utils/config.py new file mode 100644 index 00000000..9020037c --- /dev/null +++ b/src/clm/src/utils/config.py @@ -0,0 +1,124 @@ +""" Utilities for dealing with collection objects (lists, dicts) and configs """ +from typing import Sequence, Mapping, Optional, Callable +import functools +import hydra +from omegaconf import ListConfig, DictConfig + +# TODO this is usually used in a pattern where it's turned into a list, so can just do that here +def is_list(x): + return isinstance(x, Sequence) and not isinstance(x, str) + + +def is_dict(x): + return isinstance(x, Mapping) + + +def to_dict(x, recursive=True): + """Convert Sequence or Mapping object to dict + + lists get converted to {0: x[0], 1: x[1], ...} + """ + if is_list(x): + x = {i: v for i, v in enumerate(x)} + if is_dict(x): + if recursive: + return {k: to_dict(v, recursive=recursive) for k, v in x.items()} + else: + return dict(x) + else: + return x + + +def to_list(x, recursive=False): + """Convert an object to list. + + If Sequence (e.g. list, tuple, Listconfig): just return it + + Special case: If non-recursive and not a list, wrap in list + """ + if is_list(x): + if recursive: + return [to_list(_x) for _x in x] + else: + return list(x) + else: + if recursive: + return x + else: + return [x] + + +def extract_attrs_from_obj(obj, *attrs): + if obj is None: + assert len(attrs) == 0 + return [] + return [getattr(obj, attr, None) for attr in attrs] + + +def auto_assign_attrs(cls, **kwargs): + for k, v in kwargs.items(): + setattr(cls, k, v) + + +def instantiate(registry, config, *args, partial=False, wrap=None, **kwargs): + """ + registry: Dictionary mapping names to functions or target paths (e.g. {'model': 'models.SequenceModel'}) + config: Dictionary with a '_name_' key indicating which element of the registry to grab, and kwargs to be passed into the target constructor + wrap: wrap the target class (e.g. ema optimizer or tasks.wrap) + *args, **kwargs: additional arguments to override the config to pass into the target constructor + """ + + # Case 1: no config + if config is None: + return None + # Case 2a: string means _name_ was overloaded + if isinstance(config, str): + _name_ = None + _target_ = registry[config] + config = {} + # Case 2b: grab the desired callable from name + else: + _name_ = config.pop("_name_") + _target_ = registry[_name_] + + # Retrieve the right constructor automatically based on type + if isinstance(_target_, str): + fn = hydra.utils.get_method(path=_target_) + elif isinstance(_target_, Callable): + fn = _target_ + else: + raise NotImplementedError("instantiate target must be string or callable") + + # Instantiate object + if wrap is not None: + fn = wrap(fn) + obj = functools.partial(fn, *args, **config, **kwargs) + + # Restore _name_ + if _name_ is not None: + config["_name_"] = _name_ + + if partial: + return obj + else: + return obj() + + +def get_class(registry, _name_): + return hydra.utils.get_class(path=registry[_name_]) + + +def omegaconf_filter_keys(d, fn=None): + """Only keep keys where fn(key) is True. Support nested DictConfig. + # TODO can make this inplace? + """ + if fn is None: + fn = lambda _: True + if is_list(d): + return ListConfig([omegaconf_filter_keys(v, fn) for v in d]) + elif is_dict(d): + return DictConfig( + {k: omegaconf_filter_keys(v, fn) for k, v in d.items() if fn(k)} + ) + else: + return d diff --git a/src/clm/src/utils/distributed.py b/src/clm/src/utils/distributed.py new file mode 100644 index 00000000..77d7ecf2 --- /dev/null +++ b/src/clm/src/utils/distributed.py @@ -0,0 +1,144 @@ +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from contextlib import contextmanager + +import torch + + +def init_distributed(cuda): + """ + Initializes distributed backend. + + :param cuda: (bool) if True initializes nccl backend, if False initializes + gloo backend + """ + world_size = int(os.environ.get('WORLD_SIZE', 1)) + distributed = (world_size > 1) + if distributed: + backend = 'nccl' if cuda else 'gloo' + torch.distributed.init_process_group(backend=backend, + init_method='env://') + assert torch.distributed.is_initialized() + return distributed + + +def barrier(): + """ + Call torch.distributed.barrier() if distritubed is in use + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + torch.distributed.barrier() + + +def get_rank(): + """ + Gets distributed rank or returns zero if distributed is not initialized. + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + else: + rank = 0 + return rank + + +def get_world_size(): + """ + Gets total number of distributed workers or returns one if distributed is + not initialized. + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + world_size = torch.distributed.get_world_size() + else: + world_size = 1 + return world_size + + +def all_reduce_item(value, op='sum'): + """ + All-reduces single scalar value if distributed is in use + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + if op == 'sum' or op == 'mean': + dop = torch.distributed.ReduceOp.SUM + elif op == 'min': + dop = torch.distributed.ReduceOp.MIN + elif op == 'max': + dop = torch.distributed.ReduceOp.MAX + elif op == 'product': + dop = torch.distributed.ReduceOp.PRODUCT + else: + raise RuntimeError('Unsupported reduce op') + + backend = torch.distributed.get_backend() + if backend == torch.distributed.Backend.NCCL: + device = torch.device('cuda') + elif backend == torch.distributed.Backend.GLOO: + device = torch.device('cpu') + else: + raise RuntimeError('Unsupported distributed backend') + + tensor = torch.tensor(value, device=device) + torch.distributed.all_reduce(tensor, dop) + if op == 'mean': + tensor /= get_world_size() + ret = tensor.item() + else: + ret = value + return ret + + +def all_reduce_tensor(value, op='sum'): + """ + All-reduces single scalar value if distributed is in use + """ + if torch.distributed.is_available() and torch.distributed.is_initialized(): + if op == 'sum' or op == 'mean': + dop = torch.distributed.ReduceOp.SUM + elif op == 'min': + dop = torch.distributed.ReduceOp.MIN + elif op == 'max': + dop = torch.distributed.ReduceOp.MAX + elif op == 'product': + dop = torch.distributed.ReduceOp.PRODUCT + else: + raise RuntimeError('Unsupported reduce op') + + backend = torch.distributed.get_backend() + if backend == torch.distributed.Backend.NCCL: + device = torch.device('cuda') + elif backend == torch.distributed.Backend.GLOO: + device = torch.device('cpu') + else: + raise RuntimeError('Unsupported distributed backend') + + tensor = value + torch.distributed.all_reduce(tensor, dop) + if op == 'mean': + tensor /= get_world_size() + ret = tensor + else: + ret = value + return ret + + +@contextmanager +def sync_workers(): + """ + Yields distributed rank and synchronizes all workers on exit. + """ + rank = get_rank() + yield rank + barrier() diff --git a/src/clm/src/utils/optim/lamb.py b/src/clm/src/utils/optim/lamb.py new file mode 100644 index 00000000..8bbdf3a2 --- /dev/null +++ b/src/clm/src/utils/optim/lamb.py @@ -0,0 +1,251 @@ +# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# MIT License +# +# Copyright (c) 2019 cybertronai +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Lamb optimizer.""" + +import torch +from torch.optim import Optimizer + + +class Lamb(Optimizer): + r"""Implements Lamb algorithm. + + It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + adam (bool, optional): always use trust ratio = 1, which turns this into + Adam. Useful for comparison purposes. + + .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, + weight_decay=0, adam=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay) + self.adam = adam + super(Lamb, self).__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Lamb does not support sparse gradients.') + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + + # Decay the first and second moment running average coefficient + # m_t + exp_avg.mul_(beta1).add_(1 - beta1, grad) + # v_t + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + + # Paper v3 does not use debiasing. + # bias_correction1 = 1 - beta1 ** state['step'] + # bias_correction2 = 1 - beta2 ** state['step'] + # Apply bias to lr to avoid broadcast. + step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1 + + weight_norm = p.data.norm(p=2).clamp_(0, 10) + + adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) + if group['weight_decay'] != 0: + adam_step.add_(group['weight_decay'], p.data) + + adam_norm = adam_step.norm(p=2) + + if weight_norm == 0.0 or adam_norm == 0.0: + trust_ratio = 1 + else: + trust_ratio = weight_norm / (adam_norm + group['eps']) + + state['weight_norm'] = weight_norm + state['adam_norm'] = adam_norm + state['trust_ratio'] = trust_ratio + if self.adam: + trust_ratio = 1 + + p.data.add_(-step_size * trust_ratio, adam_step) + + return loss + + +@torch.jit.script +def lamb_kernel(param, grad, exp_avg, exp_avg_sq, beta1: float, + beta2: float, step_size: float, eps: float, weight_decay: float): + exp_avg = exp_avg * beta1 + (1 - beta1) * grad + exp_avg_sq = exp_avg_sq * beta2 + (1 - beta2) * (grad * grad) + + adam_step = exp_avg / (exp_avg_sq.sqrt() + eps) + adam_step = adam_step + weight_decay * param + + weight_norm = param.norm(p=2).clamp(0, 10) + adam_norm = adam_step.norm(p=2) + + trust_ratio = weight_norm / (adam_norm + eps) + trust_ratio = (weight_norm == 0.0) * 1.0 + (weight_norm != 0.0) * trust_ratio + trust_ratio = (adam_norm == 0.0) * 1.0 + (adam_norm != 0.0) * trust_ratio + trust_ratio = trust_ratio.float() + + param = param - step_size * trust_ratio * adam_step + return param, exp_avg, exp_avg_sq + + +class JITLamb(Optimizer): + r"""Implements Lamb algorithm. + + It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. + + Arguments: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + adam (bool, optional): always use trust ratio = 1, which turns this into + Adam. Useful for comparison purposes. + + .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: + https://arxiv.org/abs/1904.00962 + """ + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, + weight_decay=0, adam=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + defaults = dict(lr=lr, betas=betas, eps=eps, + weight_decay=weight_decay) + self.adam = adam + super().__init__(params, defaults) + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError('Lamb does not support sparse gradients.') + + state = self.state[p] + + # State initialization + if len(state) == 0: + state['step'] = 0 + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state['exp_avg_sq'] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + state['step'] += 1 + step_size = group['lr'] + + param, exp_avg, exp_avg_sq = lamb_kernel(p.data, grad, exp_avg, + exp_avg_sq, beta1, + beta2, step_size, + group['eps'], + group['weight_decay'], + ) + state['exp_avg'] = exp_avg + state['exp_avg_sq'] = exp_avg_sq + p.data = param + + return loss diff --git a/src/clm/src/utils/optim/schedulers.py b/src/clm/src/utils/optim/schedulers.py new file mode 100644 index 00000000..35e6d877 --- /dev/null +++ b/src/clm/src/utils/optim/schedulers.py @@ -0,0 +1,87 @@ +"""Custom learning rate schedulers""" + +import math +import warnings +import torch + +from timm.scheduler import CosineLRScheduler + + +# https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html +class CosineWarmup(torch.optim.lr_scheduler.CosineAnnealingLR): + + def __init__(self, optimizer, T_max, eta_min=0, warmup_step=0, **kwargs): + self.warmup_step = warmup_step + super().__init__(optimizer, T_max - warmup_step, eta_min, *kwargs) + + # Copied from CosineAnnealingLR, but adding warmup and changing self.last_epoch to + # self.last_epoch - self.warmup_step. + def get_lr(self): + if not self._get_lr_called_within_step: + warnings.warn("To get the last learning rate computed by the scheduler, " + "please use `get_last_lr()`.", UserWarning) + + if self.last_epoch == self.warmup_step: # also covers the case where both are 0 + return self.base_lrs + elif self.last_epoch < self.warmup_step: + return [base_lr * (self.last_epoch + 1) / self.warmup_step for base_lr in self.base_lrs] + elif (self.last_epoch - self.warmup_step - 1 - self.T_max) % (2 * self.T_max) == 0: + return [group['lr'] + (base_lr - self.eta_min) * + (1 - math.cos(math.pi / self.T_max)) / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)] + return [(1 + math.cos(math.pi * (self.last_epoch - self.warmup_step) / self.T_max)) / + (1 + math.cos(math.pi * (self.last_epoch - self.warmup_step - 1) / self.T_max)) * + (group['lr'] - self.eta_min) + self.eta_min + for group in self.optimizer.param_groups] + + _get_closed_form_lr = None + + +def InvSqrt(optimizer, warmup_step): + """ Originally used for Transformer (in Attention is all you need) + """ + + def lr_lambda(step): + # return a multiplier instead of a learning rate + if step == warmup_step: # also covers the case where both are 0 + return 1. + else: + return 1. / (step ** 0.5) if step > warmup_step else (step + 1) / (warmup_step ** 1.5) + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) + + +def Constant(optimizer, warmup_step): + + def lr_lambda(step): + if step == warmup_step: # also covers the case where both are 0 + return 1. + else: + return 1. if step > warmup_step else (step + 1) / warmup_step + + return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) + + +class TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler): + """ Wrap timm.scheduler.CosineLRScheduler so we can call scheduler.step() without passing in epoch. + It supports resuming as well. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._last_epoch = -1 + self.step(epoch=0) + + def step(self, epoch=None): + if epoch is None: + self._last_epoch += 1 + else: + self._last_epoch = epoch + # We call either step or step_update, depending on whether we're using the scheduler every + # epoch or every step. + # Otherwise, lightning will always call step (i.e., meant for each epoch), and if we set + # scheduler interval to "step", then the learning rate update will be wrong. + if self.t_in_epochs: + super().step(epoch=self._last_epoch) + else: + super().step_update(num_updates=self._last_epoch) diff --git a/src/clm/src/utils/optim_groups.py b/src/clm/src/utils/optim_groups.py new file mode 100644 index 00000000..b935a8f3 --- /dev/null +++ b/src/clm/src/utils/optim_groups.py @@ -0,0 +1,144 @@ +"""Utilities for special optimizer hyperparameters. + +group_parameters_for_optimizer is a modification of timm's optimizer logic, which is currently unused +add_optimizer_hooks is an improved version that uses this codebase's _optim dictionary +""" + +import inspect + +import torch.nn as nn + +import hydra + + +def add_optimizer_hooks( + model, + bias_weight_decay=False, + normalization_weight_decay=False, +): + """Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with + attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False, for + normalization parameters if normalization_weight_decay==False + """ + + # Separate out all parameters to those that will and won't experience regularizing weight decay + blacklist_weight_modules = (nn.Embedding, ) + if not normalization_weight_decay: + blacklist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, + # Not compatible with Pytorch 1.8.1 + # nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d, + nn.GroupNorm, nn.SyncBatchNorm, + nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, + nn.LayerNorm, nn.LocalResponseNorm) + for mn, m in model.named_modules(): + for pn, p in m.named_parameters(): + if (not bias_weight_decay and pn.endswith('bias')) \ + or getattr(p, '_no_weight_decay', False) \ + or isinstance(m, blacklist_weight_modules): + setattr(p, "_optim", {"weight_decay": 0.0}) + + +def group_parameters_for_optimizer( + model, + optimizer_cfg, + bias_weight_decay=False, + normalization_weight_decay=False, +): + """Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with + attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False, for + normalization parameters if normalization_weight_decay==False + """ + # Get the weight decay from the config, or from the default value of the optimizer constructor + # if it's not specified in the config. + if 'weight_decay' in optimizer_cfg: + weight_decay = optimizer_cfg.weight_decay + else: + # https://stackoverflow.com/questions/12627118/get-a-function-arguments-default-value + signature = inspect.signature(hydra.utils.get_class(optimizer_cfg._target_)) + if 'weight_decay' in signature.parameters: + weight_decay = signature.parameters['weight_decay'].default + if weight_decay is inspect.Parameter.empty: + weight_decay = 0.0 + else: + weight_decay = 0.0 + + # If none of the parameters have weight decay anyway, and there are no parameters with special + # optimization params + if weight_decay == 0.0 and not any(hasattr(p, '_optim') for p in model.parameters()): + return model.parameters() + + skip = model.no_weight_decay() if hasattr(model, 'no_weight_decay') else set() + skip_keywords = (model.no_weight_decay_keywords() if hasattr(model, 'no_weight_decay_keywords') + else set()) + + # Adapted from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L134 + """ + This long function is unfortunately doing something very simple and is being very defensive: + We are separating out all parameters of the model into two buckets: those that will experience + weight decay for regularization and those that won't (biases, and layernorm/embedding weights). + We are then returning the PyTorch optimizer object. + """ + + # separate out all parameters to those that will and won't experience regularizing weight decay + decay = set() + no_decay = set() + special = set() + whitelist_weight_modules = (nn.Linear, ) + blacklist_weight_modules = (nn.Embedding, ) + if not normalization_weight_decay: + blacklist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, + # Not compatible with Pytorch 1.8.1 + # nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d, + nn.GroupNorm, nn.SyncBatchNorm, + nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, + nn.LayerNorm, nn.LocalResponseNorm) + for mn, m in model.named_modules(): + for pn, p in m.named_parameters(): + fpn = '%s.%s' % (mn, pn) if mn else pn # full param name + if not p.requires_grad: + continue # frozen weights + if hasattr(p, '_optim'): + special.add(fpn) + elif fpn in skip or any(skip_keyword in fpn for skip_keyword in skip_keywords): + no_decay.add(fpn) + elif getattr(p, '_no_weight_decay', False): + no_decay.add(fpn) + elif not bias_weight_decay and pn.endswith('bias'): + no_decay.add(fpn) + elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): + # weights of whitelist modules will be weight decayed + decay.add(fpn) + elif isinstance(m, blacklist_weight_modules): + # weights of blacklist modules will NOT be weight decayed + no_decay.add(fpn) + + param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} + # special case the position embedding parameter in the root GPT module as not decayed + if 'pos_emb' in param_dict: + no_decay.add('pos_emb') + + # In case of parameter sharing, some parameters show up in decay but are not in param_dict.keys() + decay &= param_dict.keys() + decay |= (param_dict.keys() - no_decay - special) + # validate that we considered every parameter + inter_params = decay & no_decay + union_params = decay | no_decay + assert len(inter_params) == 0, f"Parameters {str(inter_params)} made it into both decay/no_decay sets!" + assert len(param_dict.keys() - special - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" + + if weight_decay == 0.0 or not no_decay: + param_groups = [{"params": [param_dict[pn] for pn in sorted(list(no_decay | decay))], + "weight_decay": weight_decay}] + else: + param_groups = [ + {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, + {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + ] + # Add parameters with special hyperparameters + # Unique dicts + hps = [dict(s) for s in set(frozenset(param_dict[pn]._optim.items()) for pn in special)] + for hp in hps: + params = [param_dict[pn] for pn in sorted(list(special)) if param_dict[pn]._optim == hp] + param_groups.append({"params": params, **hp}) + + return param_groups diff --git a/src/clm/src/utils/permutations.py b/src/clm/src/utils/permutations.py new file mode 100644 index 00000000..b8f6a0d7 --- /dev/null +++ b/src/clm/src/utils/permutations.py @@ -0,0 +1,180 @@ +import math +import numpy as np +import torch + + +### Bit reversal permutation + +def bitreversal_po2(n): + m = int(math.log(n)/math.log(2)) + perm = np.arange(n).reshape(n,1) + for i in range(m): + n1 = perm.shape[0]//2 + perm = np.hstack((perm[:n1],perm[n1:])) + return perm.squeeze(0) + +def bitreversal_permutation(n): + m = int(math.ceil(math.log(n)/math.log(2))) + N = 1 << m + perm = bitreversal_po2(N) + return np.extract(perm < n, perm) + +def transpose_permutation(h, w): + indices = np.arange(h*w) + indices = indices.reshape((h, w)) + indices = indices.T + indices = indices.reshape(h*w) + return indices + +def snake_permutation(h, w): + indices = np.arange(h*w) + indices = indices.reshape((h, w)) + indices[1::2, :] = indices[1::2, ::-1] + indices = indices.reshape(h*w) + return indices + +def hilbert_permutation(n): + m = int(math.log2(n)) + assert n == 2**m + inds = decode(list(range(n*n)), 2, m) + ind_x, ind_y = inds.T + indices = np.arange(n*n).reshape((n, n)) + indices = indices[ind_x, ind_y] + return(indices) + +""" Hilbert curve utilities taken from https://github.com/PrincetonLIPS/numpy-hilbert-curve """ +def decode(hilberts, num_dims, num_bits): + ''' Decode an array of Hilbert integers into locations in a hypercube. + This is a vectorized-ish version of the Hilbert curve implementation by John + Skilling as described in: + Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference + Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics. + Params: + ------- + hilberts - An ndarray of Hilbert integers. Must be an integer dtype and + cannot have fewer bits than num_dims * num_bits. + num_dims - The dimensionality of the hypercube. Integer. + num_bits - The number of bits for each dimension. Integer. + Returns: + -------- + The output is an ndarray of unsigned integers with the same shape as hilberts + but with an additional dimension of size num_dims. + ''' + + if num_dims*num_bits > 64: + raise ValueError( + ''' + num_dims=%d and num_bits=%d for %d bits total, which can't be encoded + into a uint64. Are you sure you need that many points on your Hilbert + curve? + ''' % (num_dims, num_bits) + ) + + # Handle the case where we got handed a naked integer. + hilberts = np.atleast_1d(hilberts) + + # Keep around the shape for later. + orig_shape = hilberts.shape + + # Treat each of the hilberts as a sequence of eight uint8. + # This treats all of the inputs as uint64 and makes things uniform. + hh_uint8 = np.reshape(hilberts.ravel().astype('>u8').view(np.uint8), (-1, 8)) + + # Turn these lists of uints into lists of bits and then truncate to the size + # we actually need for using Skilling's procedure. + hh_bits = np.unpackbits(hh_uint8, axis=1)[:,-num_dims*num_bits:] + + # Take the sequence of bits and Gray-code it. + gray = binary2gray(hh_bits) + + # There has got to be a better way to do this. + # I could index them differently, but the eventual packbits likes it this way. + gray = np.swapaxes( + np.reshape(gray, (-1, num_bits, num_dims)), + axis1=1, axis2=2, + ) + + # Iterate backwards through the bits. + for bit in range(num_bits-1, -1, -1): + + # Iterate backwards through the dimensions. + for dim in range(num_dims-1, -1, -1): + + # Identify which ones have this bit active. + mask = gray[:,dim,bit] + + # Where this bit is on, invert the 0 dimension for lower bits. + gray[:,0,bit+1:] = np.logical_xor(gray[:,0,bit+1:], mask[:,np.newaxis]) + + # Where the bit is off, exchange the lower bits with the 0 dimension. + to_flip = np.logical_and( + np.logical_not(mask[:,np.newaxis]), + np.logical_xor(gray[:,0,bit+1:], gray[:,dim,bit+1:]) + ) + gray[:,dim,bit+1:] = np.logical_xor(gray[:,dim,bit+1:], to_flip) + gray[:,0,bit+1:] = np.logical_xor(gray[:,0,bit+1:], to_flip) + + # Pad back out to 64 bits. + extra_dims = 64 - num_bits + padded = np.pad(gray, ((0,0), (0,0), (extra_dims,0)), + mode='constant', constant_values=0) + + # Now chop these up into blocks of 8. + locs_chopped = np.reshape(padded[:,:,::-1], (-1, num_dims, 8, 8)) + + # Take those blocks and turn them unto uint8s. + locs_uint8 = np.squeeze(np.packbits(locs_chopped, bitorder='little', axis=3)) + + # Finally, treat these as uint64s. + flat_locs = locs_uint8.view(np.uint64) + + # Return them in the expected shape. + return np.reshape(flat_locs, (*orig_shape, num_dims)) + +def right_shift(binary, k=1, axis=-1): + ''' Right shift an array of binary values. + Parameters: + ----------- + binary: An ndarray of binary values. + k: The number of bits to shift. Default 1. + axis: The axis along which to shift. Default -1. + Returns: + -------- + Returns an ndarray with zero prepended and the ends truncated, along + whatever axis was specified. +''' + + # If we're shifting the whole thing, just return zeros. + if binary.shape[axis] <= k: + return np.zeros_like(binary) + + # Determine the padding pattern. + padding = [(0,0)] * len(binary.shape) + padding[axis] = (k,0) + + # Determine the slicing pattern to eliminate just the last one. + slicing = [slice(None)] * len(binary.shape) + slicing[axis] = slice(None, -k) + + shifted = np.pad(binary[tuple(slicing)], padding, + mode='constant', constant_values=0) + + return shifted + +def binary2gray(binary, axis=-1): + ''' Convert an array of binary values into Gray codes. + This uses the classic X ^ (X >> 1) trick to compute the Gray code. + Parameters: + ----------- + binary: An ndarray of binary values. + axis: The axis along which to compute the gray code. Default=-1. + Returns: + -------- + Returns an ndarray of Gray codes. + ''' + shifted = right_shift(binary, axis=axis) + + # Do the X ^ (X >> 1) trick. + gray = np.logical_xor(binary, shifted) + + return gray diff --git a/src/clm/src/utils/registry.py b/src/clm/src/utils/registry.py new file mode 100644 index 00000000..7943bdcc --- /dev/null +++ b/src/clm/src/utils/registry.py @@ -0,0 +1,53 @@ +optimizer = { + "adam": "torch.optim.Adam", + "adamw": "torch.optim.AdamW", + "rmsprop": "torch.optim.RMSprop", + "sgd": "torch.optim.SGD", + "lamb": "src.utils.optim.lamb.JITLamb", +} + +scheduler = { + "constant": "transformers.get_constant_schedule", + "plateau": "torch.optim.lr_scheduler.ReduceLROnPlateau", + "step": "torch.optim.lr_scheduler.StepLR", + "multistep": "torch.optim.lr_scheduler.MultiStepLR", + "cosine": "torch.optim.lr_scheduler.CosineAnnealingLR", + "constant_warmup": "transformers.get_constant_schedule_with_warmup", + "linear_warmup": "transformers.get_linear_schedule_with_warmup", + "cosine_warmup": "transformers.get_cosine_schedule_with_warmup", + "cosine_warmup_timm": "src.utils.optim.schedulers.TimmCosineLRScheduler", +} + +model = { + # Backbones from this repo + "model": "src.models.sequence.SequenceModel", + "lm": "src.models.sequence.long_conv_lm.ConvLMHeadModel", + "lm_simple": "src.models.sequence.simple_lm.SimpleLMHeadModel", + "vit_b_16": "src.models.baselines.vit_all.vit_base_patch16_224", +} + +layer = { + "id": "src.models.sequence.base.SequenceIdentity", + "ff": "src.models.sequence.ff.FF", + "mha": "src.models.sequence.mha.MultiheadAttention", + "s4d": "src.models.sequence.ssm.s4d.S4D", + "s4_simple": "src.models.sequence.ssm.s4_simple.SimpleS4Wrapper", + "long-conv": "src.models.sequence.long_conv.LongConv", + "h3": "src.models.sequence.h3.H3", + "h3-conv": "src.models.sequence.h3_conv.H3Conv", + "hyena": "src.models.sequence.hyena.HyenaOperator", + "hyena-filter": "src.models.sequence.hyena.HyenaFilter", + "vit": "src.models.sequence.mha.VitAttention", +} + +callbacks = { + "timer": "src.callbacks.timer.Timer", + "params": "src.callbacks.params.ParamsLog", + "learning_rate_monitor": "pytorch_lightning.callbacks.LearningRateMonitor", + "model_checkpoint": "pytorch_lightning.callbacks.ModelCheckpoint", + "early_stopping": "pytorch_lightning.callbacks.EarlyStopping", + "swa": "pytorch_lightning.callbacks.StochasticWeightAveraging", + "rich_model_summary": "pytorch_lightning.callbacks.RichModelSummary", + "rich_progress_bar": "pytorch_lightning.callbacks.RichProgressBar", + "progressive_resizing": "src.callbacks.progressive_resizing.ProgressiveResizing", +} diff --git a/src/clm/src/utils/train.py b/src/clm/src/utils/train.py new file mode 100644 index 00000000..12e5dbb4 --- /dev/null +++ b/src/clm/src/utils/train.py @@ -0,0 +1,156 @@ +""" Utils for the training loop. Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py """ +import logging +import os +import warnings +from typing import List, Sequence + +import torch.nn as nn +import pytorch_lightning as pl +import rich.syntax +import rich.tree +from omegaconf import DictConfig, OmegaConf +from pytorch_lightning.utilities import rank_zero_only + +from clm.src.utils.config import omegaconf_filter_keys + + +# Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging +class LoggingContext: + def __init__(self, logger, level=None, handler=None, close=True): + self.logger = logger + self.level = level + self.handler = handler + self.close = close + + def __enter__(self): + if self.level is not None: + self.old_level = self.logger.level + self.logger.setLevel(self.level) + if self.handler: + self.logger.addHandler(self.handler) + + def __exit__(self, et, ev, tb): + if self.level is not None: + self.logger.setLevel(self.old_level) + if self.handler: + self.logger.removeHandler(self.handler) + if self.handler and self.close: + self.handler.close() + # implicit return of None => don't swallow exceptions + + +def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: + """Initializes multi-GPU-friendly python logger.""" + + logger = logging.getLogger(name) + logger.setLevel(level) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"): + setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger + + +def process_config(config: DictConfig) -> DictConfig: # TODO because of filter_keys, this is no longer in place + """A couple of optional utilities, controlled by main config file: + - disabling warnings + - easier access to debug mode + - forcing debug friendly configuration + Modifies DictConfig in place. + Args: + config (DictConfig): Configuration composed by Hydra. + """ + log = get_logger() + + # Filter out keys that were used just for interpolation + # config = dictconfig_filter_keys(config, lambda k: not k.startswith('__')) + config = omegaconf_filter_keys(config, lambda k: not k.startswith('__')) + + # enable adding new keys to config + OmegaConf.set_struct(config, False) + + # disable python warnings if + if config.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + if config.get("debug"): + log.info("Running in debug mode! ") + config.trainer.fast_dev_run = True + + # force debugger friendly configuration + log.info("Forcing debugger friendly configuration! ") + # Debuggers don't like GPUs or multiprocessing + if config.trainer.get("gpus"): + config.trainer.gpus = 0 + if config.loader.get("pin_memory"): + config.loader.pin_memory = False + if config.loader.get("num_workers"): + config.loader.num_workers = 0 + + # disable adding new keys to config + # OmegaConf.set_struct(config, True) # [21-09-17 AG] I need this for .pop(_name_) pattern among other things + + return config + +@rank_zero_only +def print_config( + config: DictConfig, + resolve: bool = True, + save_cfg=True, +) -> None: + """Prints content of DictConfig using Rich library and its tree structure. + Args: + config (DictConfig): Configuration composed by Hydra. + fields (Sequence[str], optional): Determines which main fields from config will + be printed and in what order. + resolve (bool, optional): Whether to resolve reference fields of DictConfig. + """ + + style = "dim" + tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) + + fields = config.keys() + for field in fields: + branch = tree.add(field, style=style, guide_style=style) + + config_section = config.get(field) + branch_content = str(config_section) + if isinstance(config_section, DictConfig): + branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + rich.print(tree) + + if save_cfg: + with open("config_tree.txt", "w") as fp: + rich.print(tree, file=fp) + +def log_optimizer(logger, optimizer, keys): + """ Log values of particular keys from the optimizer's param groups """ + keys = sorted(keys) + for i, g in enumerate(optimizer.param_groups): + group_hps = {k: g.get(k, None) for k in keys} + logger.info(' | '.join([ + f"Optimizer group {i}", + f"{len(g['params'])} tensors", + ] + [f"{k} {v}" for k, v in group_hps.items()])) + +class OptimModule(nn.Module): + """ Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters """ + + def register(self, name, tensor, lr=None, wd=0.0): + """Register a tensor with a configurable learning rate and 0 weight decay""" + + if lr == 0.0: + self.register_buffer(name, tensor) + else: + self.register_parameter(name, nn.Parameter(tensor)) + + optim = {} + if lr is not None: optim["lr"] = lr + if wd is not None: optim["weight_decay"] = wd + setattr(getattr(self, name), "_optim", optim) \ No newline at end of file diff --git a/workflow/config/config-spectraverse-allv1-s4_cv.yaml b/workflow/config/config-spectraverse-allv1-s4_cv.yaml new file mode 100644 index 00000000..cfa201b5 --- /dev/null +++ b/workflow/config/config-spectraverse-allv1-s4_cv.yaml @@ -0,0 +1,196 @@ +# Molecular sequence representations of chemical species for training and sampling. +# Determines how the molecules are encoded internally. +# The only avaiable option for now is 'SMILES'. +representations: + - SMILES + +# The number of cross-validation folds. +# The dataset is split into train/test set for each fold, and models are trained/tested on each fold. +folds: 10 + +# Seeds used to initialize random number generators for the training runs. +# Each seed corresponds to a separate training run. +# Each fold trains 'train_seeds' number of models on the training set for that fold. +train_seeds: + - 0 + +# Seeds used when sampling molecules from the trained models. +# The number of 'sample_seeds' values specifies how many times the 'sample_molecules_RNN' step is executed, +# each time using the same trained model but with different random seed values. +sample_seeds: + - 0 + +# Specifies by how many times the input data is augmented (or enumerated) before training. +# Augmentation refers to the fact that a single molecule can have multiple SMILES representation. +# For example: +# - A value of 0 means no augmentation, leaving the input data unchanged. +# - A value of 100 means each molecule can have up to 100 different SMILES representations in the training set. +# Note: Both 0 and 1 indicate no augmentation, but with 1, the representations are updated to be different +# than those provided in the original dataset. +enum_factors: + - 0 + +# Limits the maximum number of input SMILES to be read from the original dataset. 0 means there's no limit. +max_input_smiles: 0 + +# A dictionary defining the arguments to be passed to the preprocess command. +preprocess: + # Specifies the minimum number of heavy atoms that a valid molecule should have. + min_heavy_atoms: 3 + # Defines the set of elements required for a molecule to be considered valid. + # Any SMILES containing elements outside this set will be considered invalid and excluded from the training set. + valid_atoms: [Br, C, Cl, F, H, I, N, O, P, S, Se, Si, B, As] + # Specifies whether the charges in the training SMILES should be neutralized. + neutralise: false + # Specifies whether to remove training SMILES representing molecules with tokens found in less than 0.01% of samples or fewer than 10 molecules. + remove_rare: false + # Specifies whether to remove duplicate SMILES from the training set, identifying duplicates by inchikey. + keep_duplicates: false + +# Parameters that define the neural network model and training process. +model_params: + # Type of Recurrent Neural Network (RNN) to use. + # Available options are 'LSTM' and 'GRU' + rnn_type: S4 + embedding_size: 128 # Size of the embedding vectors that represent each token in the input sequence. + hidden_size: 256 # Size of the hidden state of the RNN. + n_layers: 2 # Number of stacked RNN layers in the model. + dropout: 0 # Dropout rate applied to the RNN layer for regularization. + batch_size: 64 # Number of samples processed before the models internal parameters are updated. + learning_rate: 0.001 # Used by the optimizer to update model parameters. + max_epochs: 999999 # Maximum number of training epochs (complete passes through the training dataset). + patience: 50000 # Number of steps with no improvement in the validation loss after which early stopping is triggered. + + # An RNN model conditioned on input descriptors (experimentally obtained properties of the input SMILES). + # Note that rnn_type and other RNN architecture parameters are still applicable in this case. + conditional: + # Is the conditional model enabled? + enabled: false + + # Note: Both emb and emb_l below cannot be true at the same time. + # Concatenate the descriptors directly to the token embeddings at each step in the sequence? + emb: false + # Concatenate the descriptors to the token embeddings, but by first passing them through a + # linear layer to obtain embeddings of dimensionality equal to that of the token embeddings? + emb_l: true + + # Note: Both dec and dec_l below cannot be true at the same time. + # Concatenate the descriptors directly to the output of the RNN layers + # (prior to the decoder layer)? + dec: false + # Concatenate the descriptors to the output of the RNN layers + # (prior to the decoder layer), but by first passing them through a + # linear layer to obtain embeddings of dimensionality equal to that of the token embeddings? + dec_l: true + + # Instantiate the hidden states based on learned transformations of the descriptors + # (with a single linear layer), as in Kotsias et al? + h: false + + # Frequency of logging training progress in terms of steps (batches). + log_every_steps: 100 + # Frequency of logging training progress in terms of epochs. + log_every_epochs: 1 + # Number of molecules to sample from the trained model after training. + sample_mols: 1000000 + +# When looking at sampled molecules across all folds, what metric(s) do we +# use for aggregating frequencies? +metrics: + # With what frequency (across all folds) was each valid molecule produced? + # - freq-sum + # With what average frequency (across all folds) was each valid molecule produced? + - freq-avg + # With what average frequency (across all folds) was each valid molecule produced, + # as a fraction of total sampling frequency (x 10e3 to avoid ~0 values) + # - fp10k + +# Minimum Tanimoto coefficient threshold to filter out molecules from training set. +# This allows for only similar SMILES to be considered from the preprocessed dataset +# for the creation of training/ testing folds, (with or without augmentation). +# 0 (default) means no filtering based on Tanimoto similarity. +min_tc: 0 + +# Number of top candidate molecules to consider when evaluating correctness. +# Here, correctness is defined as an exact mass match within a specified error range. +# Example: +# A value of 30 means that the 30 most frequently generated molecules with a mass +# matching the target molecule's mass within the allowed error margin are considered +# for further evaluation. +top_k: 30 + +# Error tolerance in parts per million for mass-matching to consider a molecule "correct". +# Used in rules that evaluate the correctness of generated or sampled molecules against +# known test molecules based on mass. +err_ppm: 10 + +# Specifies minimum frequency thresholds for inclusion. +# Each value represents the minimum number of times a molecule must be generated +# across all folds to be considered for further evaluation. +structural_prior_min_freq: + - 1 + +# seed used as a global random seed for steps not covered by 'train_seeds' or 'sample_seeds'. +random_seed: 5831 + +# A dictionary that defines various input and output file paths, incorporating wildcards. +paths: + # Modify these paths to match your system. + + # Base directory for outputs + output_dir: '/Genomics/argo/users/vg8892/git/CLM/workflow/data_spectraverse_allv1_s4_cv' + # The input dataset file. + dataset: "/Genomics/argo/users/vg8892/git/CLM/data/spectraverse_allv1.txt" + # File containing data from PubChem. + pubchem_tsv_file: "../tests/test_data/PubChem_truncated.tsv" + + # The following paths can be modified, as long as all wildcards are preserved in each case. + + # Processed dataset before augmentation/training. + preprocess_output: "{output_dir}/prior/raw/{dataset}.txt" + # Training file for each cross-validation fold. + train_file: "{output_dir}/{enum_factor}/prior/inputs/train_{dataset}_{repr}_{fold}.smi" + # Vocabulary file for the tokenized sequences. + vocab_file: "{output_dir}/{enum_factor}/prior/inputs/train_{dataset}_{repr}_{fold}.vocabulary" + # Trained RNN model checkpoint file. + model_file: "{output_dir}/{enum_factor}/prior/models/{dataset}_{repr}_{fold}_{train_seed}_model.pt" + # Sampled dataset for each fold. + input_file: "{output_dir}/{enum_factor}/prior/samples/{dataset}_{repr}_{fold}_{train_seed}_{sample_seed}_samples.csv.gz" + # Unaugmented training dataset for each cross-validation fold. + train0_file: "{output_dir}/{enum_factor}/prior/inputs/train0_{dataset}_{repr}_{fold}.smi" + # Unaugmented test dataset for each cross-validation fold. + test0_file: "{output_dir}/{enum_factor}/prior/inputs/test0_{dataset}_{repr}_{fold}.smi" + # A file generated by add_carbon rule, inserting carbon symbols at random spots in training SMILES. + # carbon_file: "{output_dir}/{enum_factor}/prior/inputs/train0_{dataset}_{repr}_{fold}_carbon.csv.gz" + # Complete training dataset aggregated across all folds. + train_all_file: "{output_dir}/{enum_factor}/prior/inputs/train_{dataset}_{repr}_all.smi" + # Complete testing dataset aggregated across all folds. + test_all_file: "{output_dir}/{enum_factor}/prior/inputs/test_{dataset}_{repr}_all.smi" + # Complete aggregated SMILES from add_carbon rule across all folds. + # carbon_all_file: "{output_dir}/{enum_factor}/prior/inputs/train_{dataset}_{repr}_carbon_all.csv.gz" + # Top-n candidate SMILES based on matching by exact mass for each cross-validation fold. + cv_ranks_file: "{output_dir}/{enum_factor}/prior/structural_prior/{dataset}_{repr}_{fold}_CV_ranks_structure.csv.gz" + # Top-n candidate SMILES based on matching mass including Tanimoto coefficient for each cross-validation fold. + cv_tc_file: "{output_dir}/{enum_factor}/prior/structural_prior/{dataset}_{repr}_{fold}_CV_tc.csv.gz" + # Top-n candidate SMILES (correctness based on formula rather than structure) for each cross-validation fold. + formula_ranks_file: "{output_dir}/{enum_factor}/prior/structural_prior/{dataset}_{repr}_{fold}_CV_ranks_formula.csv.gz" + # Sampled SMILES aggregated across all folds. + process_tabulated_output: "{output_dir}/{enum_factor}/prior/samples/{dataset}_{repr}_processed_min{min_freq}_{metric}.csv.gz" + # Loss curves from model training. + loss_file: "{output_dir}/{enum_factor}/prior/models/{dataset}_{repr}_{fold}_{train_seed}_loss.csv.gz" + # Novel SMILES generated by each trained model, along with inchikey, mass and formula. + tabulate_molecules_output: "{output_dir}/{enum_factor}/prior/samples/{dataset}_{repr}_{fold}_{train_seed}_{sample_seed}_samples_masses.csv.gz" + # Aggregated sampled SMILES from all the trained models in a fold. + collect_tabulated_output: "{output_dir}/{enum_factor}/prior/samples/{dataset}_{repr}_{fold}_unique_masses.csv.gz" + # Top-n candidate SMILES based on matching mass across all folds. + overall_ranks_file: "{output_dir}/{enum_factor}/prior/structural_prior/{dataset}_{repr}_min{min_freq}_all_{metric}_CV_ranks_structure.csv.gz" + # Top-n candidate SMILES based on matching mass including tc per fold. + overall_tc_file: "{output_dir}/{enum_factor}/prior/structural_prior/{dataset}_{repr}_min{min_freq}_all_{metric}_CV_tc.csv.gz" + # Sampled molecules per trained model that appear in training set. + known_smiles_file: "{output_dir}/{enum_factor}/prior/samples/known_{dataset}_{repr}_{fold}_{train_seed}_{sample_seed}_samples_masses.csv.gz" + # Invalid SMILES sampled per trained model. + invalid_smiles_file: "{output_dir}/{enum_factor}/prior/samples/invalid_{dataset}_{repr}_{fold}_{train_seed}_{sample_seed}_samples_masses.csv.gz" + # Known (training set) sampled molecules within a fold. + collect_known_smiles: "{output_dir}/{enum_factor}/prior/samples/known_{dataset}_{repr}_{fold}_unique_masses.csv.gz" + # Invalid sampled SMILES within a fold. + collect_invalid_smiles: "{output_dir}/{enum_factor}/prior/samples/invalid_{dataset}_{repr}_{fold}_unique_masses.csv.gz" diff --git a/workflow/config/config-spectraverse-allv1-transformer_cv.yaml b/workflow/config/config-spectraverse-allv1-transformer_cv.yaml new file mode 100644 index 00000000..d96c6c9d --- /dev/null +++ b/workflow/config/config-spectraverse-allv1-transformer_cv.yaml @@ -0,0 +1,196 @@ +# Molecular sequence representations of chemical species for training and sampling. +# Determines how the molecules are encoded internally. +# The only avaiable option for now is 'SMILES'. +representations: + - SMILES + +# The number of cross-validation folds. +# The dataset is split into train/test set for each fold, and models are trained/tested on each fold. +folds: 10 + +# Seeds used to initialize random number generators for the training runs. +# Each seed corresponds to a separate training run. +# Each fold trains 'train_seeds' number of models on the training set for that fold. +train_seeds: + - 0 + +# Seeds used when sampling molecules from the trained models. +# The number of 'sample_seeds' values specifies how many times the 'sample_molecules_RNN' step is executed, +# each time using the same trained model but with different random seed values. +sample_seeds: + - 0 + +# Specifies by how many times the input data is augmented (or enumerated) before training. +# Augmentation refers to the fact that a single molecule can have multiple SMILES representation. +# For example: +# - A value of 0 means no augmentation, leaving the input data unchanged. +# - A value of 100 means each molecule can have up to 100 different SMILES representations in the training set. +# Note: Both 0 and 1 indicate no augmentation, but with 1, the representations are updated to be different +# than those provided in the original dataset. +enum_factors: + - 0 + +# Limits the maximum number of input SMILES to be read from the original dataset. 0 means there's no limit. +max_input_smiles: 0 + +# A dictionary defining the arguments to be passed to the preprocess command. +preprocess: + # Specifies the minimum number of heavy atoms that a valid molecule should have. + min_heavy_atoms: 3 + # Defines the set of elements required for a molecule to be considered valid. + # Any SMILES containing elements outside this set will be considered invalid and excluded from the training set. + valid_atoms: [Br, C, Cl, F, H, I, N, O, P, S, Se, Si, B, As] + # Specifies whether the charges in the training SMILES should be neutralized. + neutralise: false + # Specifies whether to remove training SMILES representing molecules with tokens found in less than 0.01% of samples or fewer than 10 molecules. + remove_rare: false + # Specifies whether to remove duplicate SMILES from the training set, identifying duplicates by inchikey. + keep_duplicates: false + +# Parameters that define the neural network model and training process. +model_params: + # Type of Recurrent Neural Network (RNN) to use. + # Available options are 'LSTM' and 'GRU' + rnn_type: Transformer + embedding_size: 128 # Size of the embedding vectors that represent each token in the input sequence. + hidden_size: 256 # Size of the hidden state of the RNN. + n_layers: 2 # Number of stacked RNN layers in the model. + dropout: 0 # Dropout rate applied to the RNN layer for regularization. + batch_size: 64 # Number of samples processed before the models internal parameters are updated. + learning_rate: 0.001 # Used by the optimizer to update model parameters. + max_epochs: 999999 # Maximum number of training epochs (complete passes through the training dataset). + patience: 50000 # Number of steps with no improvement in the validation loss after which early stopping is triggered. + + # An RNN model conditioned on input descriptors (experimentally obtained properties of the input SMILES). + # Note that rnn_type and other RNN architecture parameters are still applicable in this case. + conditional: + # Is the conditional model enabled? + enabled: false + + # Note: Both emb and emb_l below cannot be true at the same time. + # Concatenate the descriptors directly to the token embeddings at each step in the sequence? + emb: false + # Concatenate the descriptors to the token embeddings, but by first passing them through a + # linear layer to obtain embeddings of dimensionality equal to that of the token embeddings? + emb_l: true + + # Note: Both dec and dec_l below cannot be true at the same time. + # Concatenate the descriptors directly to the output of the RNN layers + # (prior to the decoder layer)? + dec: false + # Concatenate the descriptors to the output of the RNN layers + # (prior to the decoder layer), but by first passing them through a + # linear layer to obtain embeddings of dimensionality equal to that of the token embeddings? + dec_l: true + + # Instantiate the hidden states based on learned transformations of the descriptors + # (with a single linear layer), as in Kotsias et al? + h: false + + # Frequency of logging training progress in terms of steps (batches). + log_every_steps: 100 + # Frequency of logging training progress in terms of epochs. + log_every_epochs: 1 + # Number of molecules to sample from the trained model after training. + sample_mols: 1000000 + +# When looking at sampled molecules across all folds, what metric(s) do we +# use for aggregating frequencies? +metrics: + # With what frequency (across all folds) was each valid molecule produced? + # - freq-sum + # With what average frequency (across all folds) was each valid molecule produced? + - freq-avg + # With what average frequency (across all folds) was each valid molecule produced, + # as a fraction of total sampling frequency (x 10e3 to avoid ~0 values) + # - fp10k + +# Minimum Tanimoto coefficient threshold to filter out molecules from training set. +# This allows for only similar SMILES to be considered from the preprocessed dataset +# for the creation of training/ testing folds, (with or without augmentation). +# 0 (default) means no filtering based on Tanimoto similarity. +min_tc: 0 + +# Number of top candidate molecules to consider when evaluating correctness. +# Here, correctness is defined as an exact mass match within a specified error range. +# Example: +# A value of 30 means that the 30 most frequently generated molecules with a mass +# matching the target molecule's mass within the allowed error margin are considered +# for further evaluation. +top_k: 30 + +# Error tolerance in parts per million for mass-matching to consider a molecule "correct". +# Used in rules that evaluate the correctness of generated or sampled molecules against +# known test molecules based on mass. +err_ppm: 10 + +# Specifies minimum frequency thresholds for inclusion. +# Each value represents the minimum number of times a molecule must be generated +# across all folds to be considered for further evaluation. +structural_prior_min_freq: + - 1 + +# seed used as a global random seed for steps not covered by 'train_seeds' or 'sample_seeds'. +random_seed: 5831 + +# A dictionary that defines various input and output file paths, incorporating wildcards. +paths: + # Modify these paths to match your system. + + # Base directory for outputs + output_dir: '/Genomics/argo/users/vg8892/git/CLM/workflow/data_spectraverse_allv1_transformer_cv' + # The input dataset file. + dataset: "/Genomics/argo/users/vg8892/git/CLM/data/spectraverse_allv1.txt" + # File containing data from PubChem. + pubchem_tsv_file: "../tests/test_data/PubChem_truncated.tsv" + + # The following paths can be modified, as long as all wildcards are preserved in each case. + + # Processed dataset before augmentation/training. + preprocess_output: "{output_dir}/prior/raw/{dataset}.txt" + # Training file for each cross-validation fold. + train_file: "{output_dir}/{enum_factor}/prior/inputs/train_{dataset}_{repr}_{fold}.smi" + # Vocabulary file for the tokenized sequences. + vocab_file: "{output_dir}/{enum_factor}/prior/inputs/train_{dataset}_{repr}_{fold}.vocabulary" + # Trained RNN model checkpoint file. + model_file: "{output_dir}/{enum_factor}/prior/models/{dataset}_{repr}_{fold}_{train_seed}_model.pt" + # Sampled dataset for each fold. + input_file: "{output_dir}/{enum_factor}/prior/samples/{dataset}_{repr}_{fold}_{train_seed}_{sample_seed}_samples.csv.gz" + # Unaugmented training dataset for each cross-validation fold. + train0_file: "{output_dir}/{enum_factor}/prior/inputs/train0_{dataset}_{repr}_{fold}.smi" + # Unaugmented test dataset for each cross-validation fold. + test0_file: "{output_dir}/{enum_factor}/prior/inputs/test0_{dataset}_{repr}_{fold}.smi" + # A file generated by add_carbon rule, inserting carbon symbols at random spots in training SMILES. + # carbon_file: "{output_dir}/{enum_factor}/prior/inputs/train0_{dataset}_{repr}_{fold}_carbon.csv.gz" + # Complete training dataset aggregated across all folds. + train_all_file: "{output_dir}/{enum_factor}/prior/inputs/train_{dataset}_{repr}_all.smi" + # Complete testing dataset aggregated across all folds. + test_all_file: "{output_dir}/{enum_factor}/prior/inputs/test_{dataset}_{repr}_all.smi" + # Complete aggregated SMILES from add_carbon rule across all folds. + # carbon_all_file: "{output_dir}/{enum_factor}/prior/inputs/train_{dataset}_{repr}_carbon_all.csv.gz" + # Top-n candidate SMILES based on matching by exact mass for each cross-validation fold. + cv_ranks_file: "{output_dir}/{enum_factor}/prior/structural_prior/{dataset}_{repr}_{fold}_CV_ranks_structure.csv.gz" + # Top-n candidate SMILES based on matching mass including Tanimoto coefficient for each cross-validation fold. + cv_tc_file: "{output_dir}/{enum_factor}/prior/structural_prior/{dataset}_{repr}_{fold}_CV_tc.csv.gz" + # Top-n candidate SMILES (correctness based on formula rather than structure) for each cross-validation fold. + formula_ranks_file: "{output_dir}/{enum_factor}/prior/structural_prior/{dataset}_{repr}_{fold}_CV_ranks_formula.csv.gz" + # Sampled SMILES aggregated across all folds. + process_tabulated_output: "{output_dir}/{enum_factor}/prior/samples/{dataset}_{repr}_processed_min{min_freq}_{metric}.csv.gz" + # Loss curves from model training. + loss_file: "{output_dir}/{enum_factor}/prior/models/{dataset}_{repr}_{fold}_{train_seed}_loss.csv.gz" + # Novel SMILES generated by each trained model, along with inchikey, mass and formula. + tabulate_molecules_output: "{output_dir}/{enum_factor}/prior/samples/{dataset}_{repr}_{fold}_{train_seed}_{sample_seed}_samples_masses.csv.gz" + # Aggregated sampled SMILES from all the trained models in a fold. + collect_tabulated_output: "{output_dir}/{enum_factor}/prior/samples/{dataset}_{repr}_{fold}_unique_masses.csv.gz" + # Top-n candidate SMILES based on matching mass across all folds. + overall_ranks_file: "{output_dir}/{enum_factor}/prior/structural_prior/{dataset}_{repr}_min{min_freq}_all_{metric}_CV_ranks_structure.csv.gz" + # Top-n candidate SMILES based on matching mass including tc per fold. + overall_tc_file: "{output_dir}/{enum_factor}/prior/structural_prior/{dataset}_{repr}_min{min_freq}_all_{metric}_CV_tc.csv.gz" + # Sampled molecules per trained model that appear in training set. + known_smiles_file: "{output_dir}/{enum_factor}/prior/samples/known_{dataset}_{repr}_{fold}_{train_seed}_{sample_seed}_samples_masses.csv.gz" + # Invalid SMILES sampled per trained model. + invalid_smiles_file: "{output_dir}/{enum_factor}/prior/samples/invalid_{dataset}_{repr}_{fold}_{train_seed}_{sample_seed}_samples_masses.csv.gz" + # Known (training set) sampled molecules within a fold. + collect_known_smiles: "{output_dir}/{enum_factor}/prior/samples/known_{dataset}_{repr}_{fold}_unique_masses.csv.gz" + # Invalid sampled SMILES within a fold. + collect_invalid_smiles: "{output_dir}/{enum_factor}/prior/samples/invalid_{dataset}_{repr}_{fold}_unique_masses.csv.gz" From 0a42001afc4e07131a47a97a85195279dfe30ad6 Mon Sep 17 00:00:00 2001 From: Vishu Gupta Date: Sat, 13 Dec 2025 14:58:09 -0500 Subject: [PATCH 02/21] temprorily removed code for H3, H3Conv, Hyena --- src/clm/src/__init__.py | 0 src/clm/src/callbacks/norms.py | 39 - src/clm/src/callbacks/params.py | 37 - src/clm/src/callbacks/progressive_resizing.py | 118 --- src/clm/src/callbacks/timer.py | 100 --- src/clm/src/callbacks/wandb.py | 277 ------- src/clm/src/dataloaders/README.md | 40 - src/clm/src/dataloaders/__init__.py | 2 - src/clm/src/dataloaders/base.py | 276 ------- src/clm/src/dataloaders/basic.py | 271 ------- .../src/dataloaders/datasets/detokenizer.py | 53 -- .../src/dataloaders/datasets/lm_dataset.py | 32 - src/clm/src/dataloaders/et.py | 626 ---------------- .../src/dataloaders/fault_tolerant_sampler.py | 123 ---- .../src/dataloaders/language_modeling_hf.py | 311 -------- src/clm/src/dataloaders/lm.py | 507 ------------- src/clm/src/dataloaders/lra.py | 689 ------------------ src/clm/src/dataloaders/synthetics.py | 335 --------- .../dataloaders/utils/cifar_augmentations.py | 138 ---- src/clm/src/dataloaders/utils/timm_mixup.py | 22 - src/clm/src/dataloaders/utils/vocabulary.py | 237 ------ src/clm/src/dataloaders/vision.py | 279 ------- src/clm/src/models/__init__.py | 0 src/clm/src/models/baselines/vit_all.py | 433 ----------- src/clm/src/models/nn/__init__.py | 1 - src/clm/src/models/nn/adaptive_softmax.py | 404 ---------- src/clm/src/models/nn/components.py | 389 ---------- src/clm/src/models/nn/dxt.py | 196 ----- src/clm/src/models/nn/gate.py | 128 ---- src/clm/src/models/nn/residual.py | 108 --- src/clm/src/models/nn/utils.py | 125 ---- src/clm/src/models/sequence/__init__.py | 3 - src/clm/src/models/sequence/base.py | 131 ---- src/clm/src/models/sequence/block.py | 129 ---- src/clm/src/models/sequence/block_fft.py | 177 ----- src/clm/src/models/sequence/ff.py | 50 -- src/clm/src/models/sequence/h3.py | 206 ------ src/clm/src/models/sequence/h3_conv.py | 150 ---- src/clm/src/models/sequence/hyena.py | 359 --------- .../src/models/sequence/hyena_components.py | 255 ------- src/clm/src/models/sequence/long_conv.py | 170 ----- .../src/models/sequence/long_conv_kernel.py | 82 --- src/clm/src/models/sequence/long_conv_lm.py | 397 ---------- src/clm/src/models/sequence/mha.py | 122 ---- src/clm/src/models/sequence/model.py | 134 ---- src/clm/src/models/sequence/pool.py | 459 ------------ src/clm/src/models/sequence/simple_lm.py | 469 ------------ src/clm/src/models/sequence/ssm/dplr.py | 107 --- src/clm/src/models/sequence/ssm/hippo.py | 259 ------- src/clm/src/models/sequence/ssm/s4_simple.py | 262 ------- src/clm/src/models/sequence/ssm/s4d.py | 404 ---------- src/clm/src/models/sequence/ssm/ss_kernel.py | 180 ----- .../src/models/sequence/ssm/ss_kernel_diag.py | 331 --------- .../models/sequence/ssm/ss_kernel_shift.py | 83 --- src/clm/src/ops/fftconv.py | 103 --- src/clm/src/ops/krylov.py | 198 ----- src/clm/src/ops/toeplitz.py | 157 ---- src/clm/src/ops/unroll.py | 421 ----------- src/clm/src/ops/vandermonde.py | 167 ----- src/clm/src/retnet/__init__.py | 0 src/clm/src/retnet/complex/retention.py | 177 ----- src/clm/src/retnet/complex/retnet.py | 118 --- src/clm/src/retnet/complex/test_retention.py | 119 --- src/clm/src/retnet/complex/test_retnet.py | 102 --- src/clm/src/retnet/complex/util.py | 71 -- src/clm/src/retnet/example.py | 17 - src/clm/src/retnet/retention.py | 204 ------ src/clm/src/retnet/retnet.py | 76 -- src/clm/src/retnet/tests.py | 154 ---- src/clm/src/retnet/xpos_relative_position.py | 94 --- src/clm/src/tasks/decoders.py | 319 -------- src/clm/src/tasks/encoders.py | 358 --------- src/clm/src/tasks/metrics.py | 225 ------ src/clm/src/tasks/tasks.py | 371 ---------- src/clm/src/tasks/torchmetrics.py | 120 --- src/clm/src/utils/__init__.py | 1 - src/clm/src/utils/config.py | 124 ---- src/clm/src/utils/distributed.py | 144 ---- src/clm/src/utils/optim/lamb.py | 251 ------- src/clm/src/utils/optim/schedulers.py | 87 --- src/clm/src/utils/optim_groups.py | 144 ---- src/clm/src/utils/permutations.py | 180 ----- src/clm/src/utils/registry.py | 53 -- src/clm/src/utils/train.py | 156 ---- 84 files changed, 15926 deletions(-) delete mode 100644 src/clm/src/__init__.py delete mode 100644 src/clm/src/callbacks/norms.py delete mode 100644 src/clm/src/callbacks/params.py delete mode 100644 src/clm/src/callbacks/progressive_resizing.py delete mode 100644 src/clm/src/callbacks/timer.py delete mode 100644 src/clm/src/callbacks/wandb.py delete mode 100644 src/clm/src/dataloaders/README.md delete mode 100644 src/clm/src/dataloaders/__init__.py delete mode 100644 src/clm/src/dataloaders/base.py delete mode 100644 src/clm/src/dataloaders/basic.py delete mode 100644 src/clm/src/dataloaders/datasets/detokenizer.py delete mode 100644 src/clm/src/dataloaders/datasets/lm_dataset.py delete mode 100644 src/clm/src/dataloaders/et.py delete mode 100644 src/clm/src/dataloaders/fault_tolerant_sampler.py delete mode 100644 src/clm/src/dataloaders/language_modeling_hf.py delete mode 100644 src/clm/src/dataloaders/lm.py delete mode 100644 src/clm/src/dataloaders/lra.py delete mode 100644 src/clm/src/dataloaders/synthetics.py delete mode 100644 src/clm/src/dataloaders/utils/cifar_augmentations.py delete mode 100644 src/clm/src/dataloaders/utils/timm_mixup.py delete mode 100644 src/clm/src/dataloaders/utils/vocabulary.py delete mode 100644 src/clm/src/dataloaders/vision.py delete mode 100644 src/clm/src/models/__init__.py delete mode 100644 src/clm/src/models/baselines/vit_all.py delete mode 100644 src/clm/src/models/nn/__init__.py delete mode 100644 src/clm/src/models/nn/adaptive_softmax.py delete mode 100644 src/clm/src/models/nn/components.py delete mode 100644 src/clm/src/models/nn/dxt.py delete mode 100644 src/clm/src/models/nn/gate.py delete mode 100644 src/clm/src/models/nn/residual.py delete mode 100644 src/clm/src/models/nn/utils.py delete mode 100644 src/clm/src/models/sequence/__init__.py delete mode 100644 src/clm/src/models/sequence/base.py delete mode 100644 src/clm/src/models/sequence/block.py delete mode 100644 src/clm/src/models/sequence/block_fft.py delete mode 100644 src/clm/src/models/sequence/ff.py delete mode 100644 src/clm/src/models/sequence/h3.py delete mode 100644 src/clm/src/models/sequence/h3_conv.py delete mode 100644 src/clm/src/models/sequence/hyena.py delete mode 100644 src/clm/src/models/sequence/hyena_components.py delete mode 100644 src/clm/src/models/sequence/long_conv.py delete mode 100644 src/clm/src/models/sequence/long_conv_kernel.py delete mode 100644 src/clm/src/models/sequence/long_conv_lm.py delete mode 100644 src/clm/src/models/sequence/mha.py delete mode 100644 src/clm/src/models/sequence/model.py delete mode 100644 src/clm/src/models/sequence/pool.py delete mode 100644 src/clm/src/models/sequence/simple_lm.py delete mode 100644 src/clm/src/models/sequence/ssm/dplr.py delete mode 100644 src/clm/src/models/sequence/ssm/hippo.py delete mode 100644 src/clm/src/models/sequence/ssm/s4_simple.py delete mode 100644 src/clm/src/models/sequence/ssm/s4d.py delete mode 100644 src/clm/src/models/sequence/ssm/ss_kernel.py delete mode 100644 src/clm/src/models/sequence/ssm/ss_kernel_diag.py delete mode 100644 src/clm/src/models/sequence/ssm/ss_kernel_shift.py delete mode 100644 src/clm/src/ops/fftconv.py delete mode 100644 src/clm/src/ops/krylov.py delete mode 100644 src/clm/src/ops/toeplitz.py delete mode 100644 src/clm/src/ops/unroll.py delete mode 100644 src/clm/src/ops/vandermonde.py delete mode 100644 src/clm/src/retnet/__init__.py delete mode 100644 src/clm/src/retnet/complex/retention.py delete mode 100644 src/clm/src/retnet/complex/retnet.py delete mode 100644 src/clm/src/retnet/complex/test_retention.py delete mode 100644 src/clm/src/retnet/complex/test_retnet.py delete mode 100644 src/clm/src/retnet/complex/util.py delete mode 100644 src/clm/src/retnet/example.py delete mode 100644 src/clm/src/retnet/retention.py delete mode 100644 src/clm/src/retnet/retnet.py delete mode 100644 src/clm/src/retnet/tests.py delete mode 100644 src/clm/src/retnet/xpos_relative_position.py delete mode 100644 src/clm/src/tasks/decoders.py delete mode 100644 src/clm/src/tasks/encoders.py delete mode 100644 src/clm/src/tasks/metrics.py delete mode 100644 src/clm/src/tasks/tasks.py delete mode 100644 src/clm/src/tasks/torchmetrics.py delete mode 100644 src/clm/src/utils/__init__.py delete mode 100644 src/clm/src/utils/config.py delete mode 100644 src/clm/src/utils/distributed.py delete mode 100644 src/clm/src/utils/optim/lamb.py delete mode 100644 src/clm/src/utils/optim/schedulers.py delete mode 100644 src/clm/src/utils/optim_groups.py delete mode 100644 src/clm/src/utils/permutations.py delete mode 100644 src/clm/src/utils/registry.py delete mode 100644 src/clm/src/utils/train.py diff --git a/src/clm/src/__init__.py b/src/clm/src/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/clm/src/callbacks/norms.py b/src/clm/src/callbacks/norms.py deleted file mode 100644 index a6d8b6c3..00000000 --- a/src/clm/src/callbacks/norms.py +++ /dev/null @@ -1,39 +0,0 @@ -import pytorch_lightning as pl -from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning.utilities.parsing import AttributeDict -from omegaconf import OmegaConf - -class TrackNorms(pl.Callback): - - # TODO do callbacks happen before or after the method in the main LightningModule? - # @rank_zero_only # needed? - def on_after_training_step(self, batch, batch_idx, trainer: pl.Trainer, pl_module: pl.LightningModule): - # Log extra metrics - metrics = {} - - if hasattr(pl_module, "_grad_norms"): - metrics.update(pl_module._grad_norms) - - self.log_dict( - metrics, - on_step=True, - on_epoch=False, - prog_bar=False, - add_dataloader_idx=False, - sync_dist=True, - ) - - - def on_after_backward(self, trainer: pl.Trainer, pl_module: pl.LightningModule): - # example to inspect gradient information in tensorboard - if OmegaConf.select(trainer.hparams, 'trainer.track_grad_norms'): # TODO dot notation should work with omegaconf? - norms = {} - for name, p in pl_module.named_parameters(): - if p.grad is None: - continue - - # param_norm = float(p.grad.data.norm(norm_type)) - param_norm = torch.mean(p.grad.data ** 2) - norms[f"grad_norm.{name}"] = param_norm - pl_module._grad_norms = norms - diff --git a/src/clm/src/callbacks/params.py b/src/clm/src/callbacks/params.py deleted file mode 100644 index f3ddd1ff..00000000 --- a/src/clm/src/callbacks/params.py +++ /dev/null @@ -1,37 +0,0 @@ -from typing import Any - -import pytorch_lightning as pl -from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning.utilities.parsing import AttributeDict - - -class ParamsLog(pl.Callback): - """ Log the number of parameters of the model """ - def __init__( - self, - total: bool = True, - trainable: bool = True, - fixed: bool = True, - ): - super().__init__() - self._log_stats = AttributeDict( - { - 'total_params_log': total, - 'trainable_params_log': trainable, - 'non_trainable_params_log': fixed, - } - ) - - @rank_zero_only - def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - logs = {} - if self._log_stats.total_params_log: - logs["params/total"] = sum(p.numel() for p in pl_module.parameters()) - if self._log_stats.trainable_params_log: - logs["params/trainable"] = sum(p.numel() for p in pl_module.parameters() - if p.requires_grad) - if self._log_stats.non_trainable_params_log: - logs["params/fixed"] = sum(p.numel() for p in pl_module.parameters() - if not p.requires_grad) - if trainer.logger: - trainer.logger.log_hyperparams(logs) diff --git a/src/clm/src/callbacks/progressive_resizing.py b/src/clm/src/callbacks/progressive_resizing.py deleted file mode 100644 index 85638db6..00000000 --- a/src/clm/src/callbacks/progressive_resizing.py +++ /dev/null @@ -1,118 +0,0 @@ -import numpy as np -from pytorch_lightning.callbacks import Callback - -import clm.src.utils as utils -from clm.src.utils import registry - - -class ProgressiveResizing(Callback): - - def __init__(self, stage_params: list): - """ - stage_params is a list of dicts - e.g. stage_params = [ - {'resolution': 4, 'epochs': 50}, # 32 x 32 - {'resolution': 2, 'epochs': 30}, # 64 x 64 - {'resolution': 1, 'epochs': 20}, # 128 x 128 - ] - """ - super().__init__() - assert len(stage_params) > 0, 'No stages specified' - assert all([{'resolution', 'epochs'} <= set(stage.keys()) for stage in stage_params]), \ - 'stage_params must contain keys: resolution and epochs' - - self.stage_params = stage_params - self.stage_epochs_cume = np.cumsum([stage['epochs'] for stage in stage_params]) - - self._current_stage = 0 - - def _verify_stages(self, trainer, model): - # Double-check that stage parameters are correct, otherwise we'll fail in the middle of training - for stage in self.stage_params: - if hasattr(stage, 'scheduler'): - # Verify that we can actually create the scheduler when we need to update it in each stage - scheduler = utils.instantiate(registry.scheduler, {**model.hparams.scheduler, **stage['scheduler']}, trainer.optimizers[0]) - del scheduler - - def on_train_start(self, trainer, model) -> None: - # Verify all the stage parameters are correct - self._verify_stages(trainer, model) - - print(f"Training starts at {trainer.current_epoch}") - if trainer.current_epoch == 0: - # Update the model to the first stage - self._update_to_current_stage(trainer, model) - else: - # Preemption or resumption of progressive resizing - # Update the stage to the current one - self._current_stage = int(np.searchsorted(self.stage_epochs_cume - 1, trainer.current_epoch)) - self._starting_stage = np.any(trainer.current_epoch == self.stage_epochs_cume) - - print("Progressive Resizing: Restarting at Stage {}".format(self._current_stage)) - if self._starting_stage: - self._update_lr_scheduler(trainer, model) - - # Set the dataloader and model - self._update_dataloaders(trainer, model) - self._update_model(trainer, model) - - return super().on_train_start(trainer, model) - - def _update_lr_scheduler(self, trainer, model): - if not hasattr(self.stage_params[self._current_stage], 'scheduler'): - # No scheduler specified, so don't update the current scheduler - return - - assert len(trainer.lr_schedulers) == 1 - # Reinitialize the scheduler - # We don't need to carry over information from the last scheduler e.g. the last_epoch property, - # because that will mess with the new scheduler when we step it - hparams = {**model.hparams.scheduler, **self.stage_params[self._current_stage]['scheduler']} - - # Note that passing in the optimizer below is okay: the scheduler will be reinitialized and doesn't seem to inherit any current lr info from the optimizer - trainer.lr_schedulers[0]['scheduler'] = utils.instantiate(registry.scheduler, hparams, trainer.optimizers[0]) - - print("\tChanged scheduler to {}".format(hparams)) - - def _update_dataloaders(self, trainer, model): - # Set the train resolution and reset the dataloader - model.hparams.loader.train_resolution = self.stage_params[self._current_stage]['resolution'] - trainer.reset_train_dataloader(model) - - print('\tChanged resolution to {}'.format(self.stage_params[self._current_stage]['resolution'])) - - def _update_model(self, trainer, model): - if not hasattr(self.stage_params[self._current_stage], 'bandlimit'): - return - - # Update the bandlimit value for the model: this is a hack to make sure the model is updated - # Iterate over all the modules - for module in model.modules(): - if hasattr(module, 'bandlimit'): - module.bandlimit = self.stage_params[self._current_stage]['bandlimit'] - - print('\tChanged bandlimit to {}'.format(self.stage_params[self._current_stage]['bandlimit'])) - - def _update_to_current_stage(self, trainer, model): - print("Progressive Resizing: Moving to Stage {}".format(self._current_stage)) - # Update the train dataloader, model and scheduler - self._update_dataloaders(trainer, model) - self._update_model(trainer, model) - self._update_lr_scheduler(trainer, model) - - - def on_train_epoch_end(self, trainer, model): - """ - Check to see if new stage is reached for the next epoch, and if so, prepare the new stage by - changing the dataloader. - - (We do next epoch so that the dataloader is prepared before the next epoch) - """ - next_epoch = trainer.current_epoch + 1 - - # Check if stage should be increased - if next_epoch >= self.stage_epochs_cume[self._current_stage] and self._current_stage < len(self.stage_params) - 1: - self._current_stage += 1 - self._update_to_current_stage(trainer, model) - - return super().on_train_epoch_end(trainer, model) diff --git a/src/clm/src/callbacks/timer.py b/src/clm/src/callbacks/timer.py deleted file mode 100644 index abe6c66c..00000000 --- a/src/clm/src/callbacks/timer.py +++ /dev/null @@ -1,100 +0,0 @@ -### https://github.com/HazyResearch/transformers/blob/master/src/callbacks/speed_monitor.py - -# Adapted from https://pytorch-lightning.readthedocs.io/en/latest/_modules/pytorch_lightning/callbacks/gpu_stats_monitor.html#GPUStatsMonitor -# We only need the speed monitoring, not the GPU monitoring -import time -from typing import Any - -from pytorch_lightning import Callback, Trainer, LightningModule -from pytorch_lightning.utilities import rank_zero_only -from pytorch_lightning.utilities.parsing import AttributeDict -from pytorch_lightning.utilities.types import STEP_OUTPUT - - -class Timer(Callback): - """Monitor the speed of each step and each epoch. - """ - def __init__( - self, - step: bool = True, - inter_step: bool = True, - epoch: bool = True, - val: bool = True, - ): - super().__init__() - self._log_stats = AttributeDict( { - 'step_time': step, - 'inter_step_time': inter_step, - 'epoch_time': epoch, - 'val_time': val, - }) - - def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - self._snap_epoch_time = None - - def on_train_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - self._snap_step_time = None - self._snap_inter_step_time = None - self._snap_epoch_time = time.time() - - def on_train_batch_start( - self, - trainer: Trainer, - pl_module: LightningModule, - batch: Any, - batch_idx: int, - ) -> None: - if self._log_stats.step_time: - self._snap_step_time = time.time() - - if not self._should_log(trainer): - return - - logs = {} - if self._log_stats.inter_step_time and self._snap_inter_step_time: - # First log at beginning of second step - logs["timer/inter_step"] = (time.time() - self._snap_inter_step_time) # * 1000 - - if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step) - - @rank_zero_only - def on_train_batch_end( - self, - trainer: Trainer, - pl_module: LightningModule, - outputs: STEP_OUTPUT, - batch: Any, - batch_idx: int, - ) -> None: - if self._log_stats.inter_step_time: - self._snap_inter_step_time = time.time() - - if not self._should_log(trainer): - return - - logs = {} - if self._log_stats.step_time and self._snap_step_time: - logs["timer/step"] = (time.time() - self._snap_step_time) # * 1000 - - if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step) - - @rank_zero_only - def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule,) -> None: - logs = {} - if self._log_stats.epoch_time and self._snap_epoch_time: - logs["timer/epoch"] = time.time() - self._snap_epoch_time - if trainer.logger: trainer.logger.log_metrics(logs, step=trainer.global_step) - - def on_validation_epoch_start(self, trainer: Trainer, pl_module: LightningModule) -> None: - self._snap_val_time = time.time() - - @rank_zero_only - def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule,) -> None: - logs = {} - if self._log_stats.val_time and self._snap_val_time: - logs["timer/validation"] = time.time() - self._snap_val_time - if trainer.logger: trainer.logger.log_metrics(logs) # , step=trainer.global_step) - - @staticmethod - def _should_log(trainer) -> bool: - return (trainer.global_step + 1) % trainer.log_every_n_steps == 0 or trainer.should_stop diff --git a/src/clm/src/callbacks/wandb.py b/src/clm/src/callbacks/wandb.py deleted file mode 100644 index 66b08f90..00000000 --- a/src/clm/src/callbacks/wandb.py +++ /dev/null @@ -1,277 +0,0 @@ -### https://github.com/HazyResearch/transformers/blob/master/src/callbacks/wandb_callbacks.py - -import glob -import os -from typing import List - -import matplotlib.pyplot as plt -import pandas as pd -import seaborn as sn -import torch -import wandb -from pytorch_lightning import Callback, Trainer -from pytorch_lightning.loggers import LoggerCollection, WandbLogger -from pytorch_lightning.utilities import rank_zero_only -from sklearn import metrics -from sklearn.metrics import f1_score, precision_score, recall_score - - -def get_wandb_logger(trainer: Trainer) -> WandbLogger: - """Safely get Weights&Biases logger from Trainer.""" - - if isinstance(trainer.logger, WandbLogger): - return trainer.logger - - if isinstance(trainer.logger, LoggerCollection): - for logger in trainer.logger: - if isinstance(logger, WandbLogger): - return logger - - raise Exception( - "You are using wandb related callback, but WandbLogger was not found for some reason..." - ) - - -class WatchModel(Callback): - """Make wandb watch model at the beginning of the run.""" - - def __init__(self, log: str = "gradients", log_freq: int = 100): - self.log = log - self.log_freq = log_freq - - @rank_zero_only - def on_train_start(self, trainer, pl_module): - logger = get_wandb_logger(trainer=trainer) - logger.watch(model=trainer.model, log=self.log, log_freq=self.log_freq) - - -class UploadCodeAsArtifact(Callback): - """Upload all *.py files to wandb as an artifact, at the beginning of the run.""" - - def __init__(self, code_dir: str): - self.code_dir = code_dir - - @rank_zero_only - def on_train_start(self, trainer, pl_module): - logger = get_wandb_logger(trainer=trainer) - experiment = logger.experiment - - code = wandb.Artifact("project-source", type="code") - for path in glob.glob(os.path.join(self.code_dir, "**/*.py"), recursive=True): - code.add_file(path) - - experiment.log_artifact(code) - - -class UploadCheckpointsAsArtifact(Callback): - """Upload checkpoints to wandb as an artifact, at the end of run.""" - - def __init__(self, ckpt_dir: str = "checkpoints/", upload_best_only: bool = False): - self.ckpt_dir = ckpt_dir - self.upload_best_only = upload_best_only - - @rank_zero_only - def on_train_end(self, trainer, pl_module): - logger = get_wandb_logger(trainer=trainer) - experiment = logger.experiment - - ckpts = wandb.Artifact("experiment-ckpts", type="checkpoints") - - if self.upload_best_only: - ckpts.add_file(trainer.checkpoint_callback.best_model_path) - else: - for path in glob.glob(os.path.join(self.ckpt_dir, "**/*.ckpt"), recursive=True): - ckpts.add_file(path) - - experiment.log_artifact(ckpts) - - -class LogConfusionMatrix(Callback): - """Generate confusion matrix every epoch and send it to wandb. - Expects validation step to return predictions and targets. - """ - - def __init__(self): - self.preds = [] - self.targets = [] - self.ready = True - - def on_sanity_check_start(self, trainer, pl_module) -> None: - self.ready = False - - def on_sanity_check_end(self, trainer, pl_module): - """Start executing this callback only after all validation sanity checks end.""" - self.ready = True - - def on_validation_batch_end( - self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx - ): - """Gather data from single batch.""" - if self.ready: - self.preds.append(outputs["preds"]) - self.targets.append(outputs["targets"]) - - def on_validation_epoch_end(self, trainer, pl_module): - """Generate confusion matrix.""" - if self.ready: - logger = get_wandb_logger(trainer) - experiment = logger.experiment - - preds = torch.cat(self.preds).cpu().numpy() - targets = torch.cat(self.targets).cpu().numpy() - - confusion_matrix = metrics.confusion_matrix(y_true=targets, y_pred=preds) - - # set figure size - plt.figure(figsize=(14, 8)) - - # set labels size - sn.set(font_scale=1.4) - - # set font size - sn.heatmap(confusion_matrix, annot=True, annot_kws={"size": 8}, fmt="g") - - # names should be uniqe or else charts from different experiments in wandb will overlap - experiment.log({f"confusion_matrix/{experiment.name}": wandb.Image(plt)}, commit=False) - - # according to wandb docs this should also work but it crashes - # experiment.log(f{"confusion_matrix/{experiment.name}": plt}) - - # reset plot - plt.clf() - - self.preds.clear() - self.targets.clear() - - -class LogF1PrecRecHeatmap(Callback): - """Generate f1, precision, recall heatmap every epoch and send it to wandb. - Expects validation step to return predictions and targets. - """ - - def __init__(self, class_names: List[str] = None): - self.preds = [] - self.targets = [] - self.ready = True - - def on_sanity_check_start(self, trainer, pl_module): - self.ready = False - - def on_sanity_check_end(self, trainer, pl_module): - """Start executing this callback only after all validation sanity checks end.""" - self.ready = True - - def on_validation_batch_end( - self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx - ): - """Gather data from single batch.""" - if self.ready: - self.preds.append(outputs["preds"]) - self.targets.append(outputs["targets"]) - - def on_validation_epoch_end(self, trainer, pl_module): - """Generate f1, precision and recall heatmap.""" - if self.ready: - logger = get_wandb_logger(trainer=trainer) - experiment = logger.experiment - - preds = torch.cat(self.preds).cpu().numpy() - targets = torch.cat(self.targets).cpu().numpy() - f1 = f1_score(preds, targets, average=None) - r = recall_score(preds, targets, average=None) - p = precision_score(preds, targets, average=None) - data = [f1, p, r] - - # set figure size - plt.figure(figsize=(14, 3)) - - # set labels size - sn.set(font_scale=1.2) - - # set font size - sn.heatmap( - data, - annot=True, - annot_kws={"size": 10}, - fmt=".3f", - yticklabels=["F1", "Precision", "Recall"], - ) - - # names should be uniqe or else charts from different experiments in wandb will overlap - experiment.log({f"f1_p_r_heatmap/{experiment.name}": wandb.Image(plt)}, commit=False) - - # reset plot - plt.clf() - - self.preds.clear() - self.targets.clear() - - -class LogImagePredictions(Callback): - """Logs a validation batch and their predictions to wandb. - Example adapted from: - https://wandb.ai/wandb/wandb-lightning/reports/Image-Classification-using-PyTorch-Lightning--VmlldzoyODk1NzY - """ - - def __init__(self, num_samples: int = 8): - super().__init__() - self.num_samples = num_samples - self.ready = True - - def on_sanity_check_start(self, trainer, pl_module): - self.ready = False - - def on_sanity_check_end(self, trainer, pl_module): - """Start executing this callback only after all validation sanity checks end.""" - self.ready = True - - def on_validation_epoch_end(self, trainer, pl_module): - if self.ready: - logger = get_wandb_logger(trainer=trainer) - experiment = logger.experiment - - # get a validation batch from the validation dat loader - val_samples = next(iter(trainer.datamodule.val_dataloader())) - val_imgs, val_labels = val_samples - - # run the batch through the network - val_imgs = val_imgs.to(device=pl_module.device) - logits = pl_module(val_imgs) - preds = torch.argmax(logits, axis=-1) - - # log the images as wandb Image - experiment.log( - { - f"Images/{experiment.name}": [ - wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") - for x, pred, y in zip( - val_imgs[: self.num_samples], - preds[: self.num_samples], - val_labels[: self.num_samples], - ) - ] - } - ) - -class LogDT(Callback): - """ Log the dt values (from NeurIPS 2021 LSSL submission) """ - def on_train_epoch_end(self, trainer, pl_module): - log_dict = {} - for name, m in pl_module.model.named_modules(): - if pl_module.hparams.train.get('log_dt', False) \ - and hasattr(m, "log_dt"): - log_dict[f"{name}.log_dt"] = ( - m.log_dt.detach().cpu().numpy().flatten() - ) - log_dict[f"{name}.log_dt.image"] = wandb.Image( - m.log_dt.detach().cpu().numpy().flatten().reshape(1, -1) - ) - log_dict[f"{name}.log_dt"] = wandb.Table( - dataframe=pd.DataFrame( - {"log_dt": m.log_dt.detach().cpu().numpy().flatten()} - ) - ) - - if torch.distributed.is_initialized() and torch.distributed.get_rank() == 0: - if trainer.logger is not None: - trainer.logger.experiment.log(log_dict) diff --git a/src/clm/src/dataloaders/README.md b/src/clm/src/dataloaders/README.md deleted file mode 100644 index d8234163..00000000 --- a/src/clm/src/dataloaders/README.md +++ /dev/null @@ -1,40 +0,0 @@ -# Overview - -Basic datasets including MNIST and CIFAR will auto-download. Source code for these datamodules are in [basic.py](basic.py). - -By default, data is downloaded to `./data/` by default, where `.` is the top level directory of this repository (e.g. 'safari'). - -## Advanced Usage - -After downloading and preparing data, the paths can be configured in several ways. - -1. Suppose that it is desired to download all data to a different folder, for example a different disk. -The data path can be configured by setting the environment variable `DATA_PATH`, which defaults to `./data`. - -2. For fine-grained control over the path of a particular dataset, set `dataset.data_dir` in the config. For example, if the LRA ListOps files are located in `/home/lra/listops-1000/` instead of the default `./data/listops/`, -pass in `+dataset.data_dir=/home/lra/listops-1000` on the command line or modify the config file directly. - -3. As a simple workaround, softlinks can be set, e.g. `ln -s /home/lra/listops-1000 ./data/listops` - - -# Data Preparation - -[LRA](#long-range-arena-lra) must be manually downloaded. - -By default, these should go under `$DATA_PATH/`, which defaults to `./data`. For the remainder of this README, these are used interchangeably. - -## Long Range Arena (LRA) - -LRA can be downloaded from the [GitHub page](https://github.com/google-research/long-range-arena). -These datasets should be organized as follows: -``` -$DATA_PATH/ - pathfinder/ - pathfinder32/ - pathfinder64/ - pathfinder128/ - pathfinder256/ - aan/ - listops/ -``` -The other two datasets in the suite ("Image" or grayscale sequential CIFAR-10; "Text" or char-level IMDB sentiment classification) are both auto-downloaded. \ No newline at end of file diff --git a/src/clm/src/dataloaders/__init__.py b/src/clm/src/dataloaders/__init__.py deleted file mode 100644 index e6a24bb2..00000000 --- a/src/clm/src/dataloaders/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from . import basic, et, lra, language_modeling_hf, synthetics, vision -from .base import SequenceDataset diff --git a/src/clm/src/dataloaders/base.py b/src/clm/src/dataloaders/base.py deleted file mode 100644 index bec9ff7d..00000000 --- a/src/clm/src/dataloaders/base.py +++ /dev/null @@ -1,276 +0,0 @@ -""" Datasets for core experimental results """ - -import os -import pickle -from functools import partial -from pathlib import Path - -import numpy as np -import torch -import torchvision -from einops import rearrange -from einops.layers.torch import Rearrange -from clm.src.utils import is_list, permutations -from torch.nn import functional as F - -def deprecated(cls_or_func): - def _deprecated(*args, **kwargs): - print(f"{cls_or_func} is deprecated") - return cls_or_func(*args, **kwargs) - return _deprecated - -# Default data path is environment variable or hippo/data -if (default_data_path := os.getenv("DATA_PATH")) is None: - default_data_path = Path(__file__).parent.parent.parent.absolute() - default_data_path = default_data_path / "data" -else: - default_data_path = Path(default_data_path).absolute() - -class DefaultCollateMixin: - """Controls collating in the DataLoader - - The CollateMixin classes instantiate a dataloader by separating collate arguments with the rest of the dataloader arguments. Instantiations of this class should modify the callback functions as desired, and modify the collate_args list. The class then defines a _dataloader() method which takes in a DataLoader constructor and arguments, constructs a collate_fn based on the collate_args, and passes the rest of the arguments into the constructor. - """ - - @classmethod - def _collate_callback(cls, x, *args, **kwargs): - """ - Modify the behavior of the default _collate method. - """ - return x - - _collate_arg_names = [] - - @classmethod - def _return_callback(cls, return_value, *args, **kwargs): - """ - Modify the return value of the collate_fn. - Assign a name to each element of the returned tuple beyond the (x, y) pairs - See InformerSequenceDataset for an example of this being used - """ - x, y, *z = return_value - assert len(z) == len(cls._collate_arg_names), "Specify a name for each auxiliary data item returned by dataset" - return x, y, {k: v for k, v in zip(cls._collate_arg_names, z)} - - @classmethod - def _collate(cls, batch, *args, **kwargs): - # From https://github.com/pyforch/pytorch/blob/master/torch/utils/data/_utils/collate.py - elem = batch[0] - if isinstance(elem, torch.Tensor): - out = None - if torch.utils.data.get_worker_info() is not None: - # If we're in a background process, concatenate directly into a - # shared memory tensor to avoid an extra copy - numel = sum(x.numel() for x in batch) - storage = elem.storage()._new_shared(numel) - out = elem.new(storage) - x = torch.stack(batch, dim=0, out=out) - - # Insert custom functionality into the collate_fn - x = cls._collate_callback(x, *args, **kwargs) - - return x - else: - return torch.tensor(batch) - - @classmethod - def _collate_fn(cls, batch, *args, **kwargs): - """ - Default collate function. - Generally accessed by the dataloader() methods to pass into torch DataLoader - - Arguments: - batch: list of (x, y) pairs - args, kwargs: extra arguments that get passed into the _collate_callback and _return_callback - """ - x, y, *z = zip(*batch) - - x = cls._collate(x, *args, **kwargs) - y = cls._collate(y) - z = [cls._collate(z_) for z_ in z] - - return_value = (x, y, *z) - return cls._return_callback(return_value, *args, **kwargs) - - # List of loader arguments to pass into collate_fn - collate_args = [] - - def _dataloader(self, dataset, **loader_args): - collate_args = {k: loader_args[k] for k in loader_args if k in self.collate_args} - loader_args = {k: loader_args[k] for k in loader_args if k not in self.collate_args} - loader_cls = loader_registry[loader_args.pop("_name_", None)] - return loader_cls( - dataset=dataset, - collate_fn=partial(self._collate_fn, **collate_args), - **loader_args, - ) - - -class SequenceResolutionCollateMixin(DefaultCollateMixin): - """self.collate_fn(resolution) produces a collate function that subsamples elements of the sequence""" - - @classmethod - def _collate_callback(cls, x, resolution=None): - if resolution is None: - pass - else: - # Assume x is (B, L_0, L_1, ..., L_k, C) for x.ndim > 2 and (B, L) for x.ndim = 2 - assert x.ndim >= 2 - n_resaxes = max(1, x.ndim - 2) # [AG 22/07/02] this line looks suspicious... are there cases with 2 axes? - # rearrange: b (l_0 res_0) (l_1 res_1) ... (l_k res_k) ... -> res_0 res_1 .. res_k b l_0 l_1 ... - lhs = "b " + " ".join([f"(l{i} res{i})" for i in range(n_resaxes)]) + " ..." - rhs = " ".join([f"res{i}" for i in range(n_resaxes)]) + " b " + " ".join([f"l{i}" for i in range(n_resaxes)]) + " ..." - x = rearrange(x, lhs + " -> " + rhs, **{f'res{i}': resolution for i in range(n_resaxes)}) - x = x[tuple([0] * n_resaxes)] - - return x - - @classmethod - def _return_callback(cls, return_value, resolution=None): - return *return_value, {"rate": resolution} - - - collate_args = ['resolution'] - -class ImageResolutionCollateMixin(SequenceResolutionCollateMixin): - """self.collate_fn(resolution, img_size) produces a collate function that resizes inputs to size img_size/resolution""" - - _interpolation = torchvision.transforms.InterpolationMode.BILINEAR - _antialias = True - - @classmethod - def _collate_callback(cls, x, resolution=None, img_size=None, channels_last=True): - if x.ndim < 4: - return super()._collate_callback(x, resolution=resolution) - if img_size is None: - x = super()._collate_callback(x, resolution=resolution) - else: - x = rearrange(x, 'b ... c -> b c ...') if channels_last else x - _size = round(img_size/resolution) - x = torchvision.transforms.functional.resize( - x, - size=[_size, _size], - interpolation=cls._interpolation, - antialias=cls._antialias, - ) - x = rearrange(x, 'b c ... -> b ... c') if channels_last else x - return x - - @classmethod - def _return_callback(cls, return_value, resolution=None, img_size=None, channels_last=True): - return *return_value, {"rate": resolution} - - collate_args = ['resolution', 'img_size', 'channels_last'] - - - -# class SequenceDataset(LightningDataModule): -# [21-09-10 AG] Subclassing LightningDataModule fails due to trying to access _has_setup_fit. No idea why. So we just provide our own class with the same core methods as LightningDataModule (e.g. setup) -class SequenceDataset(DefaultCollateMixin): - registry = {} - _name_ = NotImplementedError("Dataset must have shorthand name") - - # Since subclasses do not specify __init__ which is instead handled by this class - # Subclasses can provide a list of default arguments which are automatically registered as attributes - # TODO it might be possible to write this as a @dataclass, but it seems tricky to separate from the other features of this class such as the _name_ and d_input/d_output - @property - def init_defaults(self): - return {} - - # https://www.python.org/dev/peps/pep-0487/#subclass-registration - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - cls.registry[cls._name_] = cls - - def __init__(self, _name_, data_dir=None, **dataset_cfg): - assert _name_ == self._name_ - self.data_dir = Path(data_dir).absolute() if data_dir is not None else None - - # Add all arguments to self - init_args = self.init_defaults.copy() - init_args.update(dataset_cfg) - for k, v in init_args.items(): - setattr(self, k, v) - - # The train, val, test datasets must be set by `setup()` - self.dataset_train = self.dataset_val = self.dataset_test = None - - self.init() - - def init(self): - """Hook called at end of __init__, override this instead of __init__""" - pass - - def setup(self): - """This method should set self.dataset_train, self.dataset_val, and self.dataset_test.""" - raise NotImplementedError - - def split_train_val(self, val_split): - """ - Randomly split self.dataset_train into a new (self.dataset_train, self.dataset_val) pair. - """ - train_len = int(len(self.dataset_train) * (1.0 - val_split)) - self.dataset_train, self.dataset_val = torch.utils.data.random_split( - self.dataset_train, - (train_len, len(self.dataset_train) - train_len), - generator=torch.Generator().manual_seed( - getattr(self, "seed", 42) - ), # PL is supposed to have a way to handle seeds properly, but doesn't seem to work for us - ) - - def train_dataloader(self, **kwargs): - return self._train_dataloader(self.dataset_train, **kwargs) - - def _train_dataloader(self, dataset, **kwargs): - if dataset is None: return - kwargs['shuffle'] = 'sampler' not in kwargs # shuffle cant be True if we have custom sampler - return self._dataloader(dataset, **kwargs) - - def val_dataloader(self, **kwargs): - return self._eval_dataloader(self.dataset_val, **kwargs) - - def test_dataloader(self, **kwargs): - return self._eval_dataloader(self.dataset_test, **kwargs) - - def _eval_dataloader(self, dataset, **kwargs): - if dataset is None: return - # Note that shuffle=False by default - return self._dataloader(dataset, **kwargs) - - def __str__(self): - return self._name_ - -class ResolutionSequenceDataset(SequenceDataset, SequenceResolutionCollateMixin): - - def _train_dataloader(self, dataset, train_resolution=None, eval_resolutions=None, **kwargs): - if train_resolution is None: train_resolution = [1] - if not is_list(train_resolution): train_resolution = [train_resolution] - assert len(train_resolution) == 1, "Only one train resolution supported for now." - return super()._train_dataloader(dataset, resolution=train_resolution[0], **kwargs) - - def _eval_dataloader(self, dataset, train_resolution=None, eval_resolutions=None, **kwargs): - if dataset is None: return - if eval_resolutions is None: eval_resolutions = [1] - if not is_list(eval_resolutions): eval_resolutions = [eval_resolutions] - - dataloaders = [] - for resolution in eval_resolutions: - dataloaders.append(super()._eval_dataloader(dataset, resolution=resolution, **kwargs)) - - return ( - { - None if res == 1 else str(res): dl - for res, dl in zip(eval_resolutions, dataloaders) - } - if dataloaders is not None else None - ) - -class ImageResolutionSequenceDataset(ResolutionSequenceDataset, ImageResolutionCollateMixin): - pass - - - -# Registry for dataloader class -loader_registry = { - None: torch.utils.data.DataLoader, # default case -} diff --git a/src/clm/src/dataloaders/basic.py b/src/clm/src/dataloaders/basic.py deleted file mode 100644 index 938450e8..00000000 --- a/src/clm/src/dataloaders/basic.py +++ /dev/null @@ -1,271 +0,0 @@ -"""Implementation of basic benchmark datasets used in S4 experiments: MNIST, CIFAR10 and Speech Commands.""" -import numpy as np -import torch -import torchvision -from einops.layers.torch import Rearrange -from clm.src.utils import permutations - -from clm.src.dataloaders.base import default_data_path, ImageResolutionSequenceDataset, ResolutionSequenceDataset, SequenceDataset - - -class MNIST(SequenceDataset): - _name_ = "mnist" - d_input = 1 - d_output = 10 - l_output = 0 - L = 784 - - @property - def init_defaults(self): - return { - "permute": True, - "val_split": 0.1, - "seed": 42, # For train/val split - } - - def setup(self): - self.data_dir = self.data_dir or default_data_path / self._name_ - - transform_list = [ - torchvision.transforms.ToTensor(), - torchvision.transforms.Lambda(lambda x: x.view(self.d_input, self.L).t()), - ] # (L, d_input) - if self.permute: - # below is another permutation that other works have used - # permute = np.random.RandomState(92916) - # permutation = torch.LongTensor(permute.permutation(784)) - permutation = permutations.bitreversal_permutation(self.L) - transform_list.append( - torchvision.transforms.Lambda(lambda x: x[permutation]) - ) - # TODO does MNIST need normalization? - # torchvision.transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs - transform = torchvision.transforms.Compose(transform_list) - self.dataset_train = torchvision.datasets.MNIST( - self.data_dir, - train=True, - download=True, - transform=transform, - ) - self.dataset_test = torchvision.datasets.MNIST( - self.data_dir, - train=False, - transform=transform, - ) - self.split_train_val(self.val_split) - - def __str__(self): - return f"{'p' if self.permute else 's'}{self._name_}" - - -class CIFAR10(ImageResolutionSequenceDataset): - _name_ = "cifar" - d_output = 10 - l_output = 0 - - @property - def init_defaults(self): - return { - "permute": None, - "grayscale": False, - "tokenize": False, # if grayscale, tokenize into discrete byte inputs - "augment": False, - "cutout": False, - "rescale": None, - "random_erasing": False, - "val_split": 0.1, - "seed": 42, # For validation split - } - - @property - def d_input(self): - if self.grayscale: - if self.tokenize: - return 256 - else: - return 1 - else: - assert not self.tokenize - return 3 - - def setup(self): - img_size = 32 - if self.rescale: - img_size //= self.rescale - - if self.grayscale: - preprocessors = [ - torchvision.transforms.Grayscale(), - torchvision.transforms.ToTensor(), - ] - permutations_list = [ - torchvision.transforms.Lambda( - lambda x: x.view(1, img_size * img_size).t() - ) # (L, d_input) - ] - - if self.tokenize: - preprocessors.append( - torchvision.transforms.Lambda(lambda x: (x * 255).long()) - ) - permutations_list.append(Rearrange("l 1 -> l")) - else: - preprocessors.append( - torchvision.transforms.Normalize( - mean=122.6 / 255.0, std=61.0 / 255.0 - ) - ) - else: - preprocessors = [ - torchvision.transforms.ToTensor(), - torchvision.transforms.Normalize( - (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261) - ), - ] - permutations_list = [ - torchvision.transforms.Lambda( - Rearrange("z h w -> (h w) z", z=3, h=img_size, w=img_size) - ) # (L, d_input) - ] - - # Permutations and reshaping - if self.permute == "br": - permutation = permutations.bitreversal_permutation(img_size * img_size) - print("bit reversal", permutation) - permutations_list.append(torchvision.transforms.Lambda(lambda x: x[permutation])) - elif self.permute == "snake": - permutation = permutations.snake_permutation(img_size, img_size) - print("snake", permutation) - permutations_list.append(torchvision.transforms.Lambda(lambda x: x[permutation])) - elif self.permute == "hilbert": - permutation = permutations.hilbert_permutation(img_size) - print("hilbert", permutation) - permutations_list.append(torchvision.transforms.Lambda(lambda x: x[permutation])) - elif self.permute == "transpose": - permutation = permutations.transpose_permutation(img_size, img_size) - transform = torchvision.transforms.Lambda( - lambda x: torch.cat([x, x[permutation]], dim=-1) - ) - permutations_list.append(transform) - elif self.permute == "2d": # h, w, c - permutation = torchvision.transforms.Lambda( - Rearrange("(h w) c -> h w c", h=img_size, w=img_size) - ) - permutations_list.append(permutation) - elif self.permute == "2d_transpose": # c, h, w - permutation = torchvision.transforms.Lambda( - Rearrange("(h w) c -> c h w", h=img_size, w=img_size) - ) - permutations_list.append(permutation) - - # Augmentation - if self.augment: - augmentations = [ - torchvision.transforms.RandomCrop( - img_size, padding=4, padding_mode="symmetric" - ), - torchvision.transforms.RandomHorizontalFlip(), - ] - - post_augmentations = [] - if self.cutout: - post_augmentations.append(Cutout(1, img_size // 2)) - pass - if self.random_erasing: - # augmentations.append(RandomErasing()) - pass - else: - augmentations, post_augmentations = [], [] - transforms_train = ( - augmentations + preprocessors + post_augmentations + permutations_list - ) - transforms_eval = preprocessors + permutations_list - - transform_train = torchvision.transforms.Compose(transforms_train) - transform_eval = torchvision.transforms.Compose(transforms_eval) - self.dataset_train = torchvision.datasets.CIFAR10( - f"{default_data_path}/{self._name_}", - train=True, - download=True, - transform=transform_train, - ) - self.dataset_test = torchvision.datasets.CIFAR10( - f"{default_data_path}/{self._name_}", train=False, transform=transform_eval - ) - - if self.rescale: - print(f"Resizing all images to {img_size} x {img_size}.") - self.dataset_train.data = self.dataset_train.data.reshape((self.dataset_train.data.shape[0], 32 // self.rescale, self.rescale, 32 // self.rescale, self.rescale, 3)).max(4).max(2).astype(np.uint8) - self.dataset_test.data = self.dataset_test.data.reshape((self.dataset_test.data.shape[0], 32 // self.rescale, self.rescale, 32 // self.rescale, self.rescale, 3)).max(4).max(2).astype(np.uint8) - - self.split_train_val(self.val_split) - - def __str__(self): - return f"{'p' if self.permute else 's'}{self._name_}" - -class SpeechCommands(ResolutionSequenceDataset): - _name_ = "sc" - - @property - def init_defaults(self): - return { - "mfcc": False, - "dropped_rate": 0.0, - "length": 16000, - "all_classes": False, - } - - @property - def d_input(self): - _d_input = 20 if self.mfcc else 1 - _d_input += 1 if self.dropped_rate > 0.0 else 0 - return _d_input - - @property - def d_output(self): - return 10 if not self.all_classes else 35 - - @property - def l_output(self): - return 0 - - @property - def L(self): - return 161 if self.mfcc else self.length - - - def setup(self): - self.data_dir = self.data_dir or default_data_path # TODO make same logic as other classes - - from clm.src.dataloaders.datasets.sc import _SpeechCommands - - # TODO refactor with data_dir argument - self.dataset_train = _SpeechCommands( - partition="train", - length=self.L, - mfcc=self.mfcc, - sr=1, - dropped_rate=self.dropped_rate, - path=self.data_dir, - all_classes=self.all_classes, - ) - - self.dataset_val = _SpeechCommands( - partition="val", - length=self.L, - mfcc=self.mfcc, - sr=1, - dropped_rate=self.dropped_rate, - path=self.data_dir, - all_classes=self.all_classes, - ) - - self.dataset_test = _SpeechCommands( - partition="test", - length=self.L, - mfcc=self.mfcc, - sr=1, - dropped_rate=self.dropped_rate, - path=self.data_dir, - all_classes=self.all_classes, - ) diff --git a/src/clm/src/dataloaders/datasets/detokenizer.py b/src/clm/src/dataloaders/datasets/detokenizer.py deleted file mode 100644 index c42266be..00000000 --- a/src/clm/src/dataloaders/datasets/detokenizer.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copied from https://github.com/stanford-crfm/mistral/blob/main/src/corpora/detokenization.py -# Which was originally from https://github.com/NVIDIA/Megatron-LM/blob/aed2f75e209e525c842aec7c044af7acae2a4614/tasks/zeroshot_gpt/detokenizer.py - -""" -Handle detokenization for different dataset for zero-shot LM evaluation. -""" -import re - - -def wikitext_detokenize(string: str) -> str: - """ - Wikitext is whitespace tokenized and we remove these whitespaces. - Taken from https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt2/detokenizer.py - """ - # Contractions - string = string.replace("s '", "s'") - string = re.sub(r"/' [0-9]/", r"/'[0-9]/", string) - - # Number Separators - string = string.replace(" @-@ ", "-") - string = string.replace(" @,@ ", ",") - string = string.replace(" @.@ ", ".") - - # Punctuation - string = string.replace(" : ", ": ") - string = string.replace(" ; ", "; ") - string = string.replace(" . ", ". ") - string = string.replace(" ! ", "! ") - string = string.replace(" ? ", "? ") - string = string.replace(" , ", ", ") - - # Double Brackets - string = re.sub(r"\(\s*([^\)]*?)\s*\)", r"(\1)", string) - string = re.sub(r"\[\s*([^\]]*?)\s*\]", r"[\1]", string) - string = re.sub(r"{\s*([^}]*?)\s*}", r"{\1}", string) - string = re.sub(r"\"\s*([^\"]*?)\s*\"", r'"\1"', string) - string = re.sub(r"'\s*([^']*?)\s*'", r"'\1'", string) - - # Miscellaneous - string = string.replace("= = = =", "====") - string = string.replace("= = =", "===") - string = string.replace("= =", "==") - string = string.replace(" " + chr(176) + " ", chr(176)) - string = string.replace(" \n", "\n") - string = string.replace("\n ", "\n") - string = string.replace(" N ", " 1 ") - string = string.replace(" 's", "'s") - - return string - - -# Set Registry for Various Datasets -DATASET_TOKENIZATION_REGISTRY = {"wikitext": wikitext_detokenize} \ No newline at end of file diff --git a/src/clm/src/dataloaders/datasets/lm_dataset.py b/src/clm/src/dataloaders/datasets/lm_dataset.py deleted file mode 100644 index d32353a8..00000000 --- a/src/clm/src/dataloaders/datasets/lm_dataset.py +++ /dev/null @@ -1,32 +0,0 @@ -# Inspired by https://github.com/NVIDIA/Megatron-LM/blob/main/tasks/zeroshot_gpt/datasets.py -# Except we don't pad the last block and don't use overlapping eval -# And we return both the input and the target -import math -import numpy as np - -import torch - - -class LMDataset(torch.utils.data.Dataset): - - def __init__(self, tokens, seq_len, drop_last=True): - """tokens should be a numpy array - """ - self.seq_len = seq_len - ntokens = len(tokens) - if drop_last: - ntokens = ((ntokens - 1) // seq_len) * seq_len + 1 - self.ntokens = ntokens - # We're careful not to slice tokens, since it could be a memmap'ed array or H5 dataset, - # and slicing would load it to memory. - self.tokens = tokens - self.total_sequences = math.ceil((self.ntokens - 1) / self.seq_len) - - def __len__(self): - return self.total_sequences - - def __getitem__(self, idx): - start_idx = idx * self.seq_len - seq_len = min(self.seq_len, self.ntokens - 1 - start_idx) - data = torch.as_tensor(self.tokens[start_idx:(start_idx + seq_len + 1)].astype(np.int64)) - return data[:-1], data[1:].clone() \ No newline at end of file diff --git a/src/clm/src/dataloaders/et.py b/src/clm/src/dataloaders/et.py deleted file mode 100644 index 455d0a2d..00000000 --- a/src/clm/src/dataloaders/et.py +++ /dev/null @@ -1,626 +0,0 @@ -""" -ET Dataset from Informer Paper. -Dataset: https://github.com/zhouhaoyi/ETDataset -Dataloader: https://github.com/zhouhaoyi/Informer2020 -""" - -from typing import List -import os -import numpy as np -import pandas as pd -from pandas.tseries import offsets -from pandas.tseries.frequencies import to_offset -import torch -from torch.utils import data -from torch.utils.data import Dataset, DataLoader - -import warnings -warnings.filterwarnings("ignore") - -from clm.src.dataloaders.base import SequenceDataset, default_data_path - - -class TimeFeature: - def __init__(self): - pass - - def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: - pass - - def __repr__(self): - return self.__class__.__name__ + "()" - - -class SecondOfMinute(TimeFeature): - """Minute of hour encoded as value between [-0.5, 0.5]""" - - def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: - return index.second / 59.0 - 0.5 - - -class MinuteOfHour(TimeFeature): - """Minute of hour encoded as value between [-0.5, 0.5]""" - - def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: - return index.minute / 59.0 - 0.5 - - -class HourOfDay(TimeFeature): - """Hour of day encoded as value between [-0.5, 0.5]""" - - def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: - return index.hour / 23.0 - 0.5 - - -class DayOfWeek(TimeFeature): - """Hour of day encoded as value between [-0.5, 0.5]""" - - def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: - return index.dayofweek / 6.0 - 0.5 - - -class DayOfMonth(TimeFeature): - """Day of month encoded as value between [-0.5, 0.5]""" - - def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: - return (index.day - 1) / 30.0 - 0.5 - - -class DayOfYear(TimeFeature): - """Day of year encoded as value between [-0.5, 0.5]""" - - def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: - return (index.dayofyear - 1) / 365.0 - 0.5 - - -class MonthOfYear(TimeFeature): - """Month of year encoded as value between [-0.5, 0.5]""" - - def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: - return (index.month - 1) / 11.0 - 0.5 - - -class WeekOfYear(TimeFeature): - """Week of year encoded as value between [-0.5, 0.5]""" - - def __call__(self, index: pd.DatetimeIndex) -> np.ndarray: - return (index.isocalendar().week - 1) / 52.0 - 0.5 - - -def time_features_from_frequency_str(freq_str: str) -> List[TimeFeature]: - """ - Returns a list of time features that will be appropriate for the given frequency string. - Parameters - ---------- - freq_str - Frequency string of the form [multiple][granularity] such as "12H", "5min", "1D" etc. - """ - - features_by_offsets = { - offsets.YearEnd: [], - offsets.QuarterEnd: [MonthOfYear], - offsets.MonthEnd: [MonthOfYear], - offsets.Week: [DayOfMonth, WeekOfYear], - offsets.Day: [DayOfWeek, DayOfMonth, DayOfYear], - offsets.BusinessDay: [DayOfWeek, DayOfMonth, DayOfYear], - offsets.Hour: [HourOfDay, DayOfWeek, DayOfMonth, DayOfYear], - offsets.Minute: [ - MinuteOfHour, - HourOfDay, - DayOfWeek, - DayOfMonth, - DayOfYear, - ], - offsets.Second: [ - SecondOfMinute, - MinuteOfHour, - HourOfDay, - DayOfWeek, - DayOfMonth, - DayOfYear, - ], - } - - offset = to_offset(freq_str) - - for offset_type, feature_classes in features_by_offsets.items(): - if isinstance(offset, offset_type): - return [cls() for cls in feature_classes] - - supported_freq_msg = f""" - Unsupported frequency {freq_str} - The following frequencies are supported: - Y - yearly - alias: A - M - monthly - W - weekly - D - daily - B - business days - H - hourly - T - minutely - alias: min - S - secondly - """ - raise RuntimeError(supported_freq_msg) - - -def time_features(dates, timeenc=1, freq="h"): - """ - > `time_features` takes in a `dates` dataframe with a 'dates' column and extracts the date down to `freq` where freq can be any of the following if `timeenc` is 0: - > * m - [month] - > * w - [month] - > * d - [month, day, weekday] - > * b - [month, day, weekday] - > * h - [month, day, weekday, hour] - > * t - [month, day, weekday, hour, *minute] - > - > If `timeenc` is 1, a similar, but different list of `freq` values are supported (all encoded between [-0.5 and 0.5]): - > * Q - [month] - > * M - [month] - > * W - [Day of month, week of year] - > * D - [Day of week, day of month, day of year] - > * B - [Day of week, day of month, day of year] - > * H - [Hour of day, day of week, day of month, day of year] - > * T - [Minute of hour*, hour of day, day of week, day of month, day of year] - > * S - [Second of minute, minute of hour, hour of day, day of week, day of month, day of year] - *minute returns a number from 0-3 corresponding to the 15 minute period it falls into. - """ - if timeenc == 0: - dates["month"] = dates.date.apply(lambda row: row.month, 1) - dates["day"] = dates.date.apply(lambda row: row.day, 1) - dates["weekday"] = dates.date.apply(lambda row: row.weekday(), 1) - dates["hour"] = dates.date.apply(lambda row: row.hour, 1) - dates["minute"] = dates.date.apply(lambda row: row.minute, 1) - dates["minute"] = dates.minute.map(lambda x: x // 15) - freq_map = { - "y": [], - "m": ["month"], - "w": ["month"], - "d": ["month", "day", "weekday"], - "b": ["month", "day", "weekday"], - "h": ["month", "day", "weekday", "hour"], - "t": ["month", "day", "weekday", "hour", "minute"], - } - return dates[freq_map[freq.lower()]].values - if timeenc == 1: - dates = pd.to_datetime(dates.date.values) - return np.vstack( - [feat(dates) for feat in time_features_from_frequency_str(freq)] - ).transpose(1, 0) - - -class StandardScaler: - def __init__(self): - self.mean = 0.0 - self.std = 1.0 - - def fit(self, data): - self.mean = data.mean(0) - self.std = data.std(0) - - def transform(self, data): - mean = ( - torch.from_numpy(self.mean).type_as(data).to(data.device) - if torch.is_tensor(data) - else self.mean - ) - std = ( - torch.from_numpy(self.std).type_as(data).to(data.device) - if torch.is_tensor(data) - else self.std - ) - return (data - mean) / std - - def inverse_transform(self, data): - mean = ( - torch.from_numpy(self.mean).type_as(data).to(data.device) - if torch.is_tensor(data) - else self.mean - ) - std = ( - torch.from_numpy(self.std).type_as(data).to(data.device) - if torch.is_tensor(data) - else self.std - ) - return (data * std) + mean - - -class InformerDataset(Dataset): - def __init__( - self, - root_path, - flag="train", - size=None, - features="S", - data_path="ETTh1.csv", - target="OT", - scale=True, - inverse=False, - timeenc=0, - freq="h", - cols=None, - eval_stamp=False, - eval_mask=False, - ): - # size [seq_len, label_len, pred_len] - # info - if size == None: - self.seq_len = 24 * 4 * 4 - self.label_len = 24 * 4 - self.pred_len = 24 * 4 - else: - self.seq_len = size[0] - self.label_len = size[1] - self.pred_len = size[2] - # init - assert flag in ["train", "test", "val"] - type_map = {"train": 0, "val": 1, "test": 2} - self.set_type = type_map[flag] - - self.features = features - self.target = target - self.scale = scale - self.inverse = inverse - self.timeenc = timeenc - self.freq = freq - self.cols = cols - self.eval_stamp = eval_stamp - self.eval_mask = eval_mask - self.forecast_horizon = self.pred_len - - self.root_path = root_path - self.data_path = data_path - self.__read_data__() - - def _borders(self, df_raw): - num_train = int(len(df_raw) * 0.7) - num_test = int(len(df_raw) * 0.2) - num_vali = len(df_raw) - num_train - num_test - border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len] - border2s = [num_train, num_train + num_vali, len(df_raw)] - return border1s, border2s - - def _process_columns(self, df_raw): - if self.cols: - cols = self.cols.copy() - cols.remove(self.target) - else: - cols = list(df_raw.columns) - cols.remove(self.target) - cols.remove("date") - return df_raw[["date"] + cols + [self.target]] - - def __read_data__(self): - self.scaler = StandardScaler() - df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path)) - - df_raw = self._process_columns(df_raw) - - border1s, border2s = self._borders(df_raw) - border1 = border1s[self.set_type] - border2 = border2s[self.set_type] - - if self.features == "M" or self.features == "MS": - cols_data = df_raw.columns[1:] - df_data = df_raw[cols_data] - elif self.features == "S": - df_data = df_raw[[self.target]] - - if self.scale: - train_data = df_data[border1s[0] : border2s[0]] - self.scaler.fit(train_data.values) - data = self.scaler.transform(df_data.values) - else: - data = df_data.values - - df_stamp = df_raw[["date"]][border1:border2] - df_stamp["date"] = pd.to_datetime(df_stamp.date) - data_stamp = time_features(df_stamp, timeenc=self.timeenc, freq=self.freq) - - self.data_x = data[border1:border2] - if self.inverse: - self.data_y = df_data.values[border1:border2] - else: - self.data_y = data[border1:border2] - - self.data_stamp = data_stamp - - def __getitem__(self, index): - s_begin = index - s_end = s_begin + self.seq_len - r_begin = s_end - self.label_len - r_end = r_begin + self.label_len + self.pred_len - - seq_x = self.data_x[s_begin:s_end] - seq_x = np.concatenate( - [seq_x, np.zeros((self.pred_len, self.data_x.shape[-1]))], axis=0 - ) - - if self.inverse: - seq_y = np.concatenate( - [ - self.data_x[r_begin : r_begin + self.label_len], - self.data_y[r_begin + self.label_len : r_end], - ], - 0, - ) - raise NotImplementedError - else: - # seq_y = self.data_y[r_begin:r_end] # OLD in Informer codebase - seq_y = self.data_y[s_end:r_end] - - # OLD in Informer codebase - # seq_x_mark = self.data_stamp[s_begin:s_end] - # seq_y_mark = self.data_stamp[r_begin:r_end] - - if self.eval_stamp: - mark = self.data_stamp[s_begin:r_end] - else: - mark = self.data_stamp[s_begin:s_end] - mark = np.concatenate([mark, np.zeros((self.pred_len, mark.shape[-1]))], axis=0) - - if self.eval_mask: - mask = np.concatenate([np.zeros(self.seq_len), np.ones(self.pred_len)], axis=0) - else: - mask = np.concatenate([np.zeros(self.seq_len), np.zeros(self.pred_len)], axis=0) - mask = mask[:, None] - - # Add the mask to the timestamps: # 480, 5 - # mark = np.concatenate([mark, mask[:, np.newaxis]], axis=1) - - seq_x = seq_x.astype(np.float32) - seq_y = seq_y.astype(np.float32) - if self.timeenc == 0: - mark = mark.astype(np.int64) - else: - mark = mark.astype(np.float32) - mask = mask.astype(np.int64) - - return torch.tensor(seq_x), torch.tensor(seq_y), torch.tensor(mark), torch.tensor(mask) - - def __len__(self): - return len(self.data_x) - self.seq_len - self.pred_len + 1 - - def inverse_transform(self, data): - return self.scaler.inverse_transform(data) - - @property - def d_input(self): - return self.data_x.shape[-1] - - @property - def d_output(self): - if self.features in ["M", "S"]: - return self.data_x.shape[-1] - elif self.features == "MS": - return 1 - else: - raise NotImplementedError - - @property - def n_tokens_time(self): - if self.freq == 'h': - return [13, 32, 7, 24] - elif self.freq == 't': - return [13, 32, 7, 24, 4] - else: - raise NotImplementedError - - -class _Dataset_ETT_hour(InformerDataset): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def _borders(self, df_raw): - border1s = [ - 0, - 12 * 30 * 24 - self.seq_len, - 12 * 30 * 24 + 4 * 30 * 24 - self.seq_len, - ] - border2s = [ - 12 * 30 * 24, - 12 * 30 * 24 + 4 * 30 * 24, - 12 * 30 * 24 + 8 * 30 * 24, - ] - return border1s, border2s - - def _process_columns(self, df_raw): - return df_raw - - @property - def n_tokens_time(self): - assert self.freq == "h" - return [13, 32, 7, 24] - - -class _Dataset_ETT_minute(_Dataset_ETT_hour): - def __init__(self, data_path="ETTm1.csv", freq="t", **kwargs): - super().__init__(data_path=data_path, freq=freq, **kwargs) - - def _borders(self, df_raw): - border1s = [ - 0, - 12 * 30 * 24 * 4 - self.seq_len, - 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4 - self.seq_len, - ] - border2s = [ - 12 * 30 * 24 * 4, - 12 * 30 * 24 * 4 + 4 * 30 * 24 * 4, - 12 * 30 * 24 * 4 + 8 * 30 * 24 * 4, - ] - return border1s, border2s - - @property - def n_tokens_time(self): - assert self.freq == "t" - return [13, 32, 7, 24, 4] - - -class _Dataset_Weather(InformerDataset): - def __init__(self, data_path="WTH.csv", target="WetBulbCelsius", **kwargs): - super().__init__(data_path=data_path, target=target, **kwargs) - -class _Dataset_ECL(InformerDataset): - def __init__(self, data_path="ECL.csv", target="MT_320", **kwargs): - super().__init__(data_path=data_path, target=target, **kwargs) - -class InformerSequenceDataset(SequenceDataset): - - @property - def n_tokens_time(self): - # Shape of the dates: depends on `timeenc` and `freq` - return self.dataset_train.n_tokens_time # data_stamp.shape[-1] - - @property - def d_input(self): - return self.dataset_train.d_input - - @property - def d_output(self): - return self.dataset_train.d_output - - @property - def l_output(self): - return self.dataset_train.pred_len - - def _get_data_filename(self, variant): - return self.variants[variant] - - _collate_arg_names = ["mark", "mask"] # Names of the two extra tensors that the InformerDataset returns - - def setup(self): - self.data_dir = self.data_dir or default_data_path / 'informer' / self._name_ - - self.dataset_train = self._dataset_cls( - root_path=self.data_dir, - flag="train", - size=self.size, - features=self.features, - data_path=self._get_data_filename(self.variant), - target=self.target, - scale=self.scale, - inverse=self.inverse, - timeenc=self.timeenc, - freq=self.freq, - cols=self.cols, - eval_stamp=self.eval_stamp, - eval_mask=self.eval_mask, - ) - - self.dataset_val = self._dataset_cls( - root_path=self.data_dir, - flag="val", - size=self.size, - features=self.features, - data_path=self._get_data_filename(self.variant), - target=self.target, - scale=self.scale, - inverse=self.inverse, - timeenc=self.timeenc, - freq=self.freq, - cols=self.cols, - eval_stamp=self.eval_stamp, - eval_mask=self.eval_mask, - ) - - self.dataset_test = self._dataset_cls( - root_path=self.data_dir, - flag="test", - size=self.size, - features=self.features, - data_path=self._get_data_filename(self.variant), - target=self.target, - scale=self.scale, - inverse=self.inverse, - timeenc=self.timeenc, - freq=self.freq, - cols=self.cols, - eval_stamp=self.eval_stamp, - eval_mask=self.eval_mask, - ) - -class ETTHour(InformerSequenceDataset): - _name_ = "etth" - - _dataset_cls = _Dataset_ETT_hour - - init_defaults = { - "size": None, - "features": "S", - "target": "OT", - "variant": 0, - "scale": True, - "inverse": False, - "timeenc": 0, - "freq": "h", - "cols": None, - } - - variants = { - 0: "ETTh1.csv", - 1: "ETTh2.csv", - } - -class ETTMinute(InformerSequenceDataset): - _name_ = "ettm" - - _dataset_cls = _Dataset_ETT_minute - - init_defaults = { - "size": None, - "features": "S", - "target": "OT", - "variant": 0, - "scale": True, - "inverse": False, - "timeenc": 0, - "freq": "t", - "cols": None, - } - - variants = { - 0: "ETTm1.csv", - 1: "ETTm2.csv", - } - -class Weather(InformerSequenceDataset): - _name_ = "weather" - - _dataset_cls = _Dataset_Weather - - init_defaults = { - "size": None, - "features": "S", - "target": "WetBulbCelsius", - "variant": 0, - "scale": True, - "inverse": False, - "timeenc": 0, - "freq": "h", - "cols": None, - } - - variants = { - 0: "WTH.csv", - } - -class ECL(InformerSequenceDataset): - _name_ = "ecl" - - _dataset_cls = _Dataset_ECL - - init_defaults = { - "size": None, - "features": "S", - "target": "MT_320", - "variant": 0, - "scale": True, - "inverse": False, - "timeenc": 0, - "freq": "h", - "cols": None, - } - - variants = { - 0: "ECL.csv", - } diff --git a/src/clm/src/dataloaders/fault_tolerant_sampler.py b/src/clm/src/dataloaders/fault_tolerant_sampler.py deleted file mode 100644 index adab1c7f..00000000 --- a/src/clm/src/dataloaders/fault_tolerant_sampler.py +++ /dev/null @@ -1,123 +0,0 @@ -# Adapted from https://github.com/Lightning-AI/lightning/blob/2845e7565dbe6b765ae32870e7d2bc456529c30a/tests/tests_pytorch/utilities/test_auto_restart.py#L1397 -from typing import Iterator -import math - -import torch -from torch.utils.data import RandomSampler, DistributedSampler - - -class RandomFaultTolerantSampler(RandomSampler): - - def __init__(self, *args, generator=None, **kwargs): - # generator = torch.Generator().manual_seed(seed) - # super().__init__(*args, generator=generator, **kwargs) - # TD [2022-07-17]: We don't force the seed to be zero. We generate random seed, - # which should be reproducible if pl.seed_everything was called before hand. - # This means that changing the seed of the experiment will also change the - # sampling order. - if generator is None: - seed = int(torch.empty((), dtype=torch.int64).random_().item()) - generator = torch.Generator().manual_seed(seed) - super().__init__(*args, generator=generator, **kwargs) - self.counter = 0 - # self.start_counter = 0 - self.restarting = False - - def state_dict(self): - return {"random_state": self.state, "counter": self.counter} - - def load_state_dict(self, state_dict): - self.generator.set_state(state_dict.get("random_state")) - self.counter = state_dict["counter"] - # self.start_counter = self.counter - self.restarting = True - - # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per - # epoch, and subsequent epoch will have very few batches. - # def __len__(self): - # # We need a separate self.start_counter because PL seems to call len repeatedly. - # # If we use len(self.data_source) - self.counter then PL will think the epoch ends - # # when we're only half way through. - # return len(self.data_source) - self.start_counter - - def __iter__(self) -> Iterator[int]: - n = len(self.data_source) - - self.state = self.generator.get_state() - indices = torch.randperm(n, generator=self.generator).tolist() - - if not self.restarting: - self.counter = 0 - else: - indices = indices[self.counter:] - self.restarting = False - # self.start_counter = self.counter - - for index in indices: - self.counter += 1 - yield index - - self.counter = 0 - # self.start_counter = self.counter - - -class FaultTolerantDistributedSampler(DistributedSampler): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.counter = 0 - # self.start_counter = 0 - self.restarting = False - - def state_dict(self): - return {"epoch": self.epoch, "counter": self.counter} - - def load_state_dict(self, state_dict): - self.epoch = state_dict["epoch"] - self.counter = state_dict["counter"] - # self.start_counter = self.counter - self.restarting = True - - # TD [2022-08-28] Setting the len will cause PL to think there are only a few batches left per - # epoch, and subsequent epoch will have very few batches. - # def __len__(self) -> int: - # return self.num_samples - self.start_counter - - def __iter__(self): - if self.shuffle: - # deterministically shuffle based on epoch and seed - g = torch.Generator() - g.manual_seed(self.seed + self.epoch) - indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] - else: - indices = list(range(len(self.dataset))) # type: ignore[arg-type] - - if not self.drop_last: - # add extra samples to make it evenly divisible - padding_size = self.total_size - len(indices) - if padding_size <= len(indices): - indices += indices[:padding_size] - else: - indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] - else: - # remove tail of data to make it evenly divisible. - indices = indices[:self.total_size] - assert len(indices) == self.total_size - - # subsample - indices = indices[self.rank:self.total_size:self.num_replicas] - assert len(indices) == self.num_samples - - if not self.restarting: - self.counter = 0 - else: - indices = indices[self.counter:] - self.restarting = False - # self.start_counter = self.counter - - for index in indices: - self.counter += 1 - yield index - - self.counter = 0 - # self.start_counter = self.counter \ No newline at end of file diff --git a/src/clm/src/dataloaders/language_modeling_hf.py b/src/clm/src/dataloaders/language_modeling_hf.py deleted file mode 100644 index c17e66b6..00000000 --- a/src/clm/src/dataloaders/language_modeling_hf.py +++ /dev/null @@ -1,311 +0,0 @@ -# Adapted from https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_clm.py -# Adapted from https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py -from itertools import chain -from pathlib import Path -import pickle -from typing import Any, List, Union -import subprocess -import mmap - -from multiprocessing.shared_memory import SharedMemory - -import numpy as np - -import torch -from torch.utils.data.dataloader import DataLoader, Dataset -from transformers import AutoTokenizer -from datasets import load_dataset - -from clm.src.dataloaders.base import SequenceDataset, default_data_path - -from clm.src.dataloaders.datasets.lm_dataset import LMDataset -from clm.src.dataloaders.fault_tolerant_sampler import RandomFaultTolerantSampler -from clm.src.dataloaders.fault_tolerant_sampler import FaultTolerantDistributedSampler -from clm.src.dataloaders.datasets.detokenizer import DATASET_TOKENIZATION_REGISTRY -from clm.src.utils.train import get_logger -logger = get_logger() - - -# https://github.com/numpy/numpy/issues/18294 -class SHMArray(np.ndarray): #copied from https://numpy.org/doc/stable/user/basics.subclassing.html#slightly-more-realistic-example-attribute-added-to-existing-array - - def __new__(cls, input_array, shm=None): - obj = np.asarray(input_array).view(cls) - obj.shm = shm - return obj - - def __array_finalize__(self, obj): - if obj is None: return - self.shm = getattr(obj, 'shm', None) - - -class LMDataModuleWT103(SequenceDataset): - _name_ = "wt103" - - def __init__(self, dataset_name, tokenizer_name, dataset_config_name=None, max_length=1024, - cache_dir=None, val_ratio=0.0005, val_split_seed=2357, add_eos=True, - detokenize=False, val_only=False, batch_size=32, batch_size_eval=None, num_workers=1, - shuffle=False, pin_memory=False, drop_last=False, fault_tolerant=False, ddp=False, - fast_forward_epochs=None, fast_forward_batches=None, - use_shmem=True, *args, **kwargs): - self.dataset_name = dataset_name - self.dataset_config_name = dataset_config_name - self.tokenizer_name = tokenizer_name - self.cache_dir = None if cache_dir is None else Path(cache_dir).expanduser() - self.max_length = max_length - self.val_ratio = val_ratio - self.val_split_seed = val_split_seed - self.val_only = val_only - self.add_eos = add_eos - self.detokenize = detokenize - self.batch_size = batch_size - self.batch_size_eval = batch_size_eval if batch_size_eval is not None else self.batch_size - self.num_workers = num_workers - self.shuffle = shuffle - self.pin_memory = pin_memory - self.drop_last = drop_last - if fault_tolerant: - assert self.shuffle - self.fault_tolerant = fault_tolerant - if ddp: - assert fault_tolerant - self.ddp = ddp - self.fast_forward_epochs = fast_forward_epochs - self.fast_forward_batches = fast_forward_batches - if self.fast_forward_epochs is not None or self.fast_forward_batches is not None: - assert ddp and fault_tolerant - - self.use_shmem = use_shmem - if self.use_shmem: - assert cache_dir is not None - - def prepare_data(self): - if self.cache_dir is None: # Just download the dataset - load_dataset(self.dataset_name, self.dataset_config_name) - else: # Process the dataset and save it - self.process_dataset() - - def setup(self, stage=None): - if stage == 'test' and hasattr(self, 'dataset_test'): - return - concat_ids, self.tokenizer = self.process_dataset() - self.vocab_size = len(self.tokenizer) - # Create all splits - self.dataset_train, self.dataset_val, self.dataset_test = [ - LMDataset(concat_ids[split], seq_len=self.max_length) - for split in ['train', 'validation', 'test'] - ] - - def process_dataset(self): - cache_dir = None if self.cache_dir is None else self.cache_dir / self._cache_dir_name - if cache_dir is not None: - if cache_dir.is_dir(): - return self._load_from_cache(cache_dir) - - raw_datasets = load_dataset(self.dataset_name, self.dataset_config_name) - # https://github.com/stanford-crfm/mistral/blob/main/src/corpora/auto.py - if 'validation' not in raw_datasets: - assert "train" in raw_datasets, "You must have train in raw_datasets to make a validation raw_datasets" - raw_datasets = raw_datasets["train"].train_test_split( - test_size=self.val_ratio, seed=self.val_split_seed, - shuffle=True # Otherwise test will be at the end of the dataset - ) - raw_datasets['validation'] = raw_datasets['test'] - - if self.val_only: # Should only be used for evaluation, not for training - raw_datasets['train'] = raw_datasets['validation'] - - # [2021-12-25] TD: Running the detokenizer on wikitext-103 makes ppl worse - # (GPT2-small val ppl after 10 epochs ~22 -> ~25) - # However, it's useful for zero-shot transfer from Openwebtext, - # as after detokenization it's closer to Openwebtext's format. - # https://github.com/stanford-crfm/mistral/issues/12 - if self.detokenize: - if self.dataset_name in DATASET_TOKENIZATION_REGISTRY: - detokenizer = DATASET_TOKENIZATION_REGISTRY[self.dataset_name] - raw_datasets = raw_datasets.map( - lambda example: {'text': detokenizer(example['text'])}, - num_proc=max(self.num_workers, 1), - desc='Running detokenizer on dataset' - ) - - tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, use_fast=True) - # Preprocessing the datasets. - # First we tokenize all the texts. - column_names = raw_datasets["train"].column_names - text_column_name = "text" if "text" in column_names else column_names[0] - # [2021-12-25] TD: For wikitext, don't need to add the EOS since each example already ends - # with '\n', and there are no other '\n' in the examples. - # assert all([t.count('\n') == 1 for t in raw_datasets['train']['text'] if t]) - # Add EOS token to the end of the text if the text is not empty - # https://github.com/stanford-crfm/mistral/issues/91 - # https://github.com/stanford-crfm/mistral/pull/98 - if self.add_eos: - add_eos = lambda seq: (seq + tokenizer.eos_token) if seq else seq - add_eos_batched = lambda seqs: [add_eos(seq) for seq in seqs] - tokenize = lambda example: tokenizer(add_eos_batched(example[text_column_name])) - else: - tokenize = lambda example: tokenizer(example[text_column_name]) - # tokenized_datasets = raw_datasets.map( - # tokenize, - # batched=True, - # num_proc=max(self.num_workers, 1), - # remove_columns=column_names, - # desc="Running tokenizer on dataset", - # ) - dtype = np.uint16 if tokenizer.vocab_size < 64 * 1024 else np.int32 - def tokenize_concat(examples): - # We just need 'input_ids', not 'attention_mask' (since it's all 1) - input_ids = np.fromiter(chain(*tokenize(examples)['input_ids']), dtype=dtype) - # Need to return a list since we're doing batched processing - return {'input_ids': [input_ids], 'len': [len(input_ids)]} - tokenized_datasets = raw_datasets.map( - tokenize_concat, - batched=True, - num_proc=max(self.num_workers, 1), - remove_columns=column_names, - desc="Running tokenizer on dataset", - ) - - if self.use_shmem: - # Concatenate all input_ids into an array in shared memory - def write_ids_to_shm(example, shm_name, array_len): - shm = SharedMemory(name=shm_name) - shm_arr = np.ndarray((array_len,), dtype=dtype, buffer=shm.buf) - start_idx = example['len_offset'] - len(example['input_ids']) - shm_arr[start_idx:example['len_offset']] = example['input_ids'] - shm.close() - concat_ids = {} - for name, ds in tokenized_datasets.items(): - tokenized_datasets[name] = ds.add_column('len_offset', np.cumsum(ds['len'])) - array_len = tokenized_datasets[name][-1]['len_offset'] - shm = SharedMemory(create=True, size=array_len * np.dtype(dtype).itemsize) - shm_name = shm.name - tokenized_datasets[name].map( - write_ids_to_shm, - fn_kwargs={'shm_name': shm_name, 'array_len': array_len}, - batched=False, - num_proc=max(self.num_workers, 1), - desc="Concatenating examples", - ) - shm_arr = np.ndarray((array_len,), dtype=dtype, buffer=shm.buf) - # We need to keep a reference to the shared memory, otherwise it gets garbage-collected - # when it goes out of scope, and that memory is gone. - # https://github.com/numpy/numpy/issues/18294 - concat_ids[name] = SHMArray(shm_arr, shm=shm) - else: - # Use disk - concat_ids = {} - assert cache_dir is not None - cache_dir.mkdir(parents=True, exist_ok=True) - def write_ids_to_disk(example, filename): - with open(filename, 'r+b') as f: - mm = mmap.mmap(f.fileno(), 0) - start_idx = example['len_offset'] - len(example['input_ids']) - array_len = len(example['input_ids']) - arr = np.ndarray((array_len,), dtype=dtype, buffer=mm, - offset=np.dtype(dtype).itemsize * start_idx) - arr[:] = example['input_ids'] - mm.flush() - for name, ds in tokenized_datasets.items(): - tokenized_datasets[name] = ds.add_column('len_offset', np.cumsum(ds['len'])) - array_len = tokenized_datasets[name][-1]['len_offset'] - filename = cache_dir / f'{name}.bin' - # Need to create the file with this specific size first - # https://ostechnix.com/create-files-certain-size-linux/ - subprocess.run(['truncate', '-s', str(array_len * np.dtype(dtype).itemsize), - str(filename)], check=True) - tokenized_datasets[name].map( - write_ids_to_disk, - fn_kwargs={'filename': filename}, - batched=False, - num_proc=max(self.num_workers, 1), - desc="Concatenating examples", - ) - concat_ids[name] = np.memmap(filename, dtype=dtype, mode='r', shape=(array_len,)) - - if cache_dir is not None: - self._save_to_cache(concat_ids, tokenizer, cache_dir) - if not self.use_shmem: - for name in concat_ids: - Path(cache_dir / f'{name}.bin').unlink() - return concat_ids, tokenizer - - def _save_to_cache(self, concat_ids, tokenizer, cache_dir): - cache_dir.mkdir(parents=True, exist_ok=True) - logger.info(f'Saving to cache at {str(cache_dir)}') - for k, v in concat_ids.items(): - np.save(cache_dir / f'{k}.npy', v) - with open(cache_dir / 'tokenizer.pkl', 'wb') as f: - pickle.dump(tokenizer, f) - - def _load_from_cache(self, cache_dir): - assert cache_dir.is_dir() - logger.info(f'Load from cache at {str(cache_dir)}') - concat_ids = {split: np.load(cache_dir / f'{split}.npy', mmap_mode='r') - for split in ['train', 'validation', 'test']} - with open(cache_dir / 'tokenizer.pkl', 'rb') as f: - tokenizer = pickle.load(f) - return concat_ids, tokenizer - - @property - def _cache_dir_name(self): - return f'tokenizer_name-{self.tokenizer_name}-val_ratio-{self.val_ratio}-val_split_seed-{self.val_split_seed}-add_eos-{self.add_eos}-detokenize-{self.detokenize}' - - def train_dataloader(self, *args: Any, **kwargs: Any) -> DataLoader: - """ The train dataloader """ - if self.shuffle and self.fault_tolerant: - shuffle = False - # TD [2022-12-26]: We need the distributed_sampler_kwargs in case of model parallel: - # In that case the number of replicas and the data parallel rank are more complicated. - distributed_sampler_kwargs = self.trainer.distributed_sampler_kwargs - sampler = (FaultTolerantDistributedSampler(self.dataset_train, - **self.trainer.distributed_sampler_kwargs) - if self.ddp else RandomFaultTolerantSampler(self.dataset_train)) - # TD [2022-08-06]: Only the DDP sampler supports fast-forwarding for now - # We assume that it's being resumed with the same number of GPUs - if self.ddp and self.fast_forward_epochs is not None and self.fast_forward_batches is not None: - sampler.load_state_dict({ - 'epoch': self.fast_forward_epochs, - 'counter': self.fast_forward_batches * self.batch_size - }) - else: - shuffle = self.shuffle - sampler = None - return self._data_loader(self.dataset_train, batch_size=self.batch_size, - shuffle=shuffle, sampler=sampler) - - def val_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: - """ The val dataloader """ - return self._data_loader(self.dataset_val, batch_size=self.batch_size_eval) - - def test_dataloader(self, *args: Any, **kwargs: Any) -> Union[DataLoader, List[DataLoader]]: - """ The test dataloader """ - return self._data_loader(self.dataset_test, batch_size=self.batch_size_eval) - - def _data_loader(self, dataset: Dataset, batch_size: int, shuffle: bool = False, - sampler=None) -> DataLoader: - return DataLoader( - dataset, - batch_size=batch_size, - num_workers=1, # Data is already in memory, we don't need many workers - shuffle=shuffle, - sampler=sampler, - drop_last=self.drop_last, - pin_memory=self.pin_memory, - # persistent_workers=True - ) - - def load_state_dict(self, checkpoint): - if self.fault_tolerant: - self.fast_forward_epochs = checkpoint['loops']['fit_loop']['epoch_progress']['current']['completed'] - # TD [2022-08-07] ['epoch_loop.batch_progress']['total']['completed'] is 1 iteration - # behind, so we're using the optimizer's progress. This is set correctly in seq.py. - self.fast_forward_batches = checkpoint['loops']['fit_loop']['epoch_loop.batch_progress']['current']['completed'] - # At this point the train loader hasn't been constructed yet - -class LMDataModuleOWT(LMDataModuleWT103): - _name_ = "owt" - -class LMDataModulePile(LMDataModuleWT103): - _name_ = "the_pile" \ No newline at end of file diff --git a/src/clm/src/dataloaders/lm.py b/src/clm/src/dataloaders/lm.py deleted file mode 100644 index 9f1ce486..00000000 --- a/src/clm/src/dataloaders/lm.py +++ /dev/null @@ -1,507 +0,0 @@ -# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import os -import re -import subprocess -from pathlib import Path - -from typing import Optional, List, Tuple -import math -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -import functools -from omegaconf import DictConfig -from pytorch_lightning import LightningDataModule - - -from clm.src.utils import distributed -import clm.src.utils.train -log = clm.src.utils.train.get_logger(__name__) - - -from clm.src.dataloaders.base import SequenceDataset, default_data_path -from clm.src.dataloaders.utils.vocabulary import OpenAIVocab, Vocab -import clm.src.utils as utils - -project_root = Path(__file__).parent.parent.absolute() -data_path = Path(__file__).absolute().parent / 'data' - -import sys - -sys.path.insert(0, str(project_root)) - -class LMOrderedIterator: - def __init__( - self, - data, - batch_size, - l_max, - batch_first=True, - n_context=1, - n_epoch_double=0, - pad_last=False, - roll_seed=None, # roll data based on seed - limit_tokens=1.0, # reduce tokens; useful for debugging last batch edge cases - ): - """ - data -- LongTensor -- the LongTensor is strictly ordered - pad_last: whether to pad the last sequence in the batch so that all sequences - have the same length (l_max). - """ - self.raw_data = data - self.batch_size = batch_size - self.l_max = l_max - self.batch_first = batch_first - self.pad_last = pad_last - self.roll_seed = roll_seed - self.n_context = n_context - self.n_epoch_double = n_epoch_double - - self.epoch = -1 - - # DDP - self.world_size = distributed.get_world_size() - self.rank = distributed.get_rank() - - if limit_tokens is not None and 0.0 < limit_tokens < 1.0: - l_data = int(math.floor(data.size(-1) * limit_tokens)) - self.raw_data = self.raw_data[:l_data] - - self.process() - - def process(self): - """ Process the data. All logic involving sequence length and batch size should go here """ - assert self.l_max % self.n_context == 0 - self.l_inc = self.l_max // self.n_context - - global_batch_size = self.world_size * self.batch_size - - # Work out how cleanly we can divide the dataset into batch_size parts. - n_step = self.raw_data.size(-1) // global_batch_size - - # Trim off any extra elements that wouldn't cleanly fit (remainders). - self.data = self.raw_data[: n_step * global_batch_size] - - # Evenly divide the data across the batches. - self.data = self.data.view(global_batch_size, -1).contiguous().pin_memory() # (global_batch_size, length) - - # Partition data for DistributedDataParallel - self.data = self.data.chunk(self.world_size, dim=0)[self.rank] - - # Number of mini-batches - # Need to subtract 1 because target is data shifted by 1 - self.n_batch = (self.data.size(-1) - 1 + self.l_inc - 1) // self.l_inc - - def roll(self, seed): - rng = torch.Generator() - rng.manual_seed(seed) - for i in range(self.data.size(0)): - row = self.data[i, :] - shift = torch.randint(0, self.data.size(-1), (1,), generator=rng) - row = torch.cat((row[shift:], row[:shift])) - self.data[i, :] = row - - def get_batch(self, i): - """ Get batch starting at token index i """ - - end_idx = min(i + self.l_inc, self.data.size(-1)-1) - beg_idx = max(0, i + self.l_inc - self.l_max) - seq_len = end_idx - i - - data = self.data[..., beg_idx:end_idx] - target = self.data[..., i+1 : end_idx+1] - - if self.pad_last and seq_len < self.l_inc: - data = F.pad(data, (0, self.l_inc - seq_len)) # (batch_size, l_inc) - target = F.pad(target, (0, self.l_inc - seq_len)) - seq_len = self.l_inc - - if not self.batch_first: - data = data.transpose(0, 1).contiguous() # (n_batch, l_sequence) - target = target.transpose(0, 1).contiguous() - - return data, target, {"l_output": seq_len} # Return length of desired output - - def get_fixlen_iter(self, start=0): - if start != 0: - start += self.l_max - for i in range(start, self.data.size(-1) - 1, self.l_inc): - self.last_iter = i - yield self.get_batch(i) - - def get_varlen_iter(self, start=0, std=5, min_len=5, max_deviation=3): # NOTE: NOT TESTED - l_max = self.l_max + max_deviation * std - i = start - while True: - l_max = self.l_max if np.random.random() < 0.95 else self.l_max / 2.0 - l_max = min(l_max, max(min_len, int(np.random.normal(l_max, std)))) - data, target, seq_len = self.get_batch(i, l_max) # AG: this doesn't appear to work... - i += seq_len - yield data, target, seq_len - if i >= self.data.size(-1) - 2: - break - - def __iter__(self): - self.epoch += 1 - if (n := self.n_epoch_double) > 0 and self.epoch > 0 and self.epoch % n == 0: - if self.batch_size > 1: - log.info(f"LM Iterator doubling length from {self.l_max} to {self.l_max*2}") - self.l_max *= 2 - self.batch_size //= 2 - self.process() - - if self.roll_seed is not None: - self.roll(self.roll_seed + self.epoch) - return self.get_fixlen_iter() - - def __len__(self): - return self.n_batch - - -class LMShuffledIterator(object): - # NOTE: Not tested - def __init__( - self, data, batch_size, l_max, device="cpu", ext_len=None, shuffle=False - ): - """ - data -- list[LongTensor] -- there is no order among the LongTensors - """ - self.data = data - - self.batch_size = batch_size - self.l_max = l_max - self.ext_len = ext_len if ext_len is not None else 0 - - self.device = device - self.shuffle = shuffle - - def get_sent_stream(self): - # index iterator - epoch_indices = ( - np.random.permutation(len(self.data)) - if self.shuffle - else np.array(range(len(self.data))) - ) - - # sentence iterator - for idx in epoch_indices: - yield self.data[idx] - - def stream_iterator(self, sent_stream): - # streams for each data in the batch - streams = [None] * self.batch_size - - data = torch.LongTensor(self.l_max, self.batch_size) - target = torch.LongTensor(self.l_max, self.batch_size) - - n_retain = 0 - - while True: - # data : [n_retain+l_max x batch_size] - # target : [l_max x batch_size] - data[n_retain:].fill_(-1) - target.fill_(-1) - - valid_batch = True - - for i in range(self.batch_size): - n_filled = 0 - try: - while n_filled < self.l_max: - if streams[i] is None or len(streams[i]) <= 1: - streams[i] = next(sent_stream) - # number of new tokens to fill in - n_new = min(len(streams[i]) - 1, self.l_max - n_filled) - # first n_retain tokens are retained from last batch - data[ - n_retain + n_filled : n_retain + n_filled + n_new, - i, - ] = streams[i][:n_new] - target[n_filled : n_filled + n_new, i] = streams[i][ - 1 : n_new + 1 - ] - streams[i] = streams[i][n_new:] - n_filled += n_new - except StopIteration: - valid_batch = False - break - - if not valid_batch: - return - - data = data.to(self.device) - target = target.to(self.device) - - yield data, target, self.l_max - - n_retain = min(data.size(0), self.ext_len) - if n_retain > 0: - data[:n_retain] = data[-n_retain:] - data.resize_(n_retain + self.l_max, data.size(1)) - - def __iter__(self): - # sent_stream is an iterator - sent_stream = self.get_sent_stream() - - for batch in self.stream_iterator(sent_stream): - yield batch - - -class LMMultiFileIterator(LMShuffledIterator): - # NOTE: Not tested - def __init__( - self, - paths, - vocab, - batch_size, - l_max, - device="cpu", - ext_len=None, - shuffle=False, - ): - - self.paths = paths - self.vocab = vocab - - self.batch_size = batch_size - self.l_max = l_max - self.ext_len = ext_len if ext_len is not None else 0 - - self.device = device - self.shuffle = shuffle - - def get_sent_stream(self, path): - sents = self.vocab.encode_file(path, add_double_eos=True) - if self.shuffle: - np.random.shuffle(sents) - sent_stream = iter(sents) - - return sent_stream - - def __iter__(self): - if self.shuffle: - np.random.shuffle(self.paths) - - for path in self.paths: - # sent_stream is an iterator - sent_stream = self.get_sent_stream(path) - for batch in self.stream_iterator(sent_stream): - yield batch - - -class WikiText2(SequenceDataset): - _name_ = "wt2" - - # Vocab arguments - vocab_kwargs = {"special": [""], "lower_case": False} - encode_kwargs = {"ordered": True} - - init_defaults = { - # Dataset arguments - 'l_max': 512, - 'bpe': False, - 'roll_seed': 42, - 'test_split': True, - } - - @property - def n_tokens(self): - return len(self.vocab) - - def prepare_data(self): - # [21-09-23] probably broken - if not self.data_dir.exists(): - subprocess.run( - [ - str(project_root / "data" / "getdata.sh"), - self._name_, - str(self.data_dir.parent.absolute()), - ], - check=True, - ) - - def setup(self, stage=None): # [21-09-10 AG]: TODO shouldn't this tokenization happen in the prepare_data? since we're caching it it doesn't really matter, but still - if self.data_dir is None: self.data_dir = default_data_path / self._name_ - if self.bpe: - self.vocab = OpenAIVocab() - else: - self.vocab = Vocab(**self.vocab_kwargs) - - # Loader arguments - if not self._load_from_cache(): - logging.info(f"Producing dataset {self._name_}...") - self._vocab_count() - self.vocab.build_vocab() - self.train = self.vocab.encode_file( - str(self.data_dir / "train.txt"), **self.encode_kwargs - ) - self.valid = self.vocab.encode_file( - str(self.data_dir / "valid.txt"), **self.encode_kwargs - ) - self.test = self.vocab.encode_file( - str(self.data_dir / "test.txt"), **self.encode_kwargs - ) - self._save_to_cache() - - # No test set if specified - if not self.test_split: - self.test = None - - # Define task - print("Vocab size:", len(self.vocab)) - - def _vocab_count(self): - self.vocab.count_file(self.data_dir / "train.txt") - self.vocab.count_file(self.data_dir / "valid.txt") - self.vocab.count_file(self.data_dir / "test.txt") - - def _save_to_cache(self): - cache_path = self.data_dir / f"cache.pt" # TODO name could include vocab_kwargs to disambiguate - with distributed.sync_workers() as rank: - if rank == 0: - try: - torch.save( - (self.vocab, self.train, self.valid, self.test), - cache_path, - ) - logging.info(f"Saved dataset to {cache_path}...") - except: - pass - - def _load_from_cache(self): - cache_path = self.data_dir / f"cache.pt" - if cache_path.exists(): - logging.info("Loading cached dataset...") - self.vocab, self.train, self.valid, self.test = torch.load( - cache_path - ) - return True - else: - return False - - def train_dataloader(self, eval=None, **kwargs): - # TODO kwargs absorbs num_workers - return LMOrderedIterator( - self.train, - roll_seed=self.roll_seed, - **kwargs, - ) - - # def val_dataloader(self, batch_size, **kwargs): - def _eval_dataloader(self, dataset, eval=None, **loader_args): - if dataset is None: return None - # Make eval a list of dictionaries - if eval is None: eval = {} - if not utils.is_list(eval): - eval = [eval] - # Each eval setting overrides the train setting - for eval_args in eval: - for k in loader_args: - if eval_args.get(k, None) is None: - eval_args[k] = loader_args[k] - print("eval loader:", eval_args) - loaders = [LMOrderedIterator(dataset, **eval_args) for eval_args in eval] - if len(loaders) == 1: return loaders[0] - return loaders - - def val_dataloader(self, **kwargs): - return self._eval_dataloader(self.valid, **kwargs) - - def test_dataloader(self, **kwargs): - return self._eval_dataloader(self.test, **kwargs) - - -class WikiText103(WikiText2): - _name_ = "wt103" - - def _vocab_count(self): - print(self.data_dir) - self.vocab.count_file(self.data_dir / "train.txt") - - -class PennTreeBank(WikiText2): - - _name_ = "ptb" - vocab_kwargs = {"special": [""], "lower_case": True} - -class EnWik8(WikiText2): - _name_ = "enwik8" - - vocab_kwargs = {} - encode_kwargs = {"ordered": True, "add_eos": False} - - -class Text8(EnWik8): - - _name_ = "text8" - - -class LM1B(WikiText2): - # [21-09-08 AG]: this looks very out of date, the __init__ function should be inherited - - _name_ = "lm1b" - vocab_kwargs = {"special": [], "lower_case": False} - cutoffs = [59997, 99997, 639997] - tie_projs = [False] + [False] * len(cutoffs) - - def __init__(self, data_dir, bpe=False, *args, **kwargs): - LightningDataModule.__init__(self) - self.data_dir = Path(data_dir) - # self.vocab_type = vocab - if bpe: - self.vocab = OpenAIVocab() - else: - self.vocab = Vocab( - vocab_file=self.data_dir / "1b_word_vocab.txt", - **self.vocab_kwargs, - ) - - def setup(self, stage=None): - if not self._load_from_cache(): - logging.info(f"Producing dataset {self._name_}...") - # the vocab will load from file when build_vocab() is called - self.vocab.build_vocab() - train_paths = list( - ( - self.data_dir - / "1-billion-word-language-modeling-benchmark-r13output" - / "training-monolingual.tokenized.shuffled" - ).glob("news.en-*") - ) - self.train = train_paths - self.valid = self.vocab.encode_file( - str(self.data_dir / "valid.txt"), - ordered=False, - add_double_eos=True, - ) - self.test = self.vocab.encode_file( - str(self.data_dir / "test.txt"), - ordered=False, - add_double_eos=True, - ) - self._save_to_cache() - - def train_dataloader(self, *args, **kwargs): - kwargs["shuffle"] = True - return LMMultiFileIterator(self.train, self.vocab, *args, **kwargs) - - def val_dataloader(self, *args, **kwargs): - return LMShuffledIterator(self.valid, *args, **kwargs) - - def test_dataloader(self, *args, **kwargs): - return LMShuffledIterator(self.test, *args, **kwargs) diff --git a/src/clm/src/dataloaders/lra.py b/src/clm/src/dataloaders/lra.py deleted file mode 100644 index 624129f1..00000000 --- a/src/clm/src/dataloaders/lra.py +++ /dev/null @@ -1,689 +0,0 @@ -"""Long Range Arena datasets""" -import io -import logging -import os -import pickle -from pathlib import Path - -import torch -from torch import nn -import torch.nn.functional as F -import torchtext -import torchvision -from einops.layers.torch import Rearrange, Reduce -from PIL import Image # Only used for Pathfinder -from datasets import DatasetDict, Value, load_dataset - -from clm.src.dataloaders.base import default_data_path, SequenceDataset, ImageResolutionSequenceDataset - - -class IMDB(SequenceDataset): - _name_ = "imdb" - d_output = 2 - l_output = 0 - - @property - def init_defaults(self): - return { - "l_max": 4096, - "level": "char", - "min_freq": 15, - "seed": 42, - "val_split": 0.0, - "append_bos": False, - "append_eos": True, - # 'max_vocab': 135, - "n_workers": 4, # Only used for tokenizing dataset before caching - } - - @property - def n_tokens(self): - return len(self.vocab) - - def prepare_data(self): - if self.cache_dir is None: # Just download the dataset - load_dataset(self._name_, cache_dir=self.data_dir) - else: # Process the dataset and save it - self.process_dataset() - - def setup(self, stage=None): - """If cache_dir is not None, we'll cache the processed dataset there.""" - self.data_dir = self.data_dir or default_data_path / self._name_ - self.cache_dir = self.data_dir / "cache" - assert self.level in [ - "word", - "char", - ], f"level {self.level} not supported" - - if stage == "test" and hasattr(self, "dataset_test"): - return - dataset, self.tokenizer, self.vocab = self.process_dataset() - print( - f"IMDB {self.level} level | min_freq {self.min_freq} | vocab size {len(self.vocab)}" - ) - dataset.set_format(type="torch", columns=["input_ids", "label"]) - - # Create all splits - dataset_train, self.dataset_test = dataset["train"], dataset["test"] - if self.val_split == 0.0: - # Use test set as val set, as done in the LRA paper - self.dataset_train, self.dataset_val = dataset_train, None - else: - train_val = dataset_train.train_test_split( - test_size=self.val_split, seed=self.seed - ) - self.dataset_train, self.dataset_val = ( - train_val["train"], - train_val["test"], - ) - - def _collate_fn(self, batch): - xs, ys = zip(*[(data["input_ids"], data["label"]) for data in batch]) - lengths = torch.tensor([len(x) for x in xs]) - xs = nn.utils.rnn.pad_sequence( - xs, padding_value=self.vocab[""], batch_first=True - ) - ys = torch.tensor(ys) - return xs, ys, {"lengths": lengths} - - # self._collate_fn = collate_batch - - def process_dataset(self): - cache_dir = ( - None if self.cache_dir is None else self.cache_dir / self._cache_dir_name - ) - if cache_dir is not None: - if cache_dir.is_dir(): - return self._load_from_cache(cache_dir) - - dataset = load_dataset(self._name_, cache_dir=self.data_dir) - dataset = DatasetDict(train=dataset["train"], test=dataset["test"]) - if self.level == "word": - tokenizer = torchtext.data.utils.get_tokenizer( - "spacy", language="en_core_web_sm" - ) - else: # self.level == 'char' - tokenizer = list # Just convert a string to a list of chars - # Account for and tokens - l_max = self.l_max - int(self.append_bos) - int(self.append_eos) - tokenize = lambda example: {"tokens": tokenizer(example["text"])[:l_max]} - dataset = dataset.map( - tokenize, - remove_columns=["text"], - keep_in_memory=True, - load_from_cache_file=False, - num_proc=max(self.n_workers, 1), - ) - vocab = torchtext.vocab.build_vocab_from_iterator( - dataset["train"]["tokens"], - min_freq=self.min_freq, - specials=( - ["", ""] - + ([""] if self.append_bos else []) - + ([""] if self.append_eos else []) - ), - ) - vocab.set_default_index(vocab[""]) - - numericalize = lambda example: { - "input_ids": vocab( - ([""] if self.append_bos else []) - + example["tokens"] - + ([""] if self.append_eos else []) - ) - } - dataset = dataset.map( - numericalize, - remove_columns=["tokens"], - keep_in_memory=True, - load_from_cache_file=False, - num_proc=max(self.n_workers, 1), - ) - - if cache_dir is not None: - self._save_to_cache(dataset, tokenizer, vocab, cache_dir) - return dataset, tokenizer, vocab - - def _save_to_cache(self, dataset, tokenizer, vocab, cache_dir): - cache_dir = self.cache_dir / self._cache_dir_name - logger = logging.getLogger(__name__) - logger.info(f"Saving to cache at {str(cache_dir)}") - dataset.save_to_disk(str(cache_dir)) - with open(cache_dir / "tokenizer.pkl", "wb") as f: - pickle.dump(tokenizer, f) - with open(cache_dir / "vocab.pkl", "wb") as f: - pickle.dump(vocab, f) - - def _load_from_cache(self, cache_dir): - assert cache_dir.is_dir() - logger = logging.getLogger(__name__) - logger.info(f"Load from cache at {str(cache_dir)}") - dataset = DatasetDict.load_from_disk(str(cache_dir)) - with open(cache_dir / "tokenizer.pkl", "rb") as f: - tokenizer = pickle.load(f) - with open(cache_dir / "vocab.pkl", "rb") as f: - vocab = pickle.load(f) - return dataset, tokenizer, vocab - - @property - def _cache_dir_name(self): - return f"l_max-{self.l_max}-level-{self.level}-min_freq-{self.min_freq}-append_bos-{self.append_bos}-append_eos-{self.append_eos}" - -class TabularDataset(torch.utils.data.Dataset): - def __init__( - self, - path, - format, - col_idx=None, - skip_header=False, - csv_reader_params=None, - ): - """ - col_idx: the indices of the columns. - """ - if csv_reader_params is None: - csv_reader_params = {} - format = format.lower() - assert format in ["tsv", "csv"] - with io.open(os.path.expanduser(path), encoding="utf8") as f: - if format == "csv": - reader = torchtext.utils.unicode_csv_reader(f, **csv_reader_params) - elif format == "tsv": - reader = torchtext.utils.unicode_csv_reader( - f, delimiter="\t", **csv_reader_params - ) - else: - reader = f - if skip_header: - next(reader) - self._data = [ - line if col_idx is None else [line[c] for c in col_idx] - for line in reader - ] - - def __len__(self): - return len(self._data) - - def __getitem__(self, idx): - return self._data[idx] - - -# LRA tokenizer renames ']' to 'X' and delete parentheses as their tokenizer removes -# non-alphanumeric characters. -# https://github.com/google-research/long-range-arena/blob/264227cbf9591e39dd596d2dc935297a2070bdfe/lra_benchmarks/listops/input_pipeline.py#L46 -def listops_tokenizer(s): - return s.translate({ord("]"): ord("X"), ord("("): None, ord(")"): None}).split() - - -class ListOps(SequenceDataset): - _name_ = "listops" - d_output = 10 - l_output = 0 - - @property - def init_defaults(self): - return { - "l_max": 2048, - "append_bos": False, - "append_eos": True, - # 'max_vocab': 20, # Actual size 18 - "n_workers": 4, # Only used for tokenizing dataset - } - - @property - def n_tokens(self): - return len(self.vocab) - - @property - def _cache_dir_name(self): - return f"l_max-{self.l_max}-append_bos-{self.append_bos}-append_eos-{self.append_eos}" - - def init(self): - if self.data_dir is None: - self.data_dir = default_data_path / self._name_ - self.cache_dir = self.data_dir / self._cache_dir_name - - def prepare_data(self): - if self.cache_dir is None: - for split in ["train", "val", "test"]: - split_path = self.data_dir / f"basic_{split}.tsv" - if not split_path.is_file(): - raise FileNotFoundError( - f""" - File {str(split_path)} not found. - To get the dataset, download lra_release.gz from - https://github.com/google-research/long-range-arena, - then unzip it with tar -xvf lra_release.gz. - Then point data_dir to the listops-1000 directory. - """ - ) - else: # Process the dataset and save it - self.process_dataset() - - def setup(self, stage=None): - if stage == "test" and hasattr(self, "dataset_test"): - return - dataset, self.tokenizer, self.vocab = self.process_dataset() - self.vocab_size = len(self.vocab) - dataset.set_format(type="torch", columns=["input_ids", "Target"]) - self.dataset_train, self.dataset_val, self.dataset_test = ( - dataset["train"], - dataset["val"], - dataset["test"], - ) - - def collate_batch(batch): - xs, ys = zip(*[(data["input_ids"], data["Target"]) for data in batch]) - lengths = torch.tensor([len(x) for x in xs]) - xs = nn.utils.rnn.pad_sequence( - xs, padding_value=self.vocab[""], batch_first=True - ) - ys = torch.tensor(ys) - return xs, ys, {"lengths": lengths} - - self._collate_fn = collate_batch - - def process_dataset(self): - cache_dir = ( - None if self.cache_dir is None else self.cache_dir / self._cache_dir_name - ) - if cache_dir is not None: - if cache_dir.is_dir(): - return self._load_from_cache(cache_dir) - - dataset = load_dataset( - "csv", - data_files={ - "train": str(self.data_dir / "basic_train.tsv"), - "val": str(self.data_dir / "basic_val.tsv"), - "test": str(self.data_dir / "basic_test.tsv"), - }, - delimiter="\t", - keep_in_memory=True, - ) - - tokenizer = listops_tokenizer - - # Account for and tokens - l_max = self.l_max - int(self.append_bos) - int(self.append_eos) - tokenize = lambda example: {"tokens": tokenizer(example["Source"])[:l_max]} - dataset = dataset.map( - tokenize, - remove_columns=["Source"], - keep_in_memory=True, - load_from_cache_file=False, - num_proc=max(self.n_workers, 1), - ) - vocab = torchtext.vocab.build_vocab_from_iterator( - dataset["train"]["tokens"], - specials=( - ["", ""] - + ([""] if self.append_bos else []) - + ([""] if self.append_eos else []) - ), - ) - vocab.set_default_index(vocab[""]) - - numericalize = lambda example: { - "input_ids": vocab( - ([""] if self.append_bos else []) - + example["tokens"] - + ([""] if self.append_eos else []) - ) - } - dataset = dataset.map( - numericalize, - remove_columns=["tokens"], - keep_in_memory=True, - load_from_cache_file=False, - num_proc=max(self.n_workers, 1), - ) - - if cache_dir is not None: - self._save_to_cache(dataset, tokenizer, vocab, cache_dir) - return dataset, tokenizer, vocab - - def _save_to_cache(self, dataset, tokenizer, vocab, cache_dir): - cache_dir = self.cache_dir / self._cache_dir_name - logger = logging.getLogger(__name__) - logger.info(f"Saving to cache at {str(cache_dir)}") - dataset.save_to_disk(str(cache_dir)) - with open(cache_dir / "tokenizer.pkl", "wb") as f: - pickle.dump(tokenizer, f) - with open(cache_dir / "vocab.pkl", "wb") as f: - pickle.dump(vocab, f) - - def _load_from_cache(self, cache_dir): - assert cache_dir.is_dir() - logger = logging.getLogger(__name__) - logger.info(f"Load from cache at {str(cache_dir)}") - dataset = DatasetDict.load_from_disk(str(cache_dir)) - with open(cache_dir / "tokenizer.pkl", "rb") as f: - tokenizer = pickle.load(f) - with open(cache_dir / "vocab.pkl", "rb") as f: - vocab = pickle.load(f) - return dataset, tokenizer, vocab - -class PathFinderDataset(torch.utils.data.Dataset): - """Path Finder dataset.""" - - # There's an empty file in the dataset - blacklist = {"pathfinder32/curv_baseline/imgs/0/sample_172.png"} - - def __init__(self, data_dir, transform=None): - """ - Args: - data_dir (string): Directory with all the images. - transform (callable, optional): Optional transform to be applied - on a sample. - """ - self.data_dir = Path(data_dir).expanduser() - assert self.data_dir.is_dir(), f"data_dir {str(self.data_dir)} does not exist" - self.transform = transform - samples = [] - # for diff_level in ['curv_baseline', 'curv_contour_length_9', 'curv_contour_length_14']: - for diff_level in ["curv_contour_length_14"]: - path_list = sorted( - list((self.data_dir / diff_level / "metadata").glob("*.npy")), - key=lambda path: int(path.stem), - ) - assert path_list, "No metadata found" - for metadata_file in path_list: - with open(metadata_file, "r") as f: - for metadata in f.read().splitlines(): - metadata = metadata.split() - image_path = Path(diff_level) / metadata[0] / metadata[1] - if ( - str(Path(self.data_dir.stem) / image_path) - not in self.blacklist - ): - label = int(metadata[3]) - samples.append((image_path, label)) - self.samples = samples - - def __len__(self): - return len(self.samples) - - def __getitem__(self, idx): - path, target = self.samples[idx] - # https://github.com/pytorch/vision/blob/9b29f3f22783112406d9c1a6db47165a297c3942/torchvision/datasets/folder.py#L247 - with open(self.data_dir / path, "rb") as f: - sample = Image.open(f).convert("L") # Open in grayscale - if self.transform is not None: - sample = self.transform(sample) - return sample, target - -class PathFinder(ImageResolutionSequenceDataset): - _name_ = "pathfinder" - d_input = 1 - d_output = 2 - l_output = 0 - - @property - def n_tokens(self): - if self.tokenize: - return 256 - - @property - def init_defaults(self): - return { - "resolution": 32, - "sequential": True, - "tokenize": False, - "center": True, - "pool": 1, - "val_split": 0.1, - "test_split": 0.1, - "seed": 42, # Controls the train/val/test split - } - - def default_transforms(self): - transform_list = [torchvision.transforms.ToTensor()] - if self.pool > 1: - transform_list.append( - Reduce( - "1 (h h2) (w w2) -> 1 h w", - "mean", - h2=self.pool, - w2=self.pool, - ) - ) - if self.tokenize: - transform_list.append( - torchvision.transforms.Lambda(lambda x: (x * 255).long()) - ) - else: - if self.center: - transform_list.append(torchvision.transforms.Normalize(mean=0.5, std=0.5)) - if self.sequential: - # If tokenize, it makes more sense to get rid of the channel dimension - transform_list.append( - Rearrange("1 h w -> (h w)") - if self.tokenize - else Rearrange("1 h w -> (h w) 1") - ) - else: - transform_list.append(Rearrange("1 h w -> h w 1")) - return torchvision.transforms.Compose(transform_list) - - def prepare_data(self): - if not self.data_dir.is_dir(): - raise FileNotFoundError( - f""" - Directory {str(self.data_dir)} not found. - To get the dataset, download lra_release.gz from - https://github.com/google-research/long-range-arena, - then unzip it with tar -xvf lra_release.gz. - Then point data_dir to the pathfinderX directory, where X is either 32, 64, 128, or 256. - """ - ) - - def setup(self, stage=None): - if self.data_dir is None: - self.data_dir = ( - default_data_path / self._name_ / f"pathfinder{self.resolution}" - ) - - if stage == "test" and hasattr(self, "dataset_test"): - return - # [2021-08-18] TD: I ran into RuntimeError: Too many open files. - # https://github.com/pytorch/pytorch/issues/11201 - # torch.multiprocessing.set_sharing_strategy("file_system") - dataset = PathFinderDataset(self.data_dir, transform=self.default_transforms()) - len_dataset = len(dataset) - val_len = int(self.val_split * len_dataset) - test_len = int(self.test_split * len_dataset) - train_len = len_dataset - val_len - test_len - ( - self.dataset_train, - self.dataset_val, - self.dataset_test, - ) = torch.utils.data.random_split( - dataset, - [train_len, val_len, test_len], - generator=torch.Generator().manual_seed(self.seed), - ) - -class AAN(SequenceDataset): - _name_ = "aan" - d_output = 2 # Use accuracy instead of binary_accuracy - l_output = 0 - - @property - def n_tokens(self): - return len(self.vocab) - - @property - def init_defaults(self): - return { - "l_max": 4000, - # 'max_vocab': 100, # Full size 98 - "append_bos": False, - "append_eos": True, - "n_workers": 4, # For tokenizing only - } - - @property - def _cache_dir_name(self): - return f"l_max-{self.l_max}-append_bos-{self.append_bos}-append_eos-{self.append_eos}" - - def init(self): - if self.data_dir is None: - self.data_dir = default_data_path / self._name_ - self.cache_dir = self.data_dir / self._cache_dir_name - - def prepare_data(self): - if self.cache_dir is None: - for split in ["train", "eval", "test"]: - split_path = self.data_dir / f"new_aan_pairs.{split}.tsv" - if not split_path.is_file(): - raise FileNotFoundError( - f""" - File {str(split_path)} not found. - To get the dataset, download lra_release.gz from - https://github.com/google-research/long-range-arena, - then unzip it with tar -xvf lra_release.gz. - Then point data_dir to the tsv_data directory. - """ - ) - else: # Process the dataset and save it - self.process_dataset() - - def setup(self, stage=None): - if stage == "test" and hasattr(self, "dataset_test"): - return - - # [2021-08-18] TD: I ran into RuntimeError: Too many open files. - # https://github.com/pytorch/pytorch/issues/11201 - # torch.multiprocessing.set_sharing_strategy("file_system") - - dataset, self.tokenizer, self.vocab = self.process_dataset() - # self.vocab_size = len(self.vocab) - print("AAN vocab size:", len(self.vocab)) - - dataset.set_format(type="torch", columns=["input_ids1", "input_ids2", "label"]) - self.dataset_train, self.dataset_val, self.dataset_test = ( - dataset["train"], - dataset["val"], - dataset["test"], - ) - - def collate_batch(batch): - xs1, xs2, ys = zip( - *[ - (data["input_ids1"], data["input_ids2"], data["label"]) - for data in batch - ] - ) - lengths1 = torch.tensor([len(x) for x in xs1]) - lengths2 = torch.tensor([len(x) for x in xs2]) - xs1 = nn.utils.rnn.pad_sequence( - xs1, padding_value=self.vocab[""], batch_first=True - ) - xs2 = nn.utils.rnn.pad_sequence( - xs2, padding_value=self.vocab[""], batch_first=True - ) - # Pad both to same length - # Shape (batch, length) - L = max(xs1.size(1), xs2.size(1)) - xs1 = F.pad(xs1, (0, L-xs1.size(1)), value=self.vocab[""]) - xs2 = F.pad(xs2, (0, L-xs2.size(1)), value=self.vocab[""]) - ys = torch.tensor(ys) - # return xs1, xs2, ys, lengths1, lengths2 - - # Concatenate two batches - xs = torch.cat([xs1, xs2], dim=0) - lengths = torch.cat([lengths1, lengths2], dim=0) - return xs, ys, {"lengths": lengths} - - self._collate_fn = collate_batch - - def process_dataset(self): - cache_dir = ( - None if self.cache_dir is None else self.cache_dir / self._cache_dir_name - ) - if cache_dir is not None: - if cache_dir.is_dir(): - return self._load_from_cache(cache_dir) - - dataset = load_dataset( - "csv", - data_files={ - "train": str(self.data_dir / "new_aan_pairs.train.tsv"), - "val": str(self.data_dir / "new_aan_pairs.eval.tsv"), - "test": str(self.data_dir / "new_aan_pairs.test.tsv"), - }, - delimiter="\t", - column_names=["label", "input1_id", "input2_id", "text1", "text2"], - keep_in_memory=True, - ) # True) - dataset = dataset.remove_columns(["input1_id", "input2_id"]) - new_features = dataset["train"].features.copy() - new_features["label"] = Value("int32") - dataset = dataset.cast(new_features) - - tokenizer = list # Just convert a string to a list of chars - # Account for and tokens - l_max = self.l_max - int(self.append_bos) - int(self.append_eos) - tokenize = lambda example: { - "tokens1": tokenizer(example["text1"])[:l_max], - "tokens2": tokenizer(example["text2"])[:l_max], - } - dataset = dataset.map( - tokenize, - remove_columns=["text1", "text2"], - keep_in_memory=True, - load_from_cache_file=False, - num_proc=max(self.n_workers, 1), - ) - vocab = torchtext.vocab.build_vocab_from_iterator( - dataset["train"]["tokens1"] + dataset["train"]["tokens2"], - specials=( - ["", ""] - + ([""] if self.append_bos else []) - + ([""] if self.append_eos else []) - ), - ) - vocab.set_default_index(vocab[""]) - - encode = lambda text: vocab( - ([""] if self.append_bos else []) - + text - + ([""] if self.append_eos else []) - ) - numericalize = lambda example: { - "input_ids1": encode(example["tokens1"]), - "input_ids2": encode(example["tokens2"]), - } - dataset = dataset.map( - numericalize, - remove_columns=["tokens1", "tokens2"], - keep_in_memory=True, - load_from_cache_file=False, - num_proc=max(self.n_workers, 1), - ) - - if cache_dir is not None: - self._save_to_cache(dataset, tokenizer, vocab, cache_dir) - return dataset, tokenizer, vocab - - def _save_to_cache(self, dataset, tokenizer, vocab, cache_dir): - cache_dir = self.cache_dir / self._cache_dir_name - logger = logging.getLogger(__name__) - logger.info(f"Saving to cache at {str(cache_dir)}") - dataset.save_to_disk(str(cache_dir)) - with open(cache_dir / "tokenizer.pkl", "wb") as f: - pickle.dump(tokenizer, f) - with open(cache_dir / "vocab.pkl", "wb") as f: - pickle.dump(vocab, f) - - def _load_from_cache(self, cache_dir): - assert cache_dir.is_dir() - logger = logging.getLogger(__name__) - logger.info(f"Load from cache at {str(cache_dir)}") - dataset = DatasetDict.load_from_disk(str(cache_dir)) - with open(cache_dir / "tokenizer.pkl", "rb") as f: - tokenizer = pickle.load(f) - with open(cache_dir / "vocab.pkl", "rb") as f: - vocab = pickle.load(f) - return dataset, tokenizer, vocab \ No newline at end of file diff --git a/src/clm/src/dataloaders/synthetics.py b/src/clm/src/dataloaders/synthetics.py deleted file mode 100644 index de7e0f0d..00000000 --- a/src/clm/src/dataloaders/synthetics.py +++ /dev/null @@ -1,335 +0,0 @@ -'''Synthetic datasets to test in-context learning ability.''' - -import os -import torch -from torch.utils.data import TensorDataset, Dataset, DataLoader -from typing import Dict -import numpy as np -from tqdm import tqdm -from collections import Counter - -from clm.src.dataloaders.base import SequenceDataset - -class Vocab: - """Custom vocab.""" - def __init__(self, vocab_size: int, special_vocabs: Dict): - # Special tokens hold copy_prefix and noop/pad token etc - assert "copy_prefix" in special_vocabs - self.special_vocabs = special_vocabs - vocab = [str(v) for v in list(range(vocab_size))] - self.non_special_vocab = sorted(list(vocab)) - self.vocab = sorted(list(set(vocab + list(self.special_vocabs.values())))) - self.v2id = {v:i for i,v in enumerate(self.vocab)} - self.vocab_size = len(vocab) - - def get_next_vocab(self, token: str): - """Gets next token excluding special_vocabs.""" - id = (self.get_id(token) + 1) % self.vocab_size - while self.get_vocab(id) in self.special_vocabs: - id = (id + 1) % self.vocab_size - return self.get_vocab(id) - - @property - def copy_prefix(self): - return self.special_vocabs["copy_prefix"] - - @property - def noop(self): - return self.special_vocabs["noop"] - - @property - def special_tokens(self): - return set(self.special_vocabs.values()) - - def get_id(self, token: str): - return self.v2id[token] - - def get_vocab(self, id: int): - return self.vocab[id] - - def __len__(self): - return len(self.vocab) - - -class Tokenizer: - """Custom Tokenizer for our own vocab.""" - def __init__(self, vocab: Vocab): - self.vocab = vocab - - def tokenize(self, text: str, return_tensor=False, mask_input=False): - input_ids = [self.vocab.get_id(t) for t in text.split()] - if self.vocab.get_id(self.vocab.copy_prefix) not in input_ids: - raise ValueError("Input text must contain copy_prefix token.") - copy_prefix_pos = input_ids.index(self.vocab.get_id(self.vocab.copy_prefix)) - labels = input_ids - if mask_input: - # Mask the input tokens for loss but do not mask the copied token - labels = [-100] * (copy_prefix_pos+1) + labels[copy_prefix_pos+1:] - if return_tensor: - input_ids = torch.LongTensor(input_ids) - labels = torch.LongTensor(labels) - return { - "input_ids": input_ids, - "labels": labels, - } - - def decode(self, ids: list): - return " ".join([self.vocab.get_vocab(id) for id in ids]) - -def generate_start_seq(vocab: Vocab, input_seq_len: int, rng: np.random.Generator): - """Generate token sequence up to and including the copy_prefix token.""" - vocab_seq = rng.choice( - vocab.vocab, - input_seq_len, - replace=True, - # Do not generate any special tokens - p=[1/(len(vocab)-len(vocab.special_tokens)) if p not in vocab.special_tokens else 0 for p in vocab.vocab]) - vocab_seq = np.append(vocab_seq, vocab.copy_prefix) - return vocab_seq.tolist() - -def generate_induction_head( - vocab: Vocab, - input_seq_len: int, - copy_prefix: str, - induction_len: int, - num_triggers: int, - rng: np.random.Generator, - valid_chars: list = None, -): - """Generate sequence where the copy prefix is inserted into the input - and then the character after the copy prefix is copied at the end. - """ - if valid_chars is not None: - raise NotImplementedError("Valid chars not implemented for induction heads.") - vocab_seq = generate_start_seq(vocab, input_seq_len, rng) - if rng.uniform() < 0.5: - num_triggers = 1 - pos = sorted(rng.integers( - input_seq_len - (1 + induction_len), size=num_triggers - )) - pos_filtered = [] - for i, p in enumerate(pos): - if i == 0: - pos_filtered.append(p) - elif p - pos_filtered[-1] > induction_len: - pos_filtered.append(p) - to_copy = [ - vocab_seq[pos_filtered[0]+1+i] - for i in range(induction_len) - ] - for pos in pos_filtered: - vocab_seq[pos] = copy_prefix - for i in range(induction_len): - vocab_seq[pos+1+i] = to_copy[i] - # if valid_chars is not None and to_copy not in valid_chars: - # vocab_seq[pos+1] = rng.choice(valid_chars) - # to_copy = vocab_seq[pos+1] - vocab_seq = vocab_seq + to_copy - return " ".join(vocab_seq) - -def generate_assoc_recall( - vocab: Vocab, - input_seq_len: int, - num_keys: int, - rng: np.random.Generator, - allow_dot: bool = True, - valid_chars: list = None, -): - """Generate sequence where the input has a sequence of key value pairs - and the copy prefix at the end, and then a key value pair is inserted - after the copy prefix.""" - non_special_vocab_size = len(vocab.non_special_vocab) - keys = vocab.non_special_vocab[:non_special_vocab_size // 2] - values = vocab.non_special_vocab[non_special_vocab_size // 2:] - keys_multi = [ [key] for key in keys ] - for i in range(num_keys-1): - keys_multi = [ key + [key2] for key in keys_multi for key2 in keys ] - kv_map = { - tuple(k): rng.choice(values) for k in keys_multi - } - - key_present = {} - vocab_seq = [] - for _ in range(input_seq_len // (num_keys + 1)): - k = tuple(rng.choice(list(kv_map.keys()))) - v = kv_map[k] - vocab_seq += list(k) + [v] - key_present[k] = True - # vocab_seq.append(v) - - - k = tuple(rng.choice(list(kv_map.keys()))) - if not allow_dot: - while k not in key_present: - k = tuple(rng.choice(list(key_present.keys()))) - to_copy = [vocab.copy_prefix] + list(k) + [ kv_map[k] if k in key_present else vocab.noop ] - vocab_seq = vocab_seq + to_copy - return " ".join(vocab_seq) - -class ICLDataModule(SequenceDataset): - _name_ = "icl_synthetics" - - def __init__( - self, - num_examples: int, - num_test_examples: int, - vocab_size: int, - input_seq_len: int, - copy_method: str, - number_duplicates_per_epoch: int = 0, - seed: int = 0, - batch_size: int = 32, - split_train_test: bool = False, - induction_len: int = 1, - induction_num_triggers: int = 1, - allow_dot: bool = False, - max_copy_len: int = 10, - test_seq_len: int = None, - num_keys: int = 1, # number of keys for associative recall, - data_dir: str = None, - *args, **kwargs - ): - self.num_examples = num_examples - self.num_test_examples = num_test_examples - self.input_seq_len = input_seq_len - self.vocab_size = vocab_size - self.copy_method = copy_method - assert copy_method in ["induction_head", "assoc_recall"] - self.number_duplicates_per_epoch = number_duplicates_per_epoch - self.seed = seed - self.batch_size = batch_size - self.split_train_test = split_train_test # let the same copy chars appear in train/test - self.induction_len = induction_len - self.induction_num_triggers = induction_num_triggers - self.allow_dot = allow_dot - self.max_copy_len = max_copy_len - self.data_dir = data_dir - - if test_seq_len is not None: - self.test_seq_len = test_seq_len - else: - self.test_seq_len = input_seq_len - self.num_keys = num_keys - - special_vocabs = { - "copy_prefix": "=>", - "noop": "." - } - self.special_vocabs = special_vocabs - self.vocab = Vocab(vocab_size-len(special_vocabs), special_vocabs=special_vocabs) - self.tokenizer = Tokenizer(self.vocab) - - self.num_extra_seq_len = 2 - - if self.copy_method == "induction_head": - self.copy_f = self.generate_induction_head - self.num_extra_seq_len = 1 + self.induction_len - elif self.copy_method == "assoc_recall": - self.copy_f = self.generate_assoc_recall - self.num_extra_seq_len = 1 + self.num_keys - else: - self.copy_f = None - - if self.number_duplicates_per_epoch > 0: - self.duplicate_ex = self.generate_example() - self.duplicate_index = max(int(self.num_examples / self.number_duplicates_per_epoch), 1) - else: - self.duplicate_ex = None - self.duplicate_index = -1 - - self.total_seq_len = self.input_seq_len + self.num_extra_seq_len - - def generate_induction_head(self, seqlen=None, valid_chars=None): - return generate_induction_head(self.vocab, seqlen if seqlen is not None else self.input_seq_len, self.special_vocabs["copy_prefix"], self.induction_len, self.induction_num_triggers, self.rng, valid_chars=valid_chars) - - def generate_assoc_recall(self, seqlen=None, valid_chars=None): - return generate_assoc_recall(self.vocab, seqlen if seqlen is not None else self.input_seq_len, self.num_keys, self.rng, allow_dot = self.allow_dot, valid_chars=valid_chars) - - def generate_example(self, seqlen=None, valid_chars=None): - vocab_seq = self.copy_f(seqlen=seqlen, valid_chars=valid_chars) - return self.tokenizer.tokenize(vocab_seq, return_tensor=True) - - def setup(self, stage=None): - train_tensor = test_tensor = None - if self.data_dir is not None: - try: - train_tensor = torch.load(os.path.join(self.data_dir, - f"train_{self.copy_method}_{self.num_examples}_{self.vocab_size}_{self.input_seq_len}.pt")) - test_tensor = torch.load(os.path.join(self.data_dir, - f"test_{self.copy_method}_{self.num_examples}_{self.vocab_size}_{self.input_seq_len}.pt")) - except: - pass - - if train_tensor is None or test_tensor is None: - if hasattr(self, 'dataset'): - return - self.rng = np.random.default_rng(self.seed) - - if self.split_train_test: - all_vocab = self.vocab.non_special_vocab - train_vocab = set(self.rng.choice(all_vocab, size=len(all_vocab) // 2, replace=False)) - test_vocab = set(all_vocab) - train_vocab - train_vocab = list(train_vocab) - test_vocab = list(test_vocab) - else: - train_vocab = None - test_vocab = None - - all_examples = [] - for i, (example_count, valid_vocab) in enumerate(zip([self.num_examples, self.num_test_examples], [train_vocab, test_vocab])): - examples = torch.stack([self.generate_example( - seqlen=self.input_seq_len if i == 0 else self.test_seq_len, - valid_chars=valid_vocab - )['input_ids'] for _ in tqdm(range(example_count))]) - examples = torch.unique(examples, dim=0, sorted=False).tolist() - - while len(examples) < example_count: - new_example = self.generate_example( - seqlen=self.input_seq_len if i == 0 else self.test_seq_len, - valid_chars=valid_vocab - )['input_ids'].tolist() - if new_example not in examples: - examples.append(new_example) - - self.rng.shuffle(examples) - all_examples.append(torch.LongTensor(examples)) - - # all_examples = torch.concat(all_examples) - train_tensor = torch.stack([torch.stack([example[:-1], example[1:]]) for example in all_examples[0]]) - test_tensor = torch.stack([torch.stack([example[:-1], example[1:]]) for example in all_examples[1]]) - test_tensor[:, 1, :-1 * (self.num_extra_seq_len - 1)] = -100 - if self.copy_method in ["assoc_recall"]: - test_tensor[:, 1, :-1] = -100 - if self.copy_method in ["majority", "fom1"]: - train_tensor[:, 1, :-1 * (self.num_extra_seq_len - 1)] = -100 - - if self.data_dir is not None: - torch.save(train_tensor, os.path.join(self.data_dir, - f"train_{self.copy_method}_{self.num_examples}_{self.vocab_size}_{self.input_seq_len}.pt") - ) - torch.save(test_tensor, os.path.join(self.data_dir, - f"test_{self.copy_method}_{self.num_examples}_{self.vocab_size}_{self.input_seq_len}.pt") - ) - - self.dataset = { - 'train': TensorDataset(train_tensor[:, 0, :], train_tensor[:, 1, :]), - 'test': TensorDataset(test_tensor[:, 0, :], test_tensor[:, 1, :]) - } - - def train_dataloader(self, *args, **kwargs): - return self._data_loader(self.dataset['train'], shuffle=True) - - def val_dataloader(self, *args, **kwargs): - return self._data_loader(self.dataset['test'], shuffle=False) - - def test_dataloader(self, *args, **kwargs): - return self._data_loader(self.dataset['test'], shuffle=False) - - def _data_loader(self, dataset: Dataset, shuffle: bool = False) -> DataLoader: - return DataLoader( - dataset, - batch_size=self.batch_size, - num_workers=10, - shuffle=shuffle, - persistent_workers=True - ) \ No newline at end of file diff --git a/src/clm/src/dataloaders/utils/cifar_augmentations.py b/src/clm/src/dataloaders/utils/cifar_augmentations.py deleted file mode 100644 index 3c063edb..00000000 --- a/src/clm/src/dataloaders/utils/cifar_augmentations.py +++ /dev/null @@ -1,138 +0,0 @@ -""" -Borrowed from https://github.com/hysts/pytorch_image_classification/tree/9ff4248905850c68aa9c09c17914307eb81769e7/pytorch_image_classification/transforms -""" -import torch -import numpy as np -import PIL -import PIL.Image -from PIL.Image import Image - - -class NpNormalize: - def __init__(self, mean: np.ndarray, std: np.ndarray): - self.mean = np.array(mean) - self.std = np.array(std) - - def __call__(self, image: PIL.Image.Image) -> np.ndarray: - image = np.asarray(image).astype(np.float32) / 255. - image = (image - self.mean) / self.std - return image - - -class Cutout(object): - """Randomly mask out one or more patches from an image. - Args: - n_holes (int): Number of patches to cut out of each image. - length (int): The length (in pixels) of each square patch. - """ - - def __init__(self, n_holes, length): - self.n_holes = n_holes - self.length = length - - def __call__(self, img): - """ - Args: - img (Tensor): Tensor image of size (C, H, W). - Returns: - Tensor: Image with n_holes of dimension length x length cut out of it. - """ - h = img.size(1) - w = img.size(2) - - mask = np.ones((h, w), np.float32) - - for n in range(self.n_holes): - y = np.random.randint(h) - x = np.random.randint(w) - - y1 = np.clip(y - self.length // 2, 0, h) - y2 = np.clip(y + self.length // 2, 0, h) - x1 = np.clip(x - self.length // 2, 0, w) - x2 = np.clip(x + self.length // 2, 0, w) - - mask[y1: y2, x1: x2] = 0. - - mask = torch.from_numpy(mask) - mask = mask.expand_as(img) - img = img * mask - - return img - - -# -# class Cutout: -# def __init__(self, p=1.0, mask_size=16, cutout_inside=False, mask_color=0): -# # https://github.com/hysts/pytorch_image_classification/blob/9ff4248905850c68aa9c09c17914307eb81769e7/configs/augmentations/cifar/cutout.yaml -# self.p = p -# self.mask_size = mask_size -# self.cutout_inside = cutout_inside -# self.mask_color = mask_color -# -# self.mask_size_half = self.mask_size // 2 -# self.offset = 1 if self.mask_size % 2 == 0 else 0 -# -# def __call__(self, image: np.ndarray) -> np.ndarray: -# image = np.asarray(image).copy() -# -# if np.random.random() > self.p: -# return image -# -# h, w = image.shape[:2] -# -# if self.cutout_inside: -# cxmin = self.mask_size_half -# cxmax = w + self.offset - self.mask_size_half -# cymin = self.mask_size_half -# cymax = h + self.offset - self.mask_size_half -# else: -# cxmin, cxmax = 0, w + self.offset -# cymin, cymax = 0, h + self.offset -# -# cx = np.random.randint(cxmin, cxmax) -# cy = np.random.randint(cymin, cymax) -# xmin = cx - self.mask_size_half -# ymin = cy - self.mask_size_half -# xmax = xmin + self.mask_size -# ymax = ymin + self.mask_size -# xmin = max(0, xmin) -# ymin = max(0, ymin) -# xmax = min(w, xmax) -# ymax = min(h, ymax) -# image[ymin:ymax, xmin:xmax] = self.mask_color -# return image - - -class RandomErasing: - def __init__(self, p=0.5, max_attempt=20, sl=0.02, sh=0.4, rl=0.3, rh=1. / 0.3): - # https://github.com/hysts/pytorch_image_classification/blob/9ff4248905850c68aa9c09c17914307eb81769e7/configs/augmentations/cifar/random_erasing.yaml - self.p = 0.5 - self.max_attempt = 20 - self.sl, self.sh = 0.02, 0.4 - self.rl = 0.3 - self.rh = 1. / 0.3 - - def __call__(self, image: np.ndarray) -> np.ndarray: - image = np.asarray(image).copy() - - if np.random.random() > self.p: - return image - - h, w = image.shape[:2] - image_area = h * w - - for _ in range(self.max_attempt): - mask_area = np.random.uniform(self.sl, self.sh) * image_area - aspect_ratio = np.random.uniform(self.rl, self.rh) - mask_h = int(np.sqrt(mask_area * aspect_ratio)) - mask_w = int(np.sqrt(mask_area / aspect_ratio)) - - if mask_w < w and mask_h < h: - x0 = np.random.randint(0, w - mask_w) - y0 = np.random.randint(0, h - mask_h) - x1 = x0 + mask_w - y1 = y0 + mask_h - image[y0:y1, x0:x1] = np.random.uniform(0, 1) - break - - return image diff --git a/src/clm/src/dataloaders/utils/timm_mixup.py b/src/clm/src/dataloaders/utils/timm_mixup.py deleted file mode 100644 index 333a9c65..00000000 --- a/src/clm/src/dataloaders/utils/timm_mixup.py +++ /dev/null @@ -1,22 +0,0 @@ -import torch - -from timm.data import Mixup -from timm.data.mixup import mixup_target - - -class TimmMixup(Mixup): - """ Wrap timm.data.Mixup that avoids the assert that batch size must be even. - """ - def __call__(self, x, target, *args): - if self.mode == 'elem': - lam = self._mix_elem(x) - elif self.mode == 'pair': - # We move the assert from the beginning of the function to here - assert len(x) % 2 == 0, 'Batch size should be even when using this' - lam = self._mix_pair(x) - else: - lam = self._mix_batch(x) - # Another change is to set the right device here - target = mixup_target(target, self.num_classes, lam, self.label_smoothing, - device=target.device) - return x, target, *args \ No newline at end of file diff --git a/src/clm/src/dataloaders/utils/vocabulary.py b/src/clm/src/dataloaders/utils/vocabulary.py deleted file mode 100644 index bdb98936..00000000 --- a/src/clm/src/dataloaders/utils/vocabulary.py +++ /dev/null @@ -1,237 +0,0 @@ -# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import contextlib -import os -from collections import Counter -from collections import OrderedDict - -import torch - -import clm.src.utils as utils - - -class Vocab(object): - def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True, - delimiter=None, vocab_file=None): - self.counter = Counter() - self.special = special - self.min_freq = min_freq - self.max_size = max_size - self.lower_case = lower_case - self.delimiter = delimiter - self.vocab_file = vocab_file - - def tokenize(self, line, add_eos=False, add_double_eos=False): - line = line.strip() - # convert to lower case - if self.lower_case: - line = line.lower() - - # empty delimiter '' will evaluate False - if self.delimiter == '': - symbols = line - else: - symbols = line.split(self.delimiter) - - if add_double_eos: # lm1b - return [''] + symbols + [''] - elif add_eos: - return symbols + [''] - else: - return symbols - - def count_file(self, path, verbose=False, add_eos=False): - if verbose: - print('counting file {} ...'.format(path)) - assert os.path.exists(path) - - sents = [] - with open(path, 'r', encoding='utf-8') as f: - for idx, line in enumerate(f): - if verbose and idx > 0 and idx % 500000 == 0: - print(' line {}'.format(idx)) - symbols = self.tokenize(line, add_eos=add_eos) - self.counter.update(symbols) - sents.append(symbols) - - return sents - - def count_sents(self, sents, verbose=False): - """ - sents : a list of sentences, each a list of tokenized symbols - """ - if verbose: - print('counting {} sents ...'.format(len(sents))) - for idx, symbols in enumerate(sents): - if verbose and idx > 0 and idx % 500000 == 0: - print(' line {}'.format(idx)) - self.counter.update(symbols) - - def _build_from_file(self, vocab_file): - self.idx2sym = [] - self.sym2idx = OrderedDict() - - with open(vocab_file, 'r', encoding='utf-8') as f: - for line in f: - symb = line.strip().split()[0] - self.add_symbol(symb) - self.unk_idx = self.sym2idx[''] - - def build_vocab(self): - if self.vocab_file: - print('building vocab from {}'.format(self.vocab_file)) - self._build_from_file(self.vocab_file) - print('final vocab size {}'.format(len(self))) - else: - print('building vocab with min_freq={}, max_size={}'.format( - self.min_freq, self.max_size)) - self.idx2sym = [] - self.sym2idx = OrderedDict() - - for sym in self.special: - self.add_special(sym) - - for sym, cnt in self.counter.most_common(self.max_size): - if cnt < self.min_freq: - break - self.add_symbol(sym) - - print('final vocab size {} from {} unique tokens'.format( - len(self), len(self.counter))) - - def encode_file(self, path, ordered=False, verbose=False, add_eos=True, - add_double_eos=False): - if verbose: - print('encoding file {} ...'.format(path)) - assert os.path.exists(path) - encoded = [] - with open(path, 'r', encoding='utf-8') as f: - for idx, line in enumerate(f): - if verbose and idx > 0 and idx % 500000 == 0: - print(' line {}'.format(idx)) - symbols = self.tokenize(line, add_eos=add_eos, - add_double_eos=add_double_eos) - encoded.append(self.convert_to_tensor(symbols)) - - if ordered: - encoded = torch.cat(encoded) - - return encoded - - def encode_sents(self, sents, ordered=False, verbose=False): - if verbose: - print('encoding {} sents ...'.format(len(sents))) - encoded = [] - for idx, symbols in enumerate(sents): - if verbose and idx > 0 and idx % 500000 == 0: - print(' line {}'.format(idx)) - encoded.append(self.convert_to_tensor(symbols)) - - if ordered: - encoded = torch.cat(encoded) - - return encoded - - def add_special(self, sym): - if sym not in self.sym2idx: - self.idx2sym.append(sym) - self.sym2idx[sym] = len(self.idx2sym) - 1 - setattr(self, '{}_idx'.format(sym.strip('<>')), self.sym2idx[sym]) - - def add_symbol(self, sym): - if sym not in self.sym2idx: - self.idx2sym.append(sym) - self.sym2idx[sym] = len(self.idx2sym) - 1 - - def get_sym(self, idx): - assert 0 <= idx < len(self), 'Index {} out of range'.format(idx) - return self.idx2sym[idx] - - def get_idx(self, sym): - if sym in self.sym2idx: - return self.sym2idx[sym] - else: - # print('encounter unk {}'.format(sym)) - assert '' not in sym - assert hasattr(self, 'unk_idx') - return self.sym2idx.get(sym, self.unk_idx) - - def get_symbols(self, indices): - return [self.get_sym(idx) for idx in indices] - - def get_indices(self, symbols): - return [self.get_idx(sym) for sym in symbols] - - def convert_to_tensor(self, symbols): - return torch.LongTensor(self.get_indices(symbols)) - - def convert_to_sent(self, indices, exclude=None): - if exclude is None: - return ' '.join([self.get_sym(idx) for idx in indices]) - else: - return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude]) - - def __len__(self): - return len(self.idx2sym) - - -# Class OpenAIVocab has been adapted from -# https://github.com/cybertronai/transformer-xl/blob/master/utils/vocabulary.py -class OpenAIVocab(Vocab): - def __init__(self, max_size=None, vocab_file=None): - from transformers import GPT2Tokenizer - self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2') - self.EOT = self.tokenizer.encoder['<|endoftext|>'] - self.max_size = max_size - self.vocab_file = vocab_file - - pad = 8 - vocab_size = len(self.tokenizer) - padded_vocab_size = (vocab_size + pad - 1) // pad * pad - for i in range(0, padded_vocab_size - vocab_size): - token = f'madeupword{i:09d}' - self.tokenizer.add_tokens([token]) - - def __len__(self): - return len(self.tokenizer) - - def count_file(self, path, verbose=False, add_eos=False): - # TODO: train from scratch, respect self.max_size - pass - - def build_vocab(self): - pass - - def encode_file(self, path, ordered=False, verbose=False, add_eos=True, add_double_eos=False) -> torch.LongTensor: - cached = path + '.bpe' - if os.path.exists(cached): - return torch.load(cached) - print(f'encoding file {path} ...') - assert os.path.exists(path), f"{path} doesn't exist" - - with open(path, encoding='utf-8') as f: - # Suppress warnings about length. - with open(os.devnull, "w") as devnull, contextlib.redirect_stderr(devnull): - out = torch.LongTensor(self.tokenizer.encode(f.read()) + [self.EOT]) - with utils.distributed.sync_workers() as rank: - if rank == 0: - torch.save(out, cached) - return out - - def tokenize(self, line, add_eos=False, add_double_eos=False): - return self.tokenizer.encode(line) - - def convert_to_tensor(self, symbols): - return torch.LongTensor(symbols) diff --git a/src/clm/src/dataloaders/vision.py b/src/clm/src/dataloaders/vision.py deleted file mode 100644 index ac44763b..00000000 --- a/src/clm/src/dataloaders/vision.py +++ /dev/null @@ -1,279 +0,0 @@ -"""Miscellaneous vision datasets.""" - -import os - -import torch -from torch import nn -from torch.nn import functional as F -import torchvision - -from clm.src.dataloaders.base import default_data_path, SequenceDataset - - -class ImageNet(SequenceDataset): - """ - .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2017/08/ - Sample-of-Images-from-the-ImageNet-Dataset-used-in-the-ILSVRC-Challenge.png - :width: 400 - :alt: Imagenet - Specs: - - 1000 classes - - Each image is (3 x varies x varies) (here we default to 3 x 224 x 224) - Imagenet train, val and test dataloaders. - The train set is the imagenet train. - The val split is taken from train if a val_split % is provided, or will be the same as test otherwise - The test set is the official imagenet validation set. - - """ - - _name_ = "imagenet" - d_input = 3 - d_output = 1000 - l_output = 0 - - init_defaults = { - "data_dir": None, - "cache_dir": None, - "image_size": 224, - "val_split": None, # currently not implemented - "train_transforms": None, - "val_transforms": None, - "test_transforms": None, - "mixup": None, # augmentation - "num_aug_repeats": 0, - "num_gpus": 1, - "shuffle": True, # for train - "loader_fft": False, - } - - @property - def num_classes(self) -> int: - """ - Return: - 1000 - """ - return 1000 - - def _verify_splits(self, data_dir: str, split: str) -> None: - dirs = os.listdir(data_dir) - - if split not in dirs: - raise FileNotFoundError( - f"a {split} Imagenet split was not found in {data_dir}," - f" make sure the folder contains a subfolder named {split}" - ) - - def prepare_data(self) -> None: - """This method already assumes you have imagenet2012 downloaded. It validates the data using the meta.bin. - .. warning:: Please download imagenet on your own first. - """ - if not self.use_archive_dataset: - self._verify_splits(self.data_dir, "train") - self._verify_splits(self.data_dir, "val") - else: - if not self.data_dir.is_file(): - raise FileNotFoundError(f"""Archive file {str(self.data_dir)} not found.""") - - def setup(self, stage=None): - """Creates train, val, and test dataset.""" - - from typing import Any, Callable, List, Optional, Union - - import hydra # for mixup - from pl_bolts.transforms.dataset_normalizations import \ - imagenet_normalization - from torch.utils.data import Dataset - from torch.utils.data.dataloader import default_collate - from torchvision.datasets import ImageFolder - - # for access in other methods - self.imagenet_normalization = imagenet_normalization - self.default_collate = default_collate - self.hydra = hydra - self.ImageFolder = ImageFolder - - if self.mixup is not None: - self.mixup_fn = hydra.utils.instantiate(self.mixup) - else: - self.mixup_fn = None - - self.dir_path = self.data_dir or default_data_path / self._name_ - - if stage == "fit" or stage is None: - self.set_phase([self.image_size]) - - if stage == "test" or stage is None: - test_transforms = (self.val_transform() if self.test_transforms is None - else hydra.utils.instantiate(self.test_transforms)) - - self.dataset_test = ImageFolder(os.path.join(self.dir_path, 'val'), transform=test_transforms) - - # # modded, override (for debugging) - # self.dataset_test = self.dataset_val - - def set_phase(self, stage_params=[224], val_upsample=False, test_upsample=False): - """ - For progresive learning. - Will modify train transform parameters during training, just image size for now, - and create a new train dataset, which the train_dataloader will load every - n epochs (in config). - - Later, will be possible to change magnitude of RandAug here too, and mixup alpha - - stage_params: list, list of values to change. single [image_size] for now - """ - - img_size = int(stage_params[0]) - - if val_upsample: - self.val_transforms["input_size"] = img_size - - train_transforms = (self.train_transform() if self.train_transforms is None - else self.hydra.utils.instantiate(self.train_transforms)) - val_transforms = (self.val_transform() if self.val_transforms is None - else self.hydra.utils.instantiate(self.val_transforms)) - - if self.loader_fft: - train_transforms = torchvision.transforms.Compose( - train_transforms.transforms + [ - torchvision.transforms.Lambda(lambda x: torch.fft.rfftn(x, s=tuple([2*l for l in x.shape[1:]]))) - ] - ) - val_transforms = torchvision.transforms.Compose( - val_transforms.transforms + [ - torchvision.transforms.Lambda(lambda x: torch.fft.rfftn(x, s=tuple([2*l for l in x.shape[1:]]))) - ] - ) - - self.dataset_train = self.ImageFolder(self.dir_path / 'train', - transform=train_transforms) - - if self.val_split > 0.: - # this will create the val split - self.split_train_val(self.val_split) - # will use the test split as val by default - else: - self.dataset_val = self.ImageFolder(self.dir_path / 'val', transform=val_transforms) - - # # modded, override (for debugging) - # self.dataset_train = self.dataset_val - - # not sure if normally you upsample test also - if test_upsample: - self.test_transforms["input_size"] = img_size - test_transforms = (self.val_transform() if self.test_transforms is None - else self.hydra.utils.instantiate(self.test_transforms)) - self.dataset_test = self.ImageFolder(os.path.join(self.dir_path, 'val'), transform=test_transforms) - ## modded, override (for debugging) - # self.dataset_test = self.dataset_val - - # could modify mixup by reinstantiating self.mixup_fn (later maybe) - - def train_transform(self): - """The standard imagenet transforms. - .. code-block:: python - transforms.Compose([ - transforms.RandomResizedCrop(self.image_size), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize( - mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225] - ), - ]) - """ - preprocessing = torchvision.transforms.Compose( - [ - torchvision.transforms.RandomResizedCrop(self.image_size), - torchvision.transforms.RandomHorizontalFlip(), - torchvision.transforms.ToTensor(), - self.imagenet_normalization(), - ] - ) - - return preprocessing - - def val_transform(self): - """The standard imagenet transforms for validation. - .. code-block:: python - transforms.Compose([ - transforms.Resize(self.image_size + 32), - transforms.CenterCrop(self.image_size), - transforms.ToTensor(), - transforms.Normalize( - mean=[0.485, 0.456, 0.406], - std=[0.229, 0.224, 0.225] - ), - ]) - """ - - preprocessing = torchvision.transforms.Compose( - [ - torchvision.transforms.Resize(self.image_size + 32), - torchvision.transforms.CenterCrop(self.image_size), - torchvision.transforms.ToTensor(), - self.imagenet_normalization(), - ] - ) - return preprocessing - - def train_dataloader(self, **kwargs): - """ The train dataloader """ - if self.num_aug_repeats == 0 or self.num_gpus == 1: - shuffle = self.shuffle - sampler = None - else: - shuffle = False - from timm.data.distributed_sampler import RepeatAugSampler - sampler = RepeatAugSampler(self.dataset_train, num_repeats=self.num_aug_repeats) - - # calculate resolution - resolution = self.image_size / self.train_transforms['input_size'] # usually 1.0 - - return (self._data_loader(self.dataset_train, shuffle=shuffle, mixup=self.mixup_fn, sampler=sampler, resolution=resolution, **kwargs)) - - def val_dataloader(self, **kwargs): - """ The val dataloader """ - kwargs['drop_last'] = False - - # update batch_size for eval if provided - batch_size = kwargs.get("batch_size_eval", None) or kwargs.get("batch_size") - kwargs["batch_size"] = batch_size - - # calculate resolution - resolution = self.image_size / self.val_transforms['input_size'] # usually 1.0 or 0.583 - - return (self._data_loader(self.dataset_val, resolution=resolution, **kwargs)) - - def test_dataloader(self, **kwargs): - """ The test dataloader """ - kwargs['drop_last'] = False - - # update batch_size for test if provided - batch_size = kwargs.get("batch_size_test", None) or kwargs.get("batch_size_eval", None) or kwargs.get("batch_size") - kwargs["batch_size"] = batch_size - - # calculate resolution - resolution = self.image_size / self.test_transforms.get("input_size", self.val_transforms['input_size']) - - return (self._data_loader(self.dataset_test, resolution=resolution, **kwargs)) - - def _data_loader(self, dataset, resolution, shuffle=False, mixup=None, sampler=None, **kwargs): - # collate_fn = (lambda batch: mixup(*self.default_collate(batch))) if mixup is not None else self.default_collate - collate_fn = (lambda batch: mixup(*self.collate_with_resolution(batch, resolution))) if mixup is not None else lambda batch: self.collate_with_resolution(batch, resolution) - - # hacked - can't pass this this arg to dataloader, but used to update the batch_size val / test - kwargs.pop('batch_size_eval', None) - kwargs.pop('batch_size_test', None) - - return torch.utils.data.DataLoader( - dataset, - collate_fn=collate_fn, - shuffle=shuffle, - sampler=sampler, - **kwargs, - ) - - def collate_with_resolution(self, batch, resolution): - stuff = self.default_collate(batch) - return *stuff, {"resolution": resolution} diff --git a/src/clm/src/models/__init__.py b/src/clm/src/models/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/clm/src/models/baselines/vit_all.py b/src/clm/src/models/baselines/vit_all.py deleted file mode 100644 index d2a18b6d..00000000 --- a/src/clm/src/models/baselines/vit_all.py +++ /dev/null @@ -1,433 +0,0 @@ -""" -The original Vision Transformer (ViT) from timm, copyright belongs to / Copyright 2020 Ross Wightman -""" -import math -import logging - -from functools import partial -from collections import OrderedDict -from copy import deepcopy - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from timm.models.helpers import build_model_with_cfg, overlay_external_default_cfg -from timm.models.layers import PatchEmbed, Mlp, trunc_normal_, lecun_normal_ - -from clm.src.models.sequence.base import SequenceModule -from clm.src.models.nn.components import Normalization -from clm.src.models.sequence.block import SequenceResidualBlock -from clm.src.utils.config import to_list, to_dict - -_logger = logging.getLogger(__name__) - - -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, - 'input_size': (3, 224, 224), - 'pool_size': None, - 'classifier': 'head', - **kwargs, - } - - -default_cfgs = { - # patch models (my experiments) - 'vit_small_patch16_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', - ), - - # patch models (weights ported from official Google JAX impl) - 'vit_base_patch16_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - ), -} - - -class VisionTransformer(SequenceModule): - """ Vision Transformer - A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - - https://arxiv.org/abs/2010.11929 - Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` - - https://arxiv.org/abs/2012.12877 - """ - - def __init__( - self, - img_size=224, - patch_size=16, - in_chans=3, - num_classes=1000, - d_model=768, - depth=12, - expand=4, - representation_size=None, - distilled=False, - dropout=0., - drop_path_rate=0., - embed_layer=PatchEmbed, - norm='layer', - weight_init='', - layer=None, - transposed=False, - layer_reps=1, - use_pos_embed=False, - use_cls_token=False, - track_norms=False, - ): - """ - Args: - img_size (int, tuple): input image size - patch_size (int, tuple): patch size - in_chans (int): number of input channels - num_classes (int): number of classes for classification head - d_model (int): embedding dimension - depth (int): depth of transformer - num_heads (int): number of attention heads - mlp_ratio (int): ratio of mlp hidden dim to embedding dim - qkv_bias (bool): enable bias for qkv if True - qk_scale (float): override default qk scale of head_dim ** -0.5 if set - representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set - distilled (bool): model includes a distillation token and head as in DeiT models - dropout (float): dropout rate - attn_drop_rate (float): attention dropout rate - drop_path_rate (float): stochastic depth rate - embed_layer (nn.Module): patch embedding layer - norm_layer: (nn.Module): normalization layer - weight_init: (str): weight init scheme - """ - super().__init__() - self.num_classes = num_classes - self.num_features = self.d_model = d_model # num_features for consistency with other models - self.num_tokens = 2 if distilled else 1 - self.use_pos_embed = use_pos_embed - self.use_cls_token = use_cls_token - - self.track_norms = track_norms - - self.patch_embed = embed_layer( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=d_model, - ) - num_patches = self.patch_embed.num_patches - - self.cls_token = None - self.dist_token = None - if use_cls_token: - self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model)) - self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model)) if distilled else None - else: - assert not distilled, 'Distillation token not supported without class token' - - self.pos_embed = None - if use_pos_embed: - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, d_model)) - self.pos_drop = nn.Dropout(p=dropout) - - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule - - self.transposed = transposed - - layer = to_list(layer, recursive=False) * layer_reps - - # Some special arguments are passed into each layer - for _layer in layer: - # If layers don't specify dropout, add it - if _layer.get('dropout', None) is None: - _layer['dropout'] = dropout - # Ensure all layers are shaped the same way - _layer['transposed'] = transposed - - # Config for the inverted bottleneck - ff_cfg = { - '_name_': 'ff', - 'expand': int(expand), - 'transposed': self.transposed, - 'activation': 'gelu', - 'initializer': None, - 'dropout': dropout, - } - - blocks = [] - for i in range(depth): - for _layer in layer: - blocks.append( - SequenceResidualBlock( - d_input=d_model, - i_layer=i, - prenorm=True, - dropout=dropout, - layer=_layer, - residual='R', - norm=norm, - pool=None, - drop_path=dpr[i], - ) - ) - if expand > 0: - blocks.append( - SequenceResidualBlock( - d_input=d_model, - i_layer=i, - prenorm=True, - dropout=dropout, - layer=ff_cfg, - residual='R', - norm=norm, - pool=None, - drop_path=dpr[i], - ) - ) - self.blocks = nn.Sequential(*blocks) - - if norm is None: - self.norm = None - elif isinstance(norm, str): - self.norm = Normalization(d_model, transposed=self.transposed, _name_=norm) - else: - self.norm = Normalization(d_model, transposed=self.transposed, **norm) - - # Representation layer: generally defaults to nn.Identity() - if representation_size and not distilled: - self.num_features = representation_size - self.pre_logits = nn.Sequential(OrderedDict([ - ('fc', nn.Linear(d_model, representation_size)), - ('act', nn.Tanh()) - ])) - else: - self.pre_logits = nn.Identity() - - # Classifier head(s): TODO: move to decoder - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - self.head_dist = None - if distilled: - self.head_dist = nn.Linear(self.d_model, self.num_classes) if num_classes > 0 else nn.Identity() - - # Weight init - assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '') - head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0. - if self.pos_embed is not None: - trunc_normal_(self.pos_embed, std=.02) - if self.dist_token is not None: - trunc_normal_(self.dist_token, std=.02) - if weight_init.startswith('jax'): - # leave cls token as zeros to match jax impl - for n, m in self.named_modules(): - _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True) - else: - if self.cls_token is not None: - trunc_normal_(self.cls_token, std=.02) - self.apply(_init_vit_weights) - - def _init_weights(self, m): - # this fn left here for compat with downstream users - _init_vit_weights(m) - - @torch.jit.ignore - def no_weight_decay(self): - return {'pos_embed', 'cls_token', 'dist_token'} - - def forward_features(self, x): - # TODO: move to encoder - x = self.patch_embed(x) - - if self.use_cls_token: - cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks - if self.dist_token is None: - x = torch.cat((cls_token, x), dim=1) - else: - x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) - - if self.use_pos_embed: - x = self.pos_drop(x + self.pos_embed) - - if self.track_norms: output_norms = [torch.mean(x.detach() ** 2)] - - for block in self.blocks: - x, _ = block(x) - if self.track_norms: output_norms.append(torch.mean(x.detach() ** 2)) - x = self.norm(x) - - if self.track_norms: - metrics = to_dict(output_norms, recursive=False) - self.metrics = {f'norm/{i}': v for i, v in metrics.items()} - - if self.dist_token is None: - if self.use_cls_token: - return self.pre_logits(x[:, 0]) - else: - # pooling: TODO move to decoder - return self.pre_logits(x.mean(1)) - else: - return x[:, 0], x[:, 1] - - def forward(self, x, rate=1.0, resolution=None, state=None): - x = self.forward_features(x) - if self.head_dist is not None: - x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple - if self.training and not torch.jit.is_scripting(): - # during inference, return the average of both classifier predictions - return x, x_dist - else: - return (x + x_dist) / 2 - else: - x = self.head(x) - return x, None - - -def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = False): - """ ViT weight initialization - * When called without n, head_bias, jax_impl args it will behave exactly the same - as my original init for compatibility with prev hparam / downstream use cases (ie DeiT). - * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl - """ - if isinstance(m, (nn.Linear)): - if n.startswith('head'): - nn.init.zeros_(m.weight) - nn.init.constant_(m.bias, head_bias) - elif n.startswith('pre_logits'): - lecun_normal_(m.weight) - nn.init.zeros_(m.bias) - else: - if jax_impl: - nn.init.xavier_uniform_(m.weight) - if m.bias is not None: - if 'mlp' in n: - nn.init.normal_(m.bias, std=1e-6) - else: - nn.init.zeros_(m.bias) - else: - if m.bias is not None: - nn.init.zeros_(m.bias) - dense_init_fn_ = partial(trunc_normal_, std=.02) - if isinstance(m, nn.Linear): - dense_init_fn_(m.weight) - - elif jax_impl and isinstance(m, nn.Conv2d): - # NOTE conv was left to pytorch default in my original init - lecun_normal_(m.weight) - if m.bias is not None: - nn.init.zeros_(m.bias) - elif isinstance(m, nn.LayerNorm): - nn.init.zeros_(m.bias) - nn.init.ones_(m.weight) - - -def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): - # Rescale the grid of position embeddings when loading from state_dict. Adapted from - # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 - _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) - ntok_new = posemb_new.shape[1] - if num_tokens: - posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] - ntok_new -= num_tokens - else: - posemb_tok, posemb_grid = posemb[:, :0], posemb[0] - gs_old = int(math.sqrt(len(posemb_grid))) - if not len(gs_new): # backwards compatibility - gs_new = [int(math.sqrt(ntok_new))] * 2 - assert len(gs_new) >= 2 - _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) - posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) - posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear') - posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) - posemb = torch.cat([posemb_tok, posemb_grid], dim=1) - return posemb - - -def checkpoint_filter_fn(state_dict, model): - """ convert patch embedding weight from manual patchify + linear proj to conv""" - out_dict = {} - if 'model' in state_dict: - # For deit models - state_dict = state_dict['model'] - for k, v in state_dict.items(): - if 'patch_embed.proj.weight' in k and len(v.shape) < 4: - # For old models that I trained prior to conv based patchification - O, I, H, W = model.patch_embed.proj.weight.shape - v = v.reshape(O, -1, H, W) - elif k == 'pos_embed' and v.shape != model.pos_embed.shape: - # To resize pos embedding when using model at different size from pretrained weights - v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1), - model.patch_embed.grid_size) - out_dict[k] = v - return out_dict - - -def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): - if default_cfg is None: - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] - - num_classes = kwargs.pop('num_classes', default_num_classes) - img_size = kwargs.pop('img_size', default_img_size) - repr_size = kwargs.pop('representation_size', None) - if repr_size is not None and num_classes != default_num_classes: - # Remove representation layer if fine-tuning. This may not always be the desired action, - # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? - _logger.warning("Removing representation layer for fine-tuning.") - repr_size = None - - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - - model = build_model_with_cfg( - VisionTransformer, - variant, - pretrained, - default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, - representation_size=repr_size, - pretrained_filter_fn=checkpoint_filter_fn, - **kwargs) - - return model - - -def vit_small_patch16_224(pretrained=False, **kwargs): - """ Tri's custom 'small' ViT model. d_model=768, depth=8, num_heads=8, mlp_ratio=3. - NOTE: - * this differs from the DeiT based 'small' definitions with d_model=384, depth=12, num_heads=6 - * this model does not have a bias for QKV (unlike the official ViT and DeiT models) - """ - print(kwargs) - model_kwargs = dict( - patch_size=16, - d_model=768, - depth=8, - expand=3, - norm='layer', - ) - model_kwargs = { - **model_kwargs, - **kwargs, - } - if pretrained: - # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model - model_kwargs.setdefault('qk_scale', 768 ** -0.5) - model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) - return model - - -def vit_base_patch16_224(pretrained=False, **kwargs): - """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). - ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. - """ - model_kwargs = dict( - patch_size=16, - d_model=768, - depth=12, - # num_heads=12, - ) - model_kwargs = { - **model_kwargs, - **kwargs, - } - model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) - return model diff --git a/src/clm/src/models/nn/__init__.py b/src/clm/src/models/nn/__init__.py deleted file mode 100644 index aee8113e..00000000 --- a/src/clm/src/models/nn/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .components import LinearActivation, Activation, Normalization, DropoutNd diff --git a/src/clm/src/models/nn/adaptive_softmax.py b/src/clm/src/models/nn/adaptive_softmax.py deleted file mode 100644 index 4ac9e2f0..00000000 --- a/src/clm/src/models/nn/adaptive_softmax.py +++ /dev/null @@ -1,404 +0,0 @@ -# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Optional -import functools - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class OptionalParameterList(nn.ParameterList): - def extra_repr(self): - child_lines = [] - for k, p in self._parameters.items(): - if p is not None: - size_str = 'x'.join(str(size) for size in p.size()) - device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device()) - parastr = 'Parameter containing: [{} of size {}{}]'.format( - torch.typename(p), size_str, device_str) - child_lines.append(' (' + str(k) + '): ' + parastr) - tmpstr = '\n'.join(child_lines) - return tmpstr - - -class ProjectedAdaptiveLogSoftmax(nn.Module): - def __init__(self, n_token, d_embed, d_proj, cutoffs, div_val=1, - tie_projs=None, out_layers_weights=None, out_projs=None, - keep_order=False, - bias_scale=0.0, - dropout=0.0, - ): - super().__init__() - - self.n_token = n_token - self.d_embed = d_embed - self.d_proj = d_proj - - self.cutoffs = list(cutoffs) + [n_token] - self.cutoff_ends = [0] + self.cutoffs - self.div_val = div_val - - self.shortlist_size = self.cutoffs[0] - self.n_clusters = len(self.cutoffs) - 1 - self.head_size = self.shortlist_size + self.n_clusters - - # bake the first False into the definition, just as [0] is built into the cutoffs - if tie_projs is None: tie_projs = [] - elif isinstance(tie_projs, bool): tie_projs = [tie_projs] * len(cutoffs) - else: tie_projs = list(tie_projs) - tie_projs = [False] + tie_projs - self.tie_projs = tie_projs - - if self.n_clusters > 0: - self.cluster_weight = nn.Parameter(torch.zeros(self.n_clusters, self.d_embed)) - self.cluster_bias = nn.Parameter(torch.zeros(self.n_clusters)) - - if not out_layers_weights: - self.out_layers_weights = nn.ParameterList() - else: - self.out_layers_weights = out_layers_weights - - self.out_layers_biases = nn.ParameterList() - - self.shared_out_projs = out_projs - self.out_projs = OptionalParameterList() - - self.dropout = dropout - self.drop = nn.Dropout(dropout) - - if div_val == 1: - if d_proj != d_embed: - for i in range(len(self.cutoffs)): - if tie_projs[i]: - self.out_projs.append(None) - else: - self.out_projs.append( - nn.Parameter(torch.zeros(d_proj, d_embed)) - ) - else: - self.out_projs.append(None) - - self.out_layers_biases.append( - nn.Parameter(torch.zeros(n_token)) - ) - - if not out_layers_weights: - self.out_layers_weights.append( - nn.Parameter(torch.zeros(n_token, d_embed)) - ) - else: - for i in range(len(self.cutoffs)): - l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i+1] - d_emb_i = d_embed // (div_val ** i) - - if tie_projs[i]: - self.out_projs.append(None) - else: - self.out_projs.append( - nn.Parameter(torch.zeros(d_proj, d_emb_i)) - ) - - self.out_layers_biases.append( - nn.Parameter(torch.zeros(r_idx - l_idx)) - ) - if not out_layers_weights: - self.out_layers_weights.append( - nn.Parameter(torch.zeros(r_idx - l_idx, d_emb_i)) - ) - for bias in self.out_layers_biases: - bound = bias_scale * d_proj ** -.5 - nn.init.uniform_(bias, -bound, bound) - - - self.keep_order = keep_order - - def _compute_logit(self, hidden, weight, bias, proj): - if proj is None: - logit = F.linear(hidden, weight, bias=bias) - else: - if self.dropout > 0.0: - logit = hidden @ proj - logit = self.drop(logit) - logit = logit @ weight.t() - else: - logit = torch.einsum('bd,de,ev->bv', (hidden, proj, weight.t())) - if bias is not None: - logit = logit + bias - return logit - - def get_out_proj(self, i): - if self.tie_projs[i]: - if len(self.shared_out_projs) == 0: - return None - elif len(self.shared_out_projs) == 1: - return self.shared_out_projs[0] - else: - return self.shared_out_projs[i] - else: - return self.out_projs[i] - - def forward(self, hidden, target, keep_order=False, key_padding_mask=None, *args, **kwargs): - # [21-09-15 AG]: TODO may need to handle key_padding_mask - ''' - hidden :: [len*bsz x d_proj] - target :: [len*bsz] - ''' - - hidden = hidden.reshape(-1, hidden.size(-1)) - target = target.reshape(-1) - if hidden.size(0) != target.size(0): - print(hidden.shape, target.shape) - raise RuntimeError('Input and target should have the same size ' - 'in the batch dimension.') - - if self.n_clusters == 0: - logit = self._compute_logit(hidden, self.out_layers_weights[0], - self.out_layers_biases[0], self.get_out_proj(0)) - nll = -F.log_softmax(logit, dim=-1) \ - .gather(1, target.unsqueeze(1)).squeeze(1) - else: - # construct weights and biases - weights, biases = [], [] - for i in range(len(self.cutoffs)): - if self.div_val == 1: - l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] - weight_i = self.out_layers_weights[0][l_idx:r_idx] - bias_i = self.out_layers_biases[0][l_idx:r_idx] - else: - weight_i = self.out_layers_weights[i] - bias_i = self.out_layers_biases[i] - - if i == 0: - weight_i = torch.cat( - [weight_i, self.cluster_weight], dim=0) - bias_i = torch.cat( - [bias_i, self.cluster_bias], dim=0) - - weights.append(weight_i) - biases.append(bias_i) - - head_weight, head_bias, head_proj = weights[0], biases[0], self.get_out_proj(0) - - head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) - head_logprob = F.log_softmax(head_logit, dim=1) - - nll = torch.zeros_like(target, dtype=hidden.dtype, device=hidden.device) - - offset = 0 - cutoff_values = [0] + self.cutoffs - for i in range(len(cutoff_values) - 1): - l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] - - mask_i = (target >= l_idx) & (target < r_idx) - indices_i = mask_i.nonzero(as_tuple=False).squeeze() - - if indices_i.numel() == 0: - continue - - target_i = target.index_select(0, indices_i) - l_idx - head_logprob_i = head_logprob.index_select(0, indices_i) - - if i == 0: - logprob_i = head_logprob_i.gather(1, target_i[:, None]).squeeze(1) - else: - weight_i, bias_i, proj_i = weights[i], biases[i], self.get_out_proj(i) - - hidden_i = hidden.index_select(0, indices_i) - - tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) - tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) - - # First term accounts for cluster probabilities - logprob_i = head_logprob_i[:, -i] \ - + tail_logprob_i.gather(1, target_i[:, None]).squeeze(1) - - if self.keep_order or keep_order: - nll.index_copy_(0, indices_i, -logprob_i) - else: - nll[offset:offset+logprob_i.size(0)].copy_(-logprob_i) - - offset += logprob_i.size(0) # TODO This should be a bug in the original implementation; it should go into the continue case above as well - - return nll.mean() # TODO maybe cases for length or padding_mask - - def compute_logits(self, hidden): - """Compute full vector of logits - - Adapted from https://github.com/kimiyoung/transformer-xl/issues/88 - """ - hidden = hidden.reshape(-1, hidden.size(-1)) - - if self.n_clusters == 0: - logits = self._compute_logit(hidden, self.out_layers_weights[0], - self.out_layers_biases[0], self.get_out_proj(0)) - return logits - else: - # construct weights and biases - weights, biases = [], [] - for i in range(len(self.cutoffs)): - if self.div_val == 1: - l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] - weight_i = self.out_layers_weights[0][l_idx:r_idx] - bias_i = self.out_layers_biases[0][l_idx:r_idx] - else: - weight_i = self.out_layers_weights[i] - bias_i = self.out_layers_biases[i] - - if i == 0: - weight_i = torch.cat( - [weight_i, self.cluster_weight], dim=0) - bias_i = torch.cat( - [bias_i, self.cluster_bias], dim=0) - - weights.append(weight_i) - biases.append(bias_i) - - head_weight, head_bias, head_proj = weights[0], biases[0], self.get_out_proj(0) - - head_logit = self._compute_logit(hidden, head_weight, head_bias, head_proj) - head_logprob = F.log_softmax(head_logit, dim=1) - - out_full_logps = [head_logprob[:, :self.cutoffs[0]]] - offset = 0 - cutoff_values = [0] + self.cutoffs - - for i in range(1, len(cutoff_values) - 1): - l_idx, r_idx = cutoff_values[i], cutoff_values[i + 1] - head_logprob_i = head_logprob # .index_select(0, indices_i) - - if i == 0: - logprob_i = head_logprob_i - else: - weight_i, bias_i, proj_i = weights[i], biases[i], self.get_out_proj(i) - - hidden_i = hidden # .index_select(0, indices_i) - - tail_logit_i = self._compute_logit(hidden_i, weight_i, bias_i, proj_i) - tail_logprob_i = F.log_softmax(tail_logit_i, dim=1) - logprob_i = head_logprob_i[:, -i].view(-1, 1) + tail_logprob_i - - offset += logprob_i.size(0) - out_full_logps.append(logprob_i) - out_full_logps = torch.cat(out_full_logps, dim = 1) - # print(torch.sum(out_full_ps), out_full_ps.shape) - return out_full_logps - - -class AdaptiveEmbedding(nn.Module): - """ Copy of transformers.AdaptiveEmbedding that works with fp16 by replacing the index_put_ operation - - Initialization has been fixed for the case when d_proj = d_embed - """ - def __init__(self, n_token, d_embed, d_proj, cutoffs : List[int], div_val=1, init_scale=1.0, sample_softmax=False, dropout=0.0): - super().__init__() - - self.n_token = n_token - self.d_embed = d_embed - - self.cutoffs = list(cutoffs) + [n_token] - self.div_val = div_val - self.d_proj = d_proj - self.drop = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() - - self.emb_scale = d_proj ** 0.5 - - self.cutoff_ends = [0] + self.cutoffs - - self.emb_layers = nn.ModuleList() - self.emb_projs = nn.ParameterList() - if div_val == 1: - self.emb_layers.append(nn.Embedding(n_token, d_embed, sparse=sample_softmax > 0)) - _init_embed(self.emb_layers[-1].weight, d_embed, init_scale) - # torch.nn.init.normal_(self.emb_layers[-1].weight, mean=0, std=init_scale * d_embed ** -.5) - if d_proj != d_embed: # TODO - # self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed))) - self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_embed))) - # torch.nn.init.normal_(self.emb_projs[-1], mean=0, std=init_scale * 1./self.emb_scale) - _init_proj(self.emb_projs[-1], d_proj, init_scale) - else: - for i in range(len(self.cutoffs)): - l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] - d_emb_i = d_embed // (div_val ** i) - self.emb_layers.append(nn.Embedding(r_idx - l_idx, d_emb_i)) - # torch.nn.init.normal_(self.emb_layers[-1].weight, mean=0, std=init_scale * d_emb_i ** -.5) - _init_embed(self.emb_layers[-1].weight, d_emb_i, init_scale) - self.emb_projs.append(nn.Parameter(torch.FloatTensor(d_proj, d_emb_i))) - # torch.nn.init.normal_(self.emb_projs[-1], mean=0, std=init_scale * 1./self.emb_scale) - _init_proj(self.emb_projs[-1], d_proj, init_scale) - - def forward(self, inp): - if self.div_val == 1: - embed = self.emb_layers[0](inp) - embed = self.drop(embed) - if self.d_proj != self.d_embed: - embed = F.linear(embed, self.emb_projs[0]) - else: - param = next(self.parameters()) - inp_flat = inp.reshape(-1) - - # Changes from original impl - # emb_flat = torch.zeros([inp_flat.size(0), self.d_proj], dtype=param.dtype, device=param.device) - embeddings = [] - indices = torch.zeros_like(inp_flat) # empty should work as long as cutoffs[-1] > max token - _total_tokens = 0 - - # emb_flat = inp.new_zeros(inp_flat.size(0), self.d_proj) - for i in range(len(self.cutoffs)): - l_idx, r_idx = self.cutoff_ends[i], self.cutoff_ends[i + 1] - - mask_i = (inp_flat >= l_idx) & (inp_flat < r_idx) - indices_i = mask_i.nonzero().squeeze(-1) # shape (_tokens,) - - _tokens = indices_i.numel() - if _tokens == 0: - continue - - inp_i = inp_flat.index_select(0, indices_i) - l_idx - emb_i = self.emb_layers[i](inp_i) - emb_i = self.drop(emb_i) - emb_i = F.linear(emb_i, self.emb_projs[i]) - - # Changes - embeddings.append(emb_i) - indices.index_put_( - (indices_i,), - torch.arange(_tokens, device=inp.device) + _total_tokens - ) - _total_tokens += _tokens - - # emb_flat.index_copy_(0, indices_i, emb_i) - embeddings = torch.cat(embeddings, dim=0) - emb_flat = embeddings[indices] - - embed_shape = inp.size() + (self.d_proj,) - embed = emb_flat.view(embed_shape) - - embed.mul_(self.emb_scale) - # embed.div_(self.emb_scale) - - return embed - - -def _init_weight(weight, d : int, init_scale : Optional[float], default=None): - assert init_scale or default - if init_scale is None: - std = default - else: - std = init_scale * (d ** -0.5) - nn.init.normal_(weight, mean=0, std=std) - -_init_embed = functools.partial(_init_weight, default=0.02) -_init_proj = functools.partial(_init_weight, default=0.01) diff --git a/src/clm/src/models/nn/components.py b/src/clm/src/models/nn/components.py deleted file mode 100644 index b47e951e..00000000 --- a/src/clm/src/models/nn/components.py +++ /dev/null @@ -1,389 +0,0 @@ -""" Utility nn components, in particular handling activations, initializations, and normalization layers """ - -from functools import partial -import math -from typing import ForwardRef -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from opt_einsum import contract - - -def stochastic_depth(input: torch.tensor, p: float, mode: str, training: bool = True): - """ - Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth" - `_ used for randomly dropping residual - branches of residual architectures. - - Args: - input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one - being its batch i.e. a batch with ``N`` rows. - p (float): probability of the input to be zeroed. - mode (str): ``"batch"`` or ``"row"``. - ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes - randomly selected rows from the batch. - training: apply stochastic depth if is ``True``. Default: ``True`` - - Returns: - Tensor[N, ...]: The randomly zeroed tensor. - """ - if p < 0.0 or p > 1.0: - raise ValueError("drop probability has to be between 0 and 1, but got {}".format(p)) - if mode not in ["batch", "row"]: - raise ValueError("mode has to be either 'batch' or 'row', but got {}".format(mode)) - if not training or p == 0.0: - return input - - survival_rate = 1.0 - p - if mode == "row": - size = [input.shape[0]] + [1] * (input.ndim - 1) - else: - size = [1] * input.ndim - noise = torch.empty(size, dtype=input.dtype, device=input.device) - noise = noise.bernoulli_(survival_rate).div_(survival_rate) - return input * noise - -class StochasticDepth(nn.Module): - """ - See :func:`stochastic_depth`. - """ - def __init__(self, p: float, mode: str) -> None: - # TODO(karan): need to upgrade to torchvision==0.11.0 to use StochasticDepth directly - # from torchvision.ops import StochasticDepth - super().__init__() - self.p = p - self.mode = mode - - def forward(self, input): - return stochastic_depth(input, self.p, self.mode, self.training) - - def __repr__(self) -> str: - tmpstr = self.__class__.__name__ + '(' - tmpstr += 'p=' + str(self.p) - tmpstr += ', mode=' + str(self.mode) - tmpstr += ')' - return tmpstr - -class DropoutNd(nn.Module): - def __init__(self, p: float = 0.5, tie=True, transposed=True): - """ - tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d) - """ - super().__init__() - if p < 0 or p >= 1: - raise ValueError("dropout probability has to be in [0, 1), " "but got {}".format(p)) - self.p = p - self.tie = tie - self.transposed = transposed - self.binomial = torch.distributions.binomial.Binomial(probs=1-self.p) - - def forward(self, X): - """ X: (batch, dim, lengths...) """ - if self.training: - if not self.transposed: X = rearrange(X, 'b d ... -> b ... d') - # binomial = torch.distributions.binomial.Binomial(probs=1-self.p) # This is incredibly slow - mask_shape = X.shape[:2] + (1,)*(X.ndim-2) if self.tie else X.shape - # mask = self.binomial.sample(mask_shape) - mask = torch.rand(*mask_shape, device=X.device) < 1.-self.p - X = X * mask * (1.0/(1-self.p)) - if not self.transposed: X = rearrange(X, 'b ... d -> b d ...') - return X - return X - - -def Activation(activation=None, size=None, dim=-1): - if activation in [ None, 'id', 'identity', 'linear' ]: - return nn.Identity() - elif activation == 'tanh': - return nn.Tanh() - elif activation == 'relu': - return nn.ReLU() - elif activation == 'gelu': - return nn.GELU() - elif activation in ['swish', 'silu']: - return nn.SiLU() - elif activation == 'glu': - return nn.GLU(dim=dim) - elif activation == 'sigmoid': - return nn.Sigmoid() - elif activation == 'softplus': - return nn.Softplus() - elif activation in ['sqrelu', 'relu2']: - return SquaredReLU() - elif activation == 'laplace': - return Laplace() - elif activation == 'ln': - return TransposedLN(dim) - else: - raise NotImplementedError("hidden activation '{}' is not implemented".format(activation)) - -def get_initializer(name, activation=None): - if activation in [ None, 'id', 'identity', 'linear' ]: - nonlinearity = 'linear' - elif activation in ['relu', 'tanh', 'sigmoid']: - nonlinearity = activation - elif activation in ['gelu', 'swish', 'silu']: - nonlinearity = 'relu' # Close to ReLU so approximate with ReLU's gain - else: - raise NotImplementedError(f"get_initializer: activation {activation} not supported") - - if name == 'uniform': - initializer = partial(torch.nn.init.kaiming_uniform_, nonlinearity=nonlinearity) - elif name == 'normal': - initializer = partial(torch.nn.init.kaiming_normal_, nonlinearity=nonlinearity) - elif name == 'xavier': - initializer = torch.nn.init.xavier_normal_ - elif name == 'zero': - initializer = partial(torch.nn.init.constant_, val=0) - elif name == 'one': - initializer = partial(torch.nn.init.constant_, val=1) - else: - raise NotImplementedError(f"get_initializer: initializer type {name} not supported") - - return initializer - -def LinearActivation( - d_input, d_output, bias=True, - zero_bias_init=False, - transposed=False, - initializer=None, - activation=None, - activate=False, # Apply activation as part of this module - weight_norm=False, - **kwargs, - ): - """ Returns a linear nn.Module with control over axes order, initialization, and activation """ - - # Construct core module - # linear_cls = partial(nn.Conv1d, kernel_size=1) if transposed else nn.Linear - linear_cls = TransposedLinear if transposed else nn.Linear - if activation == 'glu': d_output *= 2 - linear = linear_cls(d_input, d_output, bias=bias, **kwargs) - - # Initialize weight - if initializer is not None: - get_initializer(initializer, activation)(linear.weight) - - # Initialize bias - if bias and zero_bias_init: - nn.init.zeros_(linear.bias) - - # Weight norm - if weight_norm: - linear = nn.utils.weight_norm(linear) - - if activate and activation is not None: - activation = Activation(activation, d_output, dim=1 if transposed else -1) - linear = nn.Sequential(linear, activation) - return linear - -class SquaredReLU(nn.Module): - def forward(self, x): - # return F.relu(x)**2 - return torch.square(F.relu(x)) # Could this be faster? - -def laplace(x, mu=0.707107, sigma=0.282095): - x = (x - mu).div(sigma * math.sqrt(2.0)) - return 0.5 * (1.0 + torch.erf(x)) - -class Laplace(nn.Module): - def __init__(self, mu=0.707107, sigma=0.282095): - super().__init__() - self.mu = mu - self.sigma = sigma - - def forward(self, x): - return laplace(x, mu=self.mu, sigma=self.sigma) - - -class TransposedLinear(nn.Module): - """ Linear module on the second-to-last dimension - Assumes shape (B, D, L), where L can be 1 or more axis - """ - - def __init__(self, d_input, d_output, bias=True): - super().__init__() - - self.weight = nn.Parameter(torch.empty(d_output, d_input)) - nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # nn.Linear default init - # nn.init.kaiming_uniform_(self.weight, nonlinearity='linear') # should be equivalent - - if bias: - self.bias = nn.Parameter(torch.empty(d_output)) - bound = 1 / math.sqrt(d_input) - nn.init.uniform_(self.bias, -bound, bound) - setattr(self.bias, "_optim", {"weight_decay": 0.0}) - else: - self.bias = 0.0 - - def forward(self, x): - num_axis = len(x.shape[2:]) # num_axis in L, for broadcasting bias - y = contract('b u ..., v u -> b v ...', x, self.weight) + self.bias.view(-1, *[1]*num_axis) - return y - - -class TransposedLN(nn.Module): - """ LayerNorm module over second dimension - Assumes shape (B, D, L), where L can be 1 or more axis - - This is slow and a dedicated CUDA/Triton implementation shuld provide substantial end-to-end speedup - """ - def __init__(self, d, scalar=True): - super().__init__() - self.scalar = scalar - if self.scalar: - self.m = nn.Parameter(torch.zeros(1)) - self.s = nn.Parameter(torch.ones(1)) - setattr(self.m, "_optim", {"weight_decay": 0.0}) - setattr(self.s, "_optim", {"weight_decay": 0.0}) - else: - self.ln = nn.LayerNorm(d) - - def forward(self, x): - if self.scalar: - # calc. stats over D dim / channels - s, m = torch.std_mean(x, dim=1, unbiased=False, keepdim=True) - y = (self.s/s) * (x-m+self.m) - else: - # move channel to last axis, apply layer_norm, then move channel back to second axis - _x = self.ln(rearrange(x, 'b d ... -> b ... d')) - y = rearrange(_x, 'b ... d -> b d ...') - return y - -class Normalization(nn.Module): - def __init__( - self, - d, - transposed=False, # Length dimension is -1 or -2 - _name_='layer', - **kwargs - ): - super().__init__() - self.transposed = transposed - self._name_ = _name_ - - if _name_ == 'layer': - self.channel = True # Normalize over channel dimension - if self.transposed: - self.norm = TransposedLN(d, **kwargs) - else: - self.norm = nn.LayerNorm(d, **kwargs) - elif _name_ == 'instance': - self.channel = False - norm_args = {'affine': False, 'track_running_stats': False} - norm_args.update(kwargs) - self.norm = nn.InstanceNorm1d(d, **norm_args) # (True, True) performs very poorly - elif _name_ == 'batch': - self.channel = False - norm_args = {'affine': True, 'track_running_stats': True} - norm_args.update(kwargs) - self.norm = nn.BatchNorm1d(d, **norm_args) - elif _name_ == 'group': - self.channel = False - self.norm = nn.GroupNorm(1, d, *kwargs) - elif _name_ == 'none': - self.channel = True - self.norm = nn.Identity() - else: raise NotImplementedError - - def forward(self, x): - # Handle higher dimension logic - shape = x.shape - if self.transposed: - x = rearrange(x, 'b d ... -> b d (...)') - else: - x = rearrange(x, 'b ... d -> b (...)d ') - - # The cases of LayerNorm / no normalization are automatically handled in all cases - # Instance/Batch Norm work automatically with transposed axes - if self.channel or self.transposed: - x = self.norm(x) - else: - x = x.transpose(-1, -2) - x = self.norm(x) - x = x.transpose(-1, -2) - - x = x.view(shape) - return x - - def step(self, x, **kwargs): - assert self._name_ in ["layer", "none"] - if self.transposed: x = x.unsqueeze(-1) - x = self.forward(x) - if self.transposed: x = x.squeeze(-1) - return x - -class TSNormalization(nn.Module): - - def __init__(self, method, horizon): - super().__init__() - - self.method = method - self.horizon = horizon - - - def forward(self, x): - # x must be BLD - if self.method == 'mean': - self.scale = x.abs()[:, :-self.horizon].mean(dim=1)[:, None, :] - return x / self.scale - elif self.method == 'last': - self.scale = x.abs()[:, -self.horizon-1][:, None, :] - return x / self.scale - return x - -class TSInverseNormalization(nn.Module): - - def __init__(self, method, normalizer): - super().__init__() - - self.method = method - self.normalizer = normalizer - - def forward(self, x): - if self.method == 'mean' or self.method == 'last': - return x * self.normalizer.scale - return x - -class ReversibleInstanceNorm1dInput(nn.Module): - def __init__(self, d, transposed=False): - super().__init__() - # BLD if transpoed is False, otherwise BDL - self.transposed = transposed - self.norm = nn.InstanceNorm1d(d, affine=True, track_running_stats=False) - - def forward(self, x): - # Means, stds - if not self.transposed: - x = x.transpose(-1, -2) - - self.s, self.m = torch.std_mean(x, dim=-1, unbiased=False, keepdim=True) - self.s += 1e-4 - - x = (x - self.m) / self.s - # x = self.norm.weight.unsqueeze(-1) * x + self.norm.bias.unsqueeze(-1) - - if not self.transposed: - return x.transpose(-1, -2) - return x - -class ReversibleInstanceNorm1dOutput(nn.Module): - - def __init__(self, norm_input): - super().__init__() - self.transposed = norm_input.transposed - self.weight = norm_input.norm.weight - self.bias = norm_input.norm.bias - self.norm_input = norm_input - - def forward(self, x): - if not self.transposed: - x = x.transpose(-1, -2) - - # x = (x - self.bias.unsqueeze(-1))/self.weight.unsqueeze(-1) - x = x * self.norm_input.s + self.norm_input.m - - if not self.transposed: - return x.transpose(-1, -2) - return x diff --git a/src/clm/src/models/nn/dxt.py b/src/clm/src/models/nn/dxt.py deleted file mode 100644 index a9813bc5..00000000 --- a/src/clm/src/models/nn/dxt.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Implementations of several types of Discrete Sin/Cosine Transforms with various reductions to FFT. - -Currently not used by S4 -""" - -import torch -import torch.nn as nn -import numpy as np -import scipy.fft -from einops import rearrange, repeat - -class DCT(nn.Module): - """ Reductions adapted from https://dsp.stackexchange.com/questions/2807/fast-cosine-transform-via-fft """ - - def __init__(self, N, norm='backward'): - super().__init__() - - self.N = N - - # Materialize DCT matrix - P = scipy.fft.dct(np.eye(N), norm=norm, type=2).T - P = torch.tensor(P, dtype=torch.float) - self.register_buffer('P', P) - - # TODO take care of normalization - Q = np.exp(-1j * np.pi / (2 * self.N) * np.arange(self.N)) - Q = torch.tensor(Q, dtype=torch.cfloat) - self.register_buffer('Q', Q) # half shift - - def forward(self, x, mode=2): - if mode == 0: - return self.forward_dense(x) - elif mode == 1: - return self.forward_n(x) - elif mode == 2: - return self.forward_2n(x) - elif mode == 4: - return self.forward_4n(x) - - def forward_dense(self, x): - """ Baseline DCT type II - matmul by DCT matrix """ - y = (self.P.to(x) @ x.unsqueeze(-1)).squeeze(-1) - return y - - def forward_4n(self, x): - """ DCT type II - reduction to FFT size 4N """ - assert self.N == x.shape[-1] - x = torch.cat([x, x.flip(-1)], dim=-1) - z = torch.zeros_like(x) - x = torch.stack([z, x], dim=-1) - x = x.view(x.shape[:-2] + (-1,)) - y = torch.fft.fft(x) - y = y[..., :self.N] - if torch.is_complex(x): - return y - else: - return torch.real(y) - - def forward_2n(self, x): - """ DCT type II - reduction to FFT size 2N mirrored - - The reduction from the DSP forum is not quite correct in the complex input case. - halfshift(FFT[a, b, c, d, d, c, b, a]) -> [A, B, C, D, 0, -D, -C, -B] - In the case of real input, the intermediate step after FFT has form [A, B, C, D, 0, D*, C*, B*] - """ - assert self.N == x.shape[-1] - x = torch.cat([x, x.flip(-1)], dim=-1) - y = torch.fft.fft(x)[..., :self.N] - y = y * self.Q - if torch.is_complex(x): - return y - else: - return torch.real(y) - - def forward_n(self, x): - """ DCT type II - reduction to size N """ - assert self.N == x.shape[-1] - x = torch.cat([x[..., 0::2], x[..., 1::2].flip(-1)], dim=-1) - y = torch.fft.fft(x) - y = y * 2 * self.Q - if torch.is_complex(x): - y = torch.cat([y[..., :1], (y[..., 1:] + 1j * y[..., 1:].flip(-1)) / 2], dim=-1) # TODO in-place sum - else: - y = torch.real(y) - return y - -class IDCT(nn.Module): - def __init__(self, N, norm='backward'): - super().__init__() - - self.N = N - - # Materialize DCT matrix - P = np.linalg.inv(scipy.fft.dct(np.eye(N), norm=norm, type=2).T) - P = torch.tensor(P, dtype=torch.float) - self.register_buffer('P', P) - - # TODO take care of normalization - Q = np.exp(-1j * np.pi / (2 * self.N) * np.arange(2*self.N)) - Q = torch.tensor(Q, dtype=torch.cfloat) - self.register_buffer('Q', Q) # half shift - - def forward(self, x, mode=2): - if mode == 0: - return self.forward_dense(x) - elif mode == 1: - return self.forward_n(x) - elif mode == 2: - return self.forward_2n(x) - elif mode == 4: - return self.forward_4n(x) - - def forward_dense(self, x): - """ Baseline DCT type II - matmul by DCT matrix """ - y = (self.P.to(x) @ x.unsqueeze(-1)).squeeze(-1) - return y - - def forward_4n(self, x): - """ DCT type II - reduction to FFT size 4N """ - assert self.N == x.shape[-1] - z = x.new_zeros(x.shape[:-1] + (1,)) - x = torch.cat([x, z, -x.flip(-1), -x[..., 1:], z, x[..., 1:].flip(-1)], dim=-1) - y = torch.fft.ifft(x) - y = y[..., 1:2*self.N:2] - if torch.is_complex(x): - return y - else: - return torch.real(y) - - def forward_2n(self, x): - """ DCT type II - reduction to FFT size 2N mirrored """ - assert self.N == x.shape[-1] - z = x.new_zeros(x.shape[:-1] + (1,)) - x = torch.cat([x, z, -x[..., 1:].flip(-1)], dim=-1) - x = x / self.Q - y = torch.fft.ifft(x)[..., :self.N] - if torch.is_complex(x): - return y - else: - return torch.real(y) - - def forward_n(self, x): - """ DCT type II - reduction to size N """ - assert self.N == x.shape[-1] - raise NotImplementedError # Straightforward by inverting operations of DCT-II reduction - -def test_dct_ii(): - N = 8 - dct = DCT(N) - - baseline = dct.forward_dense - methods = [dct.forward_4n, dct.forward_2n, dct.forward_n] - - # Real case - print("DCT-II Real input") - x = torch.randn(1, N) - y = baseline(x) - print(y) - for fn in methods: - y_ = fn(x) - print("err", torch.norm(y-y_)) - - # Complex case - print("DCT-II Complex input") - x = torch.randn(N) + 1j * torch.randn(N) - y = baseline(x) - print(y) - for fn in methods: - y_ = fn(x) - print("err", torch.norm(y-y_)) - -def test_dct_iii(): - N = 8 - dct = IDCT(N) - - baseline = dct.forward_dense - methods = [dct.forward_4n, dct.forward_2n] - - # Real case - print("DCT-III Real input") - x = torch.randn(1, N) - y = baseline(x) - print(y) - for fn in methods: - y_ = fn(x) - print("err", torch.norm(y-y_)) - - # Complex case - print("DCT-III Complex input") - # x = torch.randn(N) + 1j * torch.randn(N) - x = 1j * torch.ones(N) - y = baseline(x) - print(y) - for fn in methods: - y_ = fn(x) - print("err", torch.norm(y-y_)) diff --git a/src/clm/src/models/nn/gate.py b/src/clm/src/models/nn/gate.py deleted file mode 100644 index d0a531f7..00000000 --- a/src/clm/src/models/nn/gate.py +++ /dev/null @@ -1,128 +0,0 @@ -""" Defines flexible gating mechanisms based on ideas from LSSL paper and UR-LSTM paper https://arxiv.org/abs/1910.09890 """ - -import torch -import torch.nn as nn - -class Gate(nn.Module): - """ Implements gating mechanisms. TODO update this with more detailed description with reference to LSSL paper when it's on arxiv - - Mechanisms: - N - No gate - G - Standard sigmoid gate - UR - Uniform refine gates - R - Refine gate - - FS - Forward discretization, Sigmoid activation [equivalent to G] - BE - Backward discretization, Exp activation [equivalent to G] - BR - Backward discretization, Relu activation - TE - Trapezoid discretization, Exp activation - TR - Trapezoid discretization, Relu activation - TS - Trapezoid discretization, Sigmoid activation (0 to 2) - """ - def __init__(self, size, preact_ctor, preact_args, mechanism='N'): - super().__init__() - self.size = size - self.mechanism = mechanism - - if self.mechanism == 'N': - pass - elif self.mechanism in ['G', 'FS', 'BE', 'BR', 'TE', 'TR', 'TS', 'ZE', 'ZR', 'ZS']: - self.W_g = preact_ctor(*preact_args) - elif self.mechanism in ['U', 'UT']: - self.W_g = preact_ctor(*preact_args) - b_g_unif = torch.empty(size) - torch.nn.init.uniform_(b_g_unif, 1./self.size, 1.-1./self.size) - self.b_g = nn.Parameter(torch.log(1./b_g_unif-1.).detach(), requires_grad=False) - elif self.mechanism == 'UR': - self.W_g = preact_ctor(*preact_args) - self.W_r = preact_ctor(*preact_args) - - b_g_unif = torch.empty(size) - torch.nn.init.uniform_(b_g_unif, 1./self.size, 1.-1./self.size) - self.b_g = nn.Parameter(torch.log(1./b_g_unif-1.).detach(), requires_grad=False) - elif self.mechanism == 'R': - self.W_g = preact_ctor(*preact_args) - self.W_r = preact_ctor(*preact_args) - elif self.mechanism in ['GT']: - self.W_g = preact_ctor(*preact_args) - else: - assert False, f'Gating type {self.mechanism} is not supported.' - - def forward(self, *inputs): - if self.mechanism == 'N': - return 1.0 - - if self.mechanism == 'G': - g_preact = self.W_g(*inputs) - g = torch.sigmoid(g_preact) - if self.mechanism == 'U': - g_preact = self.W_g(*inputs) + self.b_g - g = torch.sigmoid(g_preact) - elif self.mechanism == 'UR': - g_preact = self.W_g(*inputs) + self.b_g - g = torch.sigmoid(g_preact) - r = torch.sigmoid(self.W_r(*inputs)) - g = (1-2*r)*g**2 + 2*r*g - elif self.mechanism == 'R': - g_preact = self.W_g(*inputs) - g = torch.sigmoid(g_preact) - r = torch.sigmoid(self.W_r(*inputs)) - g = (1-2*r)*g**2 + 2*r*g - elif self.mechanism == 'UT': - g_preact = self.W_g(*inputs) + self.b_g - g = torch.sigmoid(g_preact) - r = g - g = (1-2*r)*g**2 + 2*r*g - elif self.mechanism == 'GT': - g_preact = self.W_g(*inputs) - g = torch.sigmoid(g_preact) - r = g - g = (1-2*r)*g**2 + 2*r*g - else: - g_preact = self.W_g(*inputs) - # if self.mechanism[1] == 'S': - # g = torch.sigmoid(g_preact) - # elif self.mechanism[1] == 'E': - # g = torch.exp(g_preact) - # elif self.mechanism[1] == 'R': - # g = torch.relu(g_preact) - if self.mechanism == 'FS': - g = torch.sigmoid(g_preact) - g = self.forward_diff(g) - elif self.mechanism == 'BE': - g = torch.exp(g_preact) - g = self.backward_diff(g) - elif self.mechanism == 'BR': - g = torch.relu(g_preact) - g = self.backward_diff(g) - elif self.mechanism == 'TS': - g = 2 * torch.sigmoid(g_preact) - g = self.trapezoid(g) - elif self.mechanism == 'TE': - g = torch.exp(g_preact) - g = self.trapezoid(g) - elif self.mechanism == 'TR': - g = torch.relu(g_preact) - g = self.trapezoid(g) - elif self.mechanism == 'ZE': - g = torch.exp(g_preact) - g = self.zoh(g) - elif self.mechanism == 'ZR': - g = torch.relu(g_preact) - g = self.zoh(g) - elif self.mechanism == 'ZS': - g = torch.sigmoid(g_preact) - g = self.zoh(g) - return g - - def forward_diff(self, x): - return x - - def backward_diff(self, x): - return x / (1+x) - - def trapezoid(self, x): - return x / (1 + x/2) - - def zoh(self, x): - return 1 - torch.exp(-x) diff --git a/src/clm/src/models/nn/residual.py b/src/clm/src/models/nn/residual.py deleted file mode 100644 index 360697e2..00000000 --- a/src/clm/src/models/nn/residual.py +++ /dev/null @@ -1,108 +0,0 @@ -""" Implementations of different types of residual functions. """ - -import torch -from torch import nn - -class Residual(nn.Module): - """ Residual connection with constant affine weights. Can simulate standard residual, no residual, and "constant gates". """ - - def __init__(self, i_layer, d_input, d_model, alpha=1.0, beta=1.0): - # print("ConstantResidual extra kwargs", kwargs) - super().__init__() - assert (d_input == d_model) or alpha == 0.0 - self.i_layer = i_layer - self.d_input = d_input - self.d_model = d_model - self.alpha = alpha - self.beta = beta - - @property - def d_output(self): - return self.d_model - - def forward(self, x, y, transposed): # TODO documentation of transposed - y = self.beta*y if self.beta != 1.0 else y - return self.alpha * x + y if self.alpha else y - -class Affine(Residual): - """ Residual connection with learnable scalar multipliers on the main branch - scalar: Single scalar multiplier, or one per dimension - scale, power: Initialize to scale * layer_num**(-power) - """ - - def __init__(self, *args, scalar=True, gamma=0.0, **kwargs): - # print("ConstantResidual extra kwargs", kwargs) - super().__init__(*args, **kwargs) - self.scalar = scalar - self.gamma = gamma - - c = self.beta * self.i_layer ** (-self.gamma) - d = 1 if self.scalar else self.d_input - self.affine = nn.Parameter(c * torch.ones(d)) - - def forward(self, x, y, transposed): # TODO documentation of transposed - c = self.affine - if transposed: c = c.unsqueeze(-1) - return self.alpha * x + c * y - - -class Feedforward(Residual): - def __init__(self, *args): - # print("Feedforward extra kwargs", kwargs) - super().__init__(*args, alpha=0.0, beta=1.0) - - -class Highway(Residual): - def __init__(self, *args, scaling_correction=False, elemwise=False): - super().__init__(*args) - self.scaling_correction = 1.732 if scaling_correction else 1.0 # TODO - self.elemwise = elemwise - self.Wx = nn.Linear(self.d_input, self.d_input) - if self.elemwise: - self.Wy = nn.Parameter(torch.randn(self.d_input)) - else: - self.Wy = nn.Linear(self.d_input, self.d_input) - - def forward(self, x, y, transposed=False): # TODO handle this case - if self.elemwise: - y = self.Wy * y - else: - y = self.Wy(y) - r = torch.sigmoid(self.Wx(x) + y) - z = self.scaling_correction * (1.-r) * x + r * y - return z - - -class DecayResidual(Residual): - """ Residual connection that can decay the linear combination depending on depth. """ - - def __init__(self, *args, power=0.5, l2=True): - # print("DecayResidual extra kwargs", kwargs) - super().__init__(*args) - self.power = power - self.l2 = l2 - - def forward(self, x, y, transposed): - beta = self.i_layer ** (-self.power) - if self.l2: - alpha = (1. - beta**2)**0.5 - else: - alpha = 1. - beta - - return alpha * x + beta * y - -registry = { - 'F': Feedforward, - 'N': Feedforward, - 'R': Residual, - 'H': Highway, - 'D': DecayResidual, - 'A': Affine, - 'none': Feedforward, - 'ff': Feedforward, - 'feedforward': Feedforward, - 'residual': Residual, - 'highway': Highway, - 'decay': DecayResidual, - 'affine': Affine, -} diff --git a/src/clm/src/models/nn/utils.py b/src/clm/src/models/nn/utils.py deleted file mode 100644 index 2c4d18d9..00000000 --- a/src/clm/src/models/nn/utils.py +++ /dev/null @@ -1,125 +0,0 @@ -""" Utility wrappers around modules to let them handle Args and extra arguments """ - -import inspect -from functools import wraps -import torch -from torch import nn - -def wrap_kwargs(f): - """ - Given a callable f that can consume some named arguments, - wrap it with a kwargs that passes back any unused args - - EXAMPLES - -------- - - Basic usage: - def foo(x, y=None): - return x - - wrap_kwargs(foo)(0, y=1, z=2) == (0, {'z': 2}) - - -------- - - The wrapped function can return its own argument dictionary, - which gets merged with the new kwargs. - def foo(x, y=None): - return x, {} - wrap_kwargs(foo)(0, y=1, z=2) == (0, {'z': 2}) - - def foo(x, y=None): - return x, {"y": y, "z": None} - wrap_kwargs(foo)(0, y=1, z=2) == (0, {'y': 1, 'z': 2}) - - -------- - - The wrapped function can have its own kwargs parameter: - def foo(x, y=None, **kw_args): - return x, {} - wrap_kwargs(foo)(0, y=1, z=2) == (0, {}) - - -------- - - Partial functions and modules work automatically: - class Module: - def forward(self, x, y=0): - return x, {"y": y+1} - - m = Module() - - wrap_kwargs(m.forward)(0, y=1, z=2) == (0, {'y': 2, 'z': 2}) - - """ - sig = inspect.signature(f) - # Check if f already has kwargs - has_kwargs = any([ - param.kind == inspect.Parameter.VAR_KEYWORD - for param in sig.parameters.values() - ]) - if has_kwargs: - @wraps(f) - def f_kwargs(*args, **kwargs): - y = f(*args, **kwargs) - if isinstance(y, tuple) and isinstance(y[-1], dict): - return y - else: - return y, {} - else: - param_kwargs = inspect.Parameter("kwargs", kind=inspect.Parameter.VAR_KEYWORD) - sig_kwargs = inspect.Signature(parameters=list(sig.parameters.values())+[param_kwargs]) - @wraps(f) - def f_kwargs(*args, **kwargs): - bound = sig_kwargs.bind(*args, **kwargs) - if "kwargs" in bound.arguments: - kwargs = bound.arguments.pop("kwargs") - else: - kwargs = {} - y = f(**bound.arguments) - if isinstance(y, tuple) and isinstance(y[-1], dict): - return *y[:-1], {**y[-1], **kwargs} - else: - return y, kwargs - return f_kwargs - -def discard_kwargs(f): - if f is None: return None - f_kwargs = wrap_kwargs(f) - @wraps(f) - def f_(*args, **kwargs): - return f_kwargs(*args, **kwargs)[0] - return f_ - -def PassthroughSequential(*modules): - """Special Sequential module that chains kwargs. - - Semantics are the same as nn.Sequential, with extra convenience features: - - Discard None modules - - Flatten inner Sequential modules - - In case with 0 or 1 Module, rename the class for ease of inspection - """ - def flatten(module): - if isinstance(module, nn.Sequential): - return sum([flatten(m) for m in module], []) - else: - return [module] - - modules = flatten(nn.Sequential(*modules)) - modules = [module for module in modules if module if not None] - - class Sequential(nn.Sequential): - def forward(self, x, **kwargs): - for layer in self: - x, kwargs = wrap_kwargs(layer.forward)(x, **kwargs) - return x, kwargs - - def step(self, x, **kwargs): - for layer in self: - fn = getattr(layer, "step", layer.forward) - x, kwargs = wrap_kwargs(fn)(x, **kwargs) - return x, kwargs - - if len(modules) == 0: - Sequential.__name__ = "Identity" - elif len(modules) == 1: - Sequential.__name__ = type(modules[0]).__name__ - return Sequential(*modules) diff --git a/src/clm/src/models/sequence/__init__.py b/src/clm/src/models/sequence/__init__.py deleted file mode 100644 index 38669fb6..00000000 --- a/src/clm/src/models/sequence/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .base import SequenceModule, TransposedModule -from .model import SequenceModel -from .ff import FF diff --git a/src/clm/src/models/sequence/base.py b/src/clm/src/models/sequence/base.py deleted file mode 100644 index 4f8a4ffa..00000000 --- a/src/clm/src/models/sequence/base.py +++ /dev/null @@ -1,131 +0,0 @@ -from torch import nn -import functools - -class SequenceModule(nn.Module): - """Abstract sequence model class. All models must adhere to this interface - - A SequenceModule is generally a model that transforms an input of shape - (n_batch, l_sequence, d_model) to (n_batch, l_sequence, d_output) - - REQUIRED methods and attributes - forward, d_model, d_output: controls standard forward pass, a sequence-to-sequence transformation - __init__ should also satisfy the following interface; see SequenceIdentity for an example - def __init__(self, d_model, transposed=False, **kwargs) - - OPTIONAL methods - default_state, step: allows stepping the model recurrently with a hidden state - state_to_tensor, d_state: allows decoding from hidden state - """ - - @property - def d_model(self): - """Model dimension (generally same as input dimension). - - This attribute is required for all SequenceModule instantiations. - It is used by the rest of the pipeline (e.g. model backbone, encoder) to track the internal shapes of the full model. - """ - if getattr(self, "_d_model", None) is None: - raise NotImplementedError("SequenceModule instantiation must set d_model") - return self._d_model - - @d_model.setter - def d_model(self, d): - self._d_model = d - - @property - def d_output(self): - """Output dimension of model. - - This attribute is required for all SequenceModule instantiations. - It is used by the rest of the pipeline (e.g. model backbone, decoder) to track the internal shapes of the full model. - """ - if getattr(self, "_d_output", None) is None: - raise NotImplementedError("SequenceModule instantiation must specify d_output for decoder") - return self._d_output - - @d_output.setter - def d_output(self, d): - self._d_output = d - - def forward(self, x, state=None, **kwargs): - """Forward pass of sequence model, a sequence-to-sequence transformation with an optional state. - - Generally, this should map a tensor of shape (batch, length, self.d_model) to (batch, length, self.d_output) - - Additionally, it returns a "state" which can be any additional information - For example, RNN and SSM layers may return their hidden state, - while some types of transformer layers (e.g. Transformer-XL) may want to pass a state as well - """ - return x, None - - @property - def state_to_tensor(self): - """Returns a function mapping a state to a single tensor. - - This method should be implemented if one wants to use the hidden state instead of the output sequence for final prediction. - Currently only used with the StateDecoder. - """ - return lambda _: None - - @property - def d_state(self): - """ Returns dimension of output of self.state_to_tensor """ - return None - - - def default_state(self, *batch_shape, device=None): - """Create initial state for a batch of inputs.""" - - return None - - def step(self, x, state=None, **kwargs): - """Step the model recurrently for one step of the input sequence. - - For example, this should correspond to unrolling an RNN for one step. - If the forward pass has signature (B, L, H1) -> (B, L, H2), - this method should generally have signature (B, H1) -> (B, H2) with an optional recurrent state. - """ - raise NotImplementedError - -def TransposedModule(module): - """Wrap a SequenceModule class to accept transposed parameter, handle state, absorb kwargs""" - # https://stackoverflow.com/a/65470430/1980685 - @functools.wraps(module, updated=()) - class TransposedModule(module): - def __init__(self, *args, transposed=False, **kwargs): - super().__init__(*args, **kwargs) - self.transposed = transposed - - def forward(self, x, state=None, **kwargs): - if self.transposed: x = x.transpose(-1, -2) - x, next_state = super().forward(x, state) # Don't use kwarg because nn.LSTM - next_state = None if state is None else next_state - if self.transposed: x = x.transpose(-1,-2) - return x, next_state - # https://stackoverflow.com/questions/5352781/how-to-set-class-names-dynamically - # TransposedModule.__name__ = module.__name__ # functools wraps is better solution - return TransposedModule - -@TransposedModule -class SequenceIdentity(SequenceModule): - """Simple SequenceModule for testing purposes""" - - def __init__(self, d_model, dropout=0.0, **kwargs): - """Default interface for SequenceModule - - d_model: input dimension (sometimes denoted H for hidden dimension) - transposed: if True, inputs have axis ordering (B, H, L) instead of (B, H, L) - """ - super().__init__() - self.d_model = d_model - self.d_output = d_model - - - def forward(self, x, state=None): - return x, state - - def default_state(self, *batch_shape, device=None): - return None - - def step(self, x, state=None, **kwargs): - return x, state diff --git a/src/clm/src/models/sequence/block.py b/src/clm/src/models/sequence/block.py deleted file mode 100644 index f44ee109..00000000 --- a/src/clm/src/models/sequence/block.py +++ /dev/null @@ -1,129 +0,0 @@ -""" Implements a full residual block around a black box layer - -Configurable options include: -normalization position: prenorm or postnorm -normalization type: batchnorm, layernorm etc. -subsampling/pooling -residual options: feedforward, residual, affine scalars, depth-dependent scaling, etc. -""" - -from torch import nn - -from functools import partial -import clm.src.utils as utils -from clm.src.models.nn.components import Normalization, StochasticDepth, DropoutNd -from clm.src.models.sequence import SequenceModule -from clm.src.models.sequence.pool import registry as pool_registry -from clm.src.models.nn.residual import registry as residual_registry -import clm.src.utils.registry as registry - - -class SequenceResidualBlock(SequenceModule): - def __init__( - self, - d_input, - i_layer=None, # Only needs to be passed into certain residuals like Decay - prenorm=True, - dropout=0.0, - tie_dropout=False, - transposed=False, - layer=None, # Config for black box module - residual=None, # Config for residual function - norm=None, # Config for normalization layer - pool=None, - drop_path=0., - ): - super().__init__() - - self.i_layer = i_layer - self.d_input = d_input - self.layer = utils.instantiate(registry.layer, layer, d_input) - self.prenorm = prenorm - self.transposed = transposed - - # Residual - # d_residual is the output dimension after residual - if residual is None: - self.residual = None - self.d_residual = self.layer.d_output - else: - self.residual = utils.instantiate(residual_registry, residual, i_layer, d_input, self.layer.d_output) - self.d_residual = self.residual.d_output - - # Normalization - d_norm = d_input if self.prenorm else self.d_residual - # We don't use config to directly instantiate since Normalization has some special cases - if norm is None: - self.norm = None - elif isinstance(norm, str): - self.norm = Normalization(d_norm, transposed=self.transposed, _name_=norm) - else: - self.norm = Normalization(d_norm, transposed=self.transposed, **norm) - - # Pool - self.pool = utils.instantiate(pool_registry, pool, self.d_residual, transposed=self.transposed) - - # Dropout - dropout_cls = partial(DropoutNd, transposed=self.transposed) if tie_dropout else nn.Dropout - self.drop = dropout_cls(dropout) if dropout > 0.0 else nn.Identity() - - # Stochastic depth - self.drop_path = StochasticDepth(drop_path, mode='row') if drop_path > 0.0 else nn.Identity() - - - @property - def d_output(self): - return self.pool.d_output if self.pool is not None else self.d_residual - - @property - def d_state(self): - return self.layer.d_state - - @property - def state_to_tensor(self): - return self.layer.state_to_tensor - - def default_state(self, *args, **kwargs): - return self.layer.default_state(*args, **kwargs) - - def forward(self, x, state=None, **kwargs): - y = x - - # Pre-norm - if self.norm is not None and self.prenorm: y = self.norm(y) - - # Black box layer - y, state = self.layer(y, state=state, **kwargs) - - # Residual - if self.residual is not None: y = self.residual(x, self.drop_path(self.drop(y)), self.transposed) - - # Post-norm - if self.norm is not None and not self.prenorm: y = self.norm(y) - - # Pool - if self.pool is not None: y, _ = self.pool(y) - - return y, state - - def step(self, x, state, **kwargs): - y = x - - # Pre-norm - if self.norm is not None and self.prenorm: - y = self.norm.step(y) - - # Black box layer - y, state = self.layer.step(y, state, **kwargs) - - # Residual - if self.residual is not None: y = self.residual(x, y, transposed=False) # NOTE this would not work with concat residual function (catformer) - - # Post-norm - if self.norm is not None and not self.prenorm: - y = self.norm.step(y) - - # Pool - if self.pool is not None: y, _ = self.pool(y) - - return y, state diff --git a/src/clm/src/models/sequence/block_fft.py b/src/clm/src/models/sequence/block_fft.py deleted file mode 100644 index c0a1c568..00000000 --- a/src/clm/src/models/sequence/block_fft.py +++ /dev/null @@ -1,177 +0,0 @@ -'''PyTorch version of the block FFT convolution as described in the H3 paper.''' - -import torch -from einops import rearrange -import math -from torch import nn -from clm.src.models.nn import Activation -from clm.src.utils.train import OptimModule - -def ref_dft_matrix(N, H=1): - """Compute the DFT matrix of size N x N. - - This is where we could add extra compute for free.""" - # n = torch.arange(N) - n = torch.arange(N).cuda() - k = n.view(-1, 1) - M = torch.exp(-2j * torch.pi * n * k / N) - return torch.view_as_real(M.repeat(H, 1, 1)) - -def compute_twiddle_factors(n, m): - """Compute the twiddle factors of size n x m""" - # n_a = torch.arange(n).view(-1, 1) - # m_a = torch.arange(m) - n_a = torch.arange(n).cuda().view(-1, 1) - m_a = torch.arange(m).cuda() - N = n * m - M = torch.exp(-2j * torch.pi * n_a * m_a / N) - return torch.view_as_real(M) - -def _cooley_tukey( - k, n, m, - dft_matrix=ref_dft_matrix, - max_m=16, - activation=None, -): - ''' - Compute the FFT using the general Cooley-Tukey algorithm: - * Reshape to (m, n) - * Do n m-length FFTs along the rows - * Transpose to (n, m), multiply by twiddle factors - * Do m n-length FFTs along the rows - - This function assumes that m <= 16 and recurses on n. - The base case is n <= 16 (we are simulating tensor cores of 16x16 mm). - The dft_matrix function is overwriteable - so that we can replace it with learnable parameters in a model. - ''' - assert m <= max_m - - if activation is not None: - act_fn = Activation(activation) - - k = rearrange(k, '... (m n) -> ... m n', m=m, n=n) # (m, n) - - # do n m-length FFTs - if activation is None: - mat = torch.view_as_complex(dft_matrix(m)) - k_f = torch.einsum('... m o, ... o n -> ... m n', mat, k) # (..., m, n) - else: - mat = torch.view_as_complex(dft_matrix(m)) - k_f = torch.view_as_complex(act_fn( - torch.view_as_real(torch.einsum('... m o, ... o n -> ... m n', mat, k)) - )) # (..., m, n) - - # multiply by twiddle factors - twi = torch.view_as_complex(compute_twiddle_factors(n, m)) # (n, m) - k_f = torch.einsum('n m, ... m n -> ... n m', twi, k_f) # (..., n, m) - - if n <= max_m: - # do m n-length FFTs - if activation is None: - mat = torch.view_as_complex(dft_matrix(n)) - k_f = torch.einsum('... n o, ... o m -> ... n m', mat, k_f) # (.., n, m) - else: - mat = torch.view_as_complex(dft_matrix(n)) - k_f = torch.view_as_complex(act_fn( - torch.view_as_real(torch.einsum('... n o, ... o m -> ... n m', mat, k_f)) - )) # (.., n, m) - else: - # recurse - k_f = rearrange(k_f, '... h n m -> ... m h n') - k_f = _cooley_tukey(k_f, n // max_m, max_m, dft_matrix, max_m, activation) - k_f = rearrange(k_f, '... m h n -> ... h n m') - - # reshape for the output - k_f = rearrange(k_f, '... n m -> ... (n m)') # (..., n*m) - - return k_f - -def block_fft( - k, N, - dft_matrix=ref_dft_matrix, - max_m=16, - **kwargs, -): - ''' - Compute the FFT of size N of the vector k, using _block_fft_recurse. - - The dft_matrix function is overwriteable - so that we can replace it with learnable parameters in a model. - ''' - if not math.log(N, 2).is_integer(): - N = int(2 ** math.ceil(math.log(N, 2))) - # pad k with zeros if necessary (e.g. for causality) - if k.shape[-1] != N: - k = nn.ConstantPad1d((0, N - k.shape[-1]), 0)(k) - - if N <= max_m: - mat = torch.view_as_complex(dft_matrix(m)) - return torch.einsum('... n o, ... o -> ... n', mat, k) # (.., n, m) - n = N // max_m - m = max_m - return _cooley_tukey(k, n, m, dft_matrix, max_m, **kwargs) - -class BlockFFT(OptimModule): - ''' - Learnable Block FFT module. - - Args: - learn_dft_matrix (bool): If True, learn a different DFT matrix for lengths 2, 4, 8, and 16. If False, this module computes a normal FFT. - ''' - def __init__(self, learn_dft_matrices=True, H=1, max_m=16, dft_lr=0.001, dropout=0, learn_additive=False, **block_fft_args): - super().__init__() - self.learn_dft_matrices = learn_dft_matrices - self.block_fft_args = block_fft_args - self.max_m=max_m - self.drop = torch.nn.Dropout(p=dropout) - self.learn_additive=learn_additive - # get the powers of 2 up to max_m - assert math.log(max_m, 2).is_integer(), 'max_m must be a power of 2' - - self.powers = [ 2 ** (i + 1) for i in range(int(math.log(max_m, 2))) ] - - if learn_dft_matrices: - assert dft_lr>0,"If learn_dft_matrices=True dft_lr must be positive" - self.dft_matrices = nn.ParameterList() - for n in self.powers: - setattr(self,f"mat_{n}",nn.Parameter( - 0.01 * torch.randn(H, n, n, 2) if self.learn_additive - else ref_dft_matrix(n, H=H), - requires_grad=True)) - self.register(f"mat_{n}",getattr(self,f"mat_{n}"),dft_lr) - self.dft_matrices.append(getattr(self,"mat_{}".format(n))) - - def compute_dft_matrix(self, n): - if not self.learn_dft_matrices: - return ref_dft_matrix(n) - else: - assert n in self.powers - if self.learn_additive: - mat = ref_dft_matrix(n) - return mat + self.drop(self.dft_matrices[int(math.log(n, 2) - 1)]) - else: - return self.drop(self.dft_matrices[int(math.log(n, 2) - 1)]) - - def forward(self, x, N,forward=True): - '''Compute an FFT (forward=True) or iFFT (forward=False) of length N over x.''' - if forward: - return block_fft(x, N, dft_matrix=self.compute_dft_matrix, **self.block_fft_args) - else: - return (1/(N))*torch.conj(block_fft(torch.conj(x), N, dft_matrix=self.compute_dft_matrix, **self.block_fft_args)) - - -if __name__ == "__main__": - B = 128 - H = 29 - N = 8192 - n = 2 - m = 8 - k = torch.randn(B, H, N).to(torch.complex64) - - print(f'(B, H, N) = ({B}, {H}, {N})') - - # test FFT - k_f = block_fft(k, N) - k_f_ref = torch.fft.fft(k, N) - print('L-inf error in FFT: ', torch.max(torch.abs(k_f - k_f_ref)).item()) \ No newline at end of file diff --git a/src/clm/src/models/sequence/ff.py b/src/clm/src/models/sequence/ff.py deleted file mode 100644 index 804408dd..00000000 --- a/src/clm/src/models/sequence/ff.py +++ /dev/null @@ -1,50 +0,0 @@ -""" Implementation of FFN block in the style of Transformers """ - -from functools import partial -from torch import nn -from clm.src.models.sequence.base import SequenceModule -from clm.src.models.nn import LinearActivation, DropoutNd - -class FF(SequenceModule): - def __init__(self, d_input, expand=2, d_output=None, transposed=False, activation='gelu', initializer=None, dropout=0.0, tie_dropout=False): - super().__init__() - self.d_output = d_input if d_output is None else d_output - self.transposed = transposed - d_inner = expand * d_input - - linear1 = LinearActivation( - d_input, d_inner, - transposed=transposed, - activation=activation, - initializer=initializer, - activate=True, - ) - dropout_cls = partial(DropoutNd, transposed=self.transposed) if tie_dropout else nn.Dropout - # dropout_cls = nn.Dropout2d if self.transposed else nn.Dropout - drop = dropout_cls(dropout) if dropout > 0.0 else nn.Identity() - - linear2 = LinearActivation( - d_inner, self.d_output, - transposed=transposed, - activation=None, - initializer=initializer, - activate=False, - ) - - self.ff = nn.Sequential( - linear1, - drop, - linear2, - ) - - def forward(self, x, *args, **kwargs): - return self.ff(x), None - - def step(self, x, state, **kwargs): - # x: [batch, d_input] - if self.transposed: - # expects: [batch, d_input, seq_len] - return self.ff(x.unsqueeze(-1)).squeeze(-1), state - else: - return self.ff(x), state - diff --git a/src/clm/src/models/sequence/h3.py b/src/clm/src/models/sequence/h3.py deleted file mode 100644 index 07dc4c89..00000000 --- a/src/clm/src/models/sequence/h3.py +++ /dev/null @@ -1,206 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from einops import rearrange - -from clm.src.models.sequence.ssm.ss_kernel import SSKernel - -try: - from clm.src.ops.fftconv import fftconv_func -except ImportError: - fftconv_func = None - - -@torch.jit.script -def mul_sum(q, y): - return (q * y).sum(dim=1) - - -class H3(nn.Module): - - def __init__( - self, - d_model, - d_state=64, - l_max=None, - head_dim=1, - use_fast_fftconv=False, - dropout=0.0, # Just to absorb the kwarg - layer_idx=None, - device=None, dtype=None, - # SSM Kernel arguments - **kernel_args, - ): - """ - d_state: the dimension of the state, also denoted by N - l_max: the maximum kernel length, also denoted by L. Set l_max=None to always use a global kernel - - See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" - - Other options are all experimental and should not need to be configured - """ - factory_kwargs = {'device': device, 'dtype': dtype} - super().__init__() - self.d_model = d_model - self.head_dim = head_dim - assert d_model % head_dim == 0 - self.H = d_model // head_dim - self.N = d_state - self.L = l_max - self.layer_idx = layer_idx - self.use_fast_fftconv = use_fast_fftconv - if self.use_fast_fftconv: - assert fftconv_func is not None, 'Need to install fftconv' - - self.q_proj = nn.Linear(self.d_model, self.d_model, **factory_kwargs) - self.k_proj = nn.Linear(self.d_model, self.d_model, **factory_kwargs) - self.v_proj = nn.Linear(self.d_model, self.d_model, **factory_kwargs) - - # TODO: SSKernel doesn't take device argument yet - self.ssm_k_kernel = SSKernel(self.d_model, N=d_state, L=self.L, mode='shift', - lr=kernel_args.get('lr', None)) - self.ssm_k_D = nn.Parameter(torch.randn(self.d_model)) - # S4D Kernel - self.kernel = SSKernel(self.H, N=self.N, L=self.L, channels=1, **kernel_args) - self.D = nn.Parameter(torch.randn(self.H, **factory_kwargs)) - - # Pointwise - # position-wise output transform to mix features - # Don't use FusedDense since the layout is H first - self.output_linear = nn.Linear(self.d_model, self.d_model) - - def forward(self, u, inference_params=None): - """ - u: (B L H) - - Returns: same shape as u - """ - L_og = u.size(-2) - if self.use_fast_fftconv and L_og % 2 != 0: - u = F.pad(u, (0, 0, 0, 1)) - L = u.size(-2) - - use_fast_fftconv = self.use_fast_fftconv and inference_params is None - - state_k, state = None, None - if inference_params is not None: - assert self.layer_idx is not None - if self.layer_idx not in inference_params.key_value_memory_dict: - batch_shape = (u.shape[0] * self.head_dim * self.head_dim,) - state_k = self.ssm_k_kernel.default_state(*batch_shape) - state = self.kernel.default_state(*batch_shape) - inference_params.key_value_memory_dict[self.layer_idx] = (state_k, state) - else: - state_k, state = inference_params.key_value_memory_dict[self.layer_idx] - if inference_params.sequence_len_offset == 0: - self.ssm_k_kernel._setup_step() - self.kernel._setup_step() - - if inference_params is not None and inference_params.sequence_len_offset > 0: - y, next_state_k, next_state = self.step(u, state_k, state) - inference_params.key_value_memory_dict[self.layer_idx][0].copy_(next_state_k) - inference_params.key_value_memory_dict[self.layer_idx][1].copy_(next_state) - return y - - # Compute SS Kernel - L_kernel = L if self.L is None else min(L, self.L ) - ssm_kernel, k_state = self.kernel(L=L_kernel, state=state, rate=1.0) # (C H L) (B C H L) - ssm_kernel = rearrange(ssm_kernel, '1 h l -> h l') - - u = rearrange(u, 'b l h -> (b l) h') - dtype = (self.q_proj.weight.dtype if not torch.is_autocast_enabled() - else torch.get_autocast_gpu_dtype()) - q = self.q_proj.weight @ u.T + self.q_proj.bias.to(dtype).unsqueeze(-1) - k = self.k_proj.weight @ u.T + self.k_proj.bias.to(dtype).unsqueeze(-1) - v = self.v_proj.weight @ u.T + self.v_proj.bias.to(dtype).unsqueeze(-1) - q, k, v = [rearrange(x, 'h (b l) -> b h l', l=L) for x in [q, k, v]] - - k_og = k - ssm_k_kernel, _ = self.ssm_k_kernel(L=L_kernel, state=state_k, rate=1.0) # (C H L) (B C H L) - ssm_k_kernel = rearrange(ssm_k_kernel, '1 h l -> h l') - if not use_fast_fftconv: - fft_size = L_kernel + L - ssm_k_kernel_f = torch.fft.rfft(ssm_k_kernel, n=fft_size) # (H 2L) - k_f = torch.fft.rfft(k.to(ssm_kernel.dtype), n=fft_size) # (B H 2L) - shift_k_out = torch.fft.irfft(ssm_k_kernel_f * k_f, n=fft_size)[..., :L] - k = shift_k_out + rearrange(self.ssm_k_D, 'h -> h 1') * k - else: - dropout_mask = None - # No GeLU after the SSM - # We want output_hbl=True so that k has the same layout as q and v for the next - # fftconv - k = fftconv_func(k, ssm_k_kernel, self.ssm_k_D, dropout_mask, False, False, True) - # This line below looks like it doesn't do anything, but it gets the stride right - # for the case batch_size=1. In that case k has stride (L, L, 1), but q and v has - # stride (H * L, L, 1). The two strides are equivalent because batch_size=1, but - # the C++ code doesn't like that. - k = rearrange(rearrange(k, 'b h l -> h b l'), 'h b l -> b h l') - - if not use_fast_fftconv: - fft_size = L_kernel + L - # kv = k * v - kv = (rearrange(k, 'b (h d1) l -> b d1 1 h l', d1=self.head_dim) - * rearrange(v, 'b (h d2) l -> b 1 d2 h l', d2=self.head_dim)) # b d1 d2 h l - kv_f = torch.fft.rfft(kv.to(dtype=ssm_kernel.dtype), n=fft_size) / fft_size - ssm_kernel_f = torch.fft.rfft(ssm_kernel, n=fft_size) # h L+1 - y = torch.fft.irfft(kv_f * ssm_kernel_f, n=fft_size, norm='forward')[..., :L] # b d1 d2 h l - y = y + kv * self.D.unsqueeze(-1) # b d1 d2 h l - q = rearrange(q, 'b (h d1) l -> b d1 1 h l', d1=self.head_dim) - # einsum is way slower than multiply and then sum. - if self.head_dim > 1: - y = mul_sum(y, q) - y = rearrange(y, 'b d h l -> b (d h) l') - else: - y = rearrange(y * q, 'b 1 1 h l -> b h l') - else: - dropout_mask = None - # No GeLU after the SSM - # Set output_hbl_layout=True since we'll be doing a matmul right after - y = fftconv_func(k, ssm_kernel, self.D, - dropout_mask, False, torch.is_autocast_enabled(), True, - v, self.head_dim, q) - - y = rearrange(y, 'b h l -> b l h') - - if state is not None: - assert inference_params is not None - # TODO: This doesn't ever happen? - # if inference_params.sequence_len_offset > 0: - # y = y + k_state - inference_params.key_value_memory_dict[self.layer_idx][0].copy_( - self.ssm_k_kernel.forward_state(k_og, state_k) - ) - inference_params.key_value_memory_dict[self.layer_idx][1].copy_( - self.kernel.forward_state(rearrange(kv, 'b d1 d2 h l -> (b d1 d2) h l'), state) - ) - - # y could be in fp32 because of the SSMs - if not torch.is_autocast_enabled(): - y = y.to(dtype=self.output_linear.weight.dtype) - y = self.output_linear(y) - if L_og < L: - y = y[:, :L_og, :] - - return y - - def step(self, u, state_k, state): - q, k, v = self.q_proj(u), self.k_proj(u), self.v_proj(u) - shift_k, next_state_k = self.ssm_k_kernel.step(rearrange(k, 'b 1 h -> b h'), state_k) - k = shift_k + k * self.ssm_k_D - # kv = k * v - kv = (rearrange(k, 'b 1 (h d1) -> b d1 1 h', d1=self.head_dim) - * rearrange(v, 'b 1 (h d2) -> b 1 d2 h', d2=self.head_dim)) # b d1 d2 h - y, next_state = self.kernel.step(rearrange(kv, 'b d1 d2 h -> (b d1 d2) h'), state) - y = (rearrange(y, '(b d1 d2) 1 h -> b d1 d2 h', d1=self.head_dim, d2=self.head_dim) - + kv * self.D) - q = rearrange(q, 'b 1 (h d1) -> b d1 1 h', d1=self.head_dim) - if self.head_dim > 1: - y = mul_sum(y, q) - y = rearrange(y, 'b d h l -> b (d h) l') - else: - y = rearrange(y * q, 'b 1 1 h -> b 1 h') - # y could be in fp32 because of the SSMs - if not torch.is_autocast_enabled(): - y = y.to(dtype=self.output_linear.weight.dtype) - return self.output_linear(y), next_state_k, next_state diff --git a/src/clm/src/models/sequence/h3_conv.py b/src/clm/src/models/sequence/h3_conv.py deleted file mode 100644 index f2a7f7c0..00000000 --- a/src/clm/src/models/sequence/h3_conv.py +++ /dev/null @@ -1,150 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from einops import rearrange - -from clm.src.models.sequence.long_conv_kernel import LongConvKernel - -try: - from clm.src.ops.fftconv import fftconv_func -except ImportError: - fftconv_func = None - - -@torch.jit.script -def mul_sum(q, y): - return (q * y).sum(dim=1) - - -class H3Conv(nn.Module): - - def __init__( - self, - d_model, - l_max=None, - head_dim=1, - use_fast_fftconv=False, - dropout=0.0, # Just to absorb the kwarg - layer_idx=None, - device=None, dtype=None, - # SSM Kernel arguments - **kernel_args, - ): - """ - d_state: the dimension of the state, also denoted by N - l_max: the maximum kernel length, also denoted by L. Set l_max=None to always use a global kernel - - See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" - - Other options are all experimental and should not need to be configured - """ - factory_kwargs = {'device': device, 'dtype': dtype} - super().__init__() - self.d_model = d_model - self.head_dim = head_dim - assert d_model % head_dim == 0 - self.H = d_model // head_dim - self.L = l_max - self.layer_idx = layer_idx - self.use_fast_fftconv = use_fast_fftconv - if self.use_fast_fftconv: - assert fftconv_func is not None, 'Need to install fftconv' - - self.q_proj = nn.Linear(self.d_model, self.d_model, **factory_kwargs) - self.k_proj = nn.Linear(self.d_model, self.d_model, **factory_kwargs) - self.v_proj = nn.Linear(self.d_model, self.d_model, **factory_kwargs) - self.k_kernel = LongConvKernel( - self.d_model, L=self.L, - **kernel_args) - self.k_D = nn.Parameter(torch.randn(self.d_model)) - self.kernel = LongConvKernel( - self.d_model, L=self.L, - **kernel_args) - self.D = nn.Parameter(torch.randn(self.H, **factory_kwargs)) - - # Pointwise - # position-wise output transform to mix features - # Don't use FusedDense since the layout is H first - self.output_linear = nn.Linear(self.d_model, self.d_model) - - def forward(self, u, inference_params=None): - """ - u: (B L H) - - Returns: same shape as u - """ - L_og = u.size(-2) - if self.use_fast_fftconv and L_og % 2 != 0: - u = F.pad(u, (0, 0, 0, 1)) - L = u.size(-2) - - use_fast_fftconv = self.use_fast_fftconv - - # Compute SS Kernel - ssm_kernel, _ = self.kernel() # (C H L) (B C H L) - ssm_kernel = rearrange(ssm_kernel, '1 h l -> h l') - - u = rearrange(u, 'b l h -> (b l) h') - dtype = (self.q_proj.weight.dtype if not torch.is_autocast_enabled() - else torch.get_autocast_gpu_dtype()) - q = self.q_proj.weight @ u.T + self.q_proj.bias.to(dtype).unsqueeze(-1) - k = self.k_proj.weight @ u.T + self.k_proj.bias.to(dtype).unsqueeze(-1) - v = self.v_proj.weight @ u.T + self.v_proj.bias.to(dtype).unsqueeze(-1) - q, k, v = [rearrange(x, 'h (b l) -> b h l', l=L) for x in [q, k, v]] - - k_og = k - k_kernel, _ = self.k_kernel() # (C H L) (B C H L) - k_kernel = rearrange(k_kernel, '1 h l -> h l') - if not use_fast_fftconv: - fft_size = 2 * L - k_kernel_f = torch.fft.rfft(k_kernel, n=fft_size) # (H 2L) - k_f = torch.fft.rfft(k.to(ssm_kernel.dtype), n=fft_size) # (B H 2L) - shift_k_out = torch.fft.irfft(k_kernel_f * k_f, n=fft_size)[..., :L] - k = shift_k_out + rearrange(self.k_D, 'h -> h 1') * k - else: - dropout_mask = None - # No GeLU after the SSM - # We want output_hbl=True so that k has the same layout as q and v for the next - # fftconv - k = fftconv_func(k, k_kernel, self.k_D, dropout_mask, False, False, True) - # This line below looks like it doesn't do anything, but it gets the stride right - # for the case batch_size=1. In that case k has stride (L, L, 1), but q and v has - # stride (H * L, L, 1). The two strides are equivalent because batch_size=1, but - # the C++ code doesn't like that. - k = rearrange(rearrange(k, 'b h l -> h b l'), 'h b l -> b h l') - - if not use_fast_fftconv: - fft_size = 2 * L - # kv = k * v - kv = (rearrange(k, 'b (h d1) l -> b d1 1 h l', d1=self.head_dim) - * rearrange(v, 'b (h d2) l -> b 1 d2 h l', d2=self.head_dim)) # b d1 d2 h l - kv_f = torch.fft.rfft(kv.to(dtype=ssm_kernel.dtype), n=fft_size) / fft_size - ssm_kernel_f = torch.fft.rfft(ssm_kernel, n=fft_size) # h L+1 - y = torch.fft.irfft(kv_f * ssm_kernel_f, n=fft_size, norm='forward')[..., :L] # b d1 d2 h l - y = y + kv * self.D.unsqueeze(-1) # b d1 d2 h l - q = rearrange(q, 'b (h d1) l -> b d1 1 h l', d1=self.head_dim) - # einsum is way slower than multiply and then sum. - if self.head_dim > 1: - y = mul_sum(y, q) - y = rearrange(y, 'b d h l -> b (d h) l') - else: - y = rearrange(y * q, 'b 1 1 h l -> b h l') - else: - dropout_mask = None - # No GeLU after the SSM - # Set output_hbl_layout=True since we'll be doing a matmul right after - y = fftconv_func(k, ssm_kernel, self.D, - dropout_mask, False, torch.is_autocast_enabled(), True, - v, self.head_dim, q) - - y = rearrange(y, 'b h l -> b l h') - - # y could be in fp32 because of the SSMs - if not torch.is_autocast_enabled(): - y = y.to(dtype=self.output_linear.weight.dtype) - y = self.output_linear(y) - if L_og < L: - y = y[:, :L_og, :] - - return y diff --git a/src/clm/src/models/sequence/hyena.py b/src/clm/src/models/sequence/hyena.py deleted file mode 100644 index 3089c549..00000000 --- a/src/clm/src/models/sequence/hyena.py +++ /dev/null @@ -1,359 +0,0 @@ -import math - -from re import U -import torch -import torch.nn as nn -import torch.nn.functional as F -from functools import partial - -from einops import rearrange, repeat - -try: - from clm.src.ops.fftconv import fftconv_ref, fftconv_func -except ImportError: - fftconv_func = None - -try: - from flash_attn.ops.fused_dense import FusedDense -except ImportError: - FusedDense = None - -import clm.src.utils.registry as registry -from clm.src.utils.train import OptimModule -from clm.src.utils.config import instantiate, auto_assign_attrs -from clm.src.models.nn import Activation - - -# reference convolution with residual connection -def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None): - seqlen = u.shape[-1] - fft_size = 2 * seqlen - k_f = torch.fft.rfft(k, n=fft_size) / fft_size - if k_rev is not None: - k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size - k_f = k_f + k_rev_f.conj() - u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) - - if len(u.shape) > 3: k_f = k_f.unsqueeze(1) - - y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen] - - out = y + u * D.unsqueeze(-1) - if gelu: - out = F.gelu(out) - if dropout_mask is not None: - return (out * rearrange(dropout_mask, 'b H -> b H 1')).to(dtype=u.dtype) - else: - return out.to(dtype=u.dtype) - - -@torch.jit.script -def mul_sum(q, y): - return (q * y).sum(dim=1) - - -class Sin(nn.Module): - def __init__(self, dim, w=10, train_freq=True): - super().__init__() - self.freq = nn.Parameter(w * torch.ones(1, dim)) if train_freq else w * torch.ones(1, dim) - - def forward(self, x): - return torch.sin(self.freq * x) - - -class PositionalEmbedding(OptimModule): - def __init__(self, emb_dim: int, seq_len: int, lr_pos_emb: float=1e-5, **kwargs): - """Complex exponential positional embeddings for Hyena filters.""" - super().__init__() - - self.seq_len = seq_len - # The time embedding fed to the filteres is normalized so that t_f = 1 - t = torch.linspace(0, 1, self.seq_len)[None, :, None] # 1, L, 1 - - if emb_dim > 1: - bands = (emb_dim - 1) // 2 - # To compute the right embeddings we use the "proper" linspace - t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None] - w = 2 * math.pi * t_rescaled / seq_len # 1, L, 1 - - f = torch.linspace(1e-4, bands - 1, bands)[None, None] - z = torch.exp(-1j * f * w) - z = torch.cat([t, z.real, z.imag], dim=-1) - self.register("z", z, lr=lr_pos_emb) - self.register("t", t, lr=0.0) - - def forward(self, L): - return self.z[:, :L], self.t[:, :L] - - -class ExponentialModulation(OptimModule): - def __init__( - self, - d_model, - fast_decay_pct=0.3, - slow_decay_pct=1.5, - target=1e-2, - modulation_lr=0.0, - modulate: bool=True, - shift: float = 0.0, - **kwargs - ): - super().__init__() - self.modulate = modulate - self.shift = shift - max_decay = math.log(target) / fast_decay_pct - min_decay = math.log(target) / slow_decay_pct - deltas = torch.linspace(min_decay, max_decay, d_model)[None, None] - self.register("deltas", deltas, lr=modulation_lr) - - def forward(self, t, x): - if self.modulate: - decay = torch.exp(-t * self.deltas.abs()) - x = x * (decay + self.shift) - return x - - -class HyenaFilter(OptimModule): - def __init__( - self, - d_model, - emb_dim=3, # dim of input to MLP, augments with positional encoding - order=16, # width of the implicit MLP - fused_fft_conv=False, - seq_len=1024, - lr=1e-3, - lr_pos_emb=1e-5, - dropout=0.0, - w=1, # frequency of periodic activations - wd=0, # weight decay of kernel parameters - bias=True, - num_inner_mlps=2, - normalized=False, - **kwargs - ): - """ - Implicit long filter with modulation. - - Args: - d_model: number of channels in the input - emb_dim: dimension of the positional encoding (`emb_dim` - 1) // 2 is the number of bands - order: width of the FFN - num_inner_mlps: number of inner linear layers inside filter MLP - - Note: - filter_dropout is not implemented - """ - super().__init__() - self.d_model = d_model - self.use_bias = bias - self.fused_fft_conv = fused_fft_conv - self.bias = nn.Parameter(torch.randn(self.d_model)) - self.dropout = nn.Dropout(dropout) - - act = Sin(dim=order, w=w) - self.emb_dim = emb_dim - assert emb_dim % 2 != 0 and emb_dim >= 3, "emb_dim must be odd and greater or equal to 3 (time, sine and cosine)" - self.seq_len = seq_len - - self.pos_emb = PositionalEmbedding(emb_dim, seq_len, lr_pos_emb) - - # uses a variable number of inner linear layers - self.implicit_filter = nn.Sequential( - nn.Linear(emb_dim, order), - act, - ) - for i in range(num_inner_mlps): - self.implicit_filter.append(nn.Linear(order, order)) - self.implicit_filter.append(act) - # final linear layer - self.implicit_filter.append(nn.Linear(order, d_model, bias=False)) - - self.modulation = ExponentialModulation(d_model, **kwargs) - - self.normalized = normalized - for c in self.implicit_filter.children(): - for name, v in c.state_dict().items(): - optim = {"weight_decay": wd, "lr": lr} - setattr(getattr(c, name), "_optim", optim) - - def filter(self, L, *args, **kwargs): - z, t = self.pos_emb(L) - h = self.implicit_filter(z) - - h = self.modulation(t, h) - - if self.normalized: h = h / torch.norm(h, dim=-1, p=1, keepdim=True) - - return h - - def forward(self, x, L, k=None, bias=None, *args, **kwargs): - if k is None: k = self.filter(L) - - # Ensure compatibility with filters that return a tuple - k = k[0] if type(k) is tuple else k - if bias is None: bias = self.bias - bias = bias if self.use_bias else 0 * bias - - if self.fused_fft_conv: - bias = bias.to(dtype=torch.float32) - y = fftconv_func( - x, k, bias, dropout_mask=None, gelu=False, - force_fp16_output=torch.is_autocast_enabled() - ) - else: - y = fftconv_ref(x, k, bias, dropout_mask=None, gelu=False) - - return y - - -class HyenaOperator(nn.Module): - def __init__( - self, - d_model, - l_max, - order=2, - filter_order=64, - num_heads=1, - inner_factor=1, - num_blocks=1, - fused_bias_fc=False, - outer_mixing=False, - dropout=0.0, - filter_dropout=0.0, - filter_cls='hyena-filter', - post_order_ffn=False, - jit_filter=False, - short_filter_order=3, - activation="id", - return_state=False, - **filter_args, - ): - r""" - Hyena operator described in the paper https://arxiv.org/pdf/2302.10866.pdf - - Args: - d_model (int): Dimension of the input and output embeddings (width of the layer) - l_max: (int): Maximum input sequence length. Defaults to None - order: (int): Depth of the Hyena recurrence. Defaults to 2 - filter_order: (int): Width of the FFN parametrizing the implicit filter. Defaults to 64 - num_heads: (int): Number of heads. Defaults to 1 - inner_factor: (int): Width multiplier. Defaults to 1 - num_blocks: (int): Number of blocks in sequence length. Defaults to 1 - fused_bias_fc: (bool): Whether to use fused bias FC. Defaults to False - dropout: (float): Dropout probability. Defaults to 0.0 - filter_dropout: (float): Dropout probability for the filter. Defaults to 0.0 - post_order_ffn: (bool): Apply a dense layer between steps of the recurrence. Defaults to False - jit_filter: (bool): Whether JIT the implicit filter function. Defaults to False - short_filter_order: (int): Length of the explicit input convolutional filter. Defaults to 3 - activation: (str): type of act between kernel output and FF (default identity) - return_state: (bool): whether to return a state - """ - super().__init__() - assert d_model % num_heads == 0, f'Model dimension {d_model} must be divisible by num heads {num_heads}' - assert l_max % num_blocks == 0, f'Maximum signal length {l_max} must be divisible by block dimension {num_blocks}' - block_dim = l_max // num_blocks - head_dim = d_model // num_heads - - auto_assign_attrs( - self, d_model=d_model, order=order, l_max=l_max, num_heads=num_heads, inner_factor=inner_factor, - block_dim=block_dim, head_dim=head_dim, filter_order=filter_order, post_order_ffn=post_order_ffn, - short_filter_order=short_filter_order, num_blocks = num_blocks, filter_dropout=filter_dropout, - jit_filter=jit_filter, outer_mixing=outer_mixing, activation=activation, return_state=return_state, - ) - self.activation = Activation(activation) - self.dropout = nn.Dropout(dropout) - self.setup_projections(fused_bias_fc, inner_factor) - self.setup_filters(filter_cls, filter_args) - - - def setup_projections(self, fused_bias_fc, inner_factor): - "Initializes input and output projections (over the width dimension)" - if fused_bias_fc and FusedDense is None: - raise ImportError('fused_dense is not installed') - linear_cls = nn.Linear if not fused_bias_fc else FusedDense - self.out_proj = linear_cls(self.d_model * inner_factor, self.d_model) - self.in_proj = linear_cls(self.d_model, (self.order + 1) * self.d_model) - if self.post_order_ffn: - self.ord_proj_w = nn.Parameter(torch.randn(self.order, self.num_heads, self.num_heads) / math.sqrt(self.head_dim)) - - - def setup_filters(self, filter_cls, filter_args): - "Initializes the explicit and implicit filters" - assert self.order >= 2, f'Order must be at least 2, (got {self.order})' - total_width = self.d_model * self.inner_factor * (self.order + 1) - - self.short_filter = nn.Conv1d( - in_channels=total_width, - out_channels=total_width, - kernel_size=self.short_filter_order, - groups=total_width, - padding=self.short_filter_order - 1 - ) - - filter_cls = instantiate(registry.layer, filter_cls, partial=True) - - self.filter_fn = filter_cls( - self.head_dim * self.inner_factor * (self.order - 1), - order=self.filter_order, - seq_len=self.l_max, - channels=1, - dropout=self.filter_dropout, - **filter_args - ) - if self.jit_filter: self.filter_fn = torch.jit.script(self.filter_fn, self.L) - - def recurrence(self, u , state): - "Fast inference mode via distilled recurrence" - raise NotImplementedError("Working on it!") - - def forward(self, u, *args, **kwargs): - l = u.size(-2) - l_filter = min(l, self.l_max) - u = self.in_proj(u) - u = rearrange(u, 'b l d -> b d l') - - uc = self.short_filter(u)[...,:l_filter] - - uc = rearrange(uc, 'b (ho v) (z l) -> b ho v z l', - z=self.num_blocks, - ho=self.num_heads, - v=self.head_dim * (self.order + 1) - ) - - *x, v = uc.split(self.d_model, dim=2) - k = self.filter_fn.filter(l_filter) - - # `c` is always 1 by default - k = rearrange(k, 'c l (v o) -> c o v l', v=self.head_dim, o=self.order - 1)[0] - - bias = rearrange(self.filter_fn.bias, '(v o) -> o v', v=self.head_dim, o=self.order - 1) - - for o, x_i in enumerate(reversed(x[1:])): - if self.outer_mixing: - v = rearrange(v, 'b h v z l -> b h 1 v z l') - v = self.dropout( - v * rearrange(x_i, 'b h v z l -> b h v 1 z l') - ) - v = v.sum(dim=2) - else: - v = self.dropout(v * x_i) - - # the bias term is broadcasted. Last dimension (l) is handled by fftconv - v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o, None, :, None]) - - if self.post_order_ffn: - w = self.ord_proj_w[o] - v = mul_sum( - rearrange(w, 'h1 h2 -> 1 h1 h2 1 1 1'), rearrange(v, 'b h v z l -> b h 1 v z l') - ) - - y = self.activation(rearrange(v * x[0], 'b h v z l -> b (z l) (h v)', z=self.num_blocks, h=self.num_heads)) - y = self.out_proj(y) - - if self.return_state: - return y, None - return y - - @property - def d_output(self): - return self.d_model \ No newline at end of file diff --git a/src/clm/src/models/sequence/hyena_components.py b/src/clm/src/models/sequence/hyena_components.py deleted file mode 100644 index 99a55cff..00000000 --- a/src/clm/src/models/sequence/hyena_components.py +++ /dev/null @@ -1,255 +0,0 @@ -""" -Standalone Hyena components without registry dependencies. -""" - -import math -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange - - -def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None): - """Reference convolution with residual connection""" - seqlen = u.shape[-1] - fft_size = 2 * seqlen - k_f = torch.fft.rfft(k, n=fft_size) / fft_size - if k_rev is not None: - k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size - k_f = k_f + k_rev_f.conj() - u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) - - if len(u.shape) > 3: - k_f = k_f.unsqueeze(1) - - y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen] - - out = y + u * D.unsqueeze(-1) - if gelu: - out = F.gelu(out) - if dropout_mask is not None: - return (out * rearrange(dropout_mask, 'b H -> b H 1')).to(dtype=u.dtype) - else: - return out.to(dtype=u.dtype) - - -@torch.jit.script -def mul_sum(q, y): - return (q * y).sum(dim=1) - - -class Sin(nn.Module): - """Sinusoidal activation function""" - def __init__(self, dim, w=10, train_freq=True): - super().__init__() - self.freq = nn.Parameter(w * torch.ones(1, dim)) if train_freq else w * torch.ones(1, dim) - - def forward(self, x): - return torch.sin(self.freq * x) - - -class PositionalEmbedding(nn.Module): - """Complex exponential positional embeddings for Hyena filters""" - def __init__(self, emb_dim: int, seq_len: int, **kwargs): - super().__init__() - - self.seq_len = seq_len - t = torch.linspace(0, 1, self.seq_len)[None, :, None] - - if emb_dim > 1: - bands = (emb_dim - 1) // 2 - - t_rescaled = torch.linspace(0, seq_len - 1, seq_len)[None, :, None] - w = 2 * math.pi * t_rescaled / seq_len - - f = torch.linspace(1e-4, bands - 1, bands)[None, None] - z = torch.exp(-1j * f * w) - z = torch.cat([t, z.real, z.imag], dim=-1) - - self.register_buffer('z', z) - self.register_buffer('t', t) - - def forward(self, L): - return self.z[:, :L], self.t[:, :L] - - -class ExponentialModulation(nn.Module): - """Exponential modulation for implicit filters""" - def __init__( - self, - d_model, - fast_decay_pct=0.3, - slow_decay_pct=1.5, - target=1e-2, - modulate: bool = True, - shift: float = 0.0, - **kwargs - ): - super().__init__() - self.modulate = modulate - self.shift = shift - max_decay = math.log(target) / fast_decay_pct - min_decay = math.log(target) / slow_decay_pct - deltas = torch.linspace(min_decay, max_decay, d_model)[None, None] - self.register_buffer('deltas', deltas) - - def forward(self, t, x): - if self.modulate: - decay = torch.exp(-t * self.deltas.abs()) - x = x * (decay + self.shift) - return x - - -class HyenaFilter(nn.Module): - """Standalone Hyena filter without registry dependencies""" - def __init__( - self, - d_model, - emb_dim=3, - order=16, - seq_len=1024, - dropout=0.0, - w=1, - bias=True, - num_inner_mlps=2, - **kwargs - ): - super().__init__() - self.d_model = d_model - self.use_bias = bias - self.bias = nn.Parameter(torch.randn(self.d_model)) - self.dropout = nn.Dropout(dropout) - - act = Sin(dim=order, w=w) - self.emb_dim = emb_dim - self.seq_len = seq_len - - self.pos_emb = PositionalEmbedding(emb_dim, seq_len) - - # Build MLP - layers = [nn.Linear(emb_dim, order), act] - for i in range(num_inner_mlps): - layers.append(nn.Linear(order, order)) - layers.append(act) - layers.append(nn.Linear(order, d_model, bias=False)) - - self.implicit_filter = nn.Sequential(*layers) - self.modulation = ExponentialModulation(d_model, **kwargs) - - def filter(self, L): - z, t = self.pos_emb(L) - h = self.implicit_filter(z) - h = self.modulation(t, h) - return h - - def forward(self, x, L, k=None, bias=None): - if k is None: - k = self.filter(L) - - k = k[0] if type(k) is tuple else k - if bias is None: - bias = self.bias - bias = bias if self.use_bias else 0 * bias - - # Use reference implementation - y = fftconv_ref(x, k, bias, dropout_mask=None, gelu=False) - return y - - -class HyenaOperator(nn.Module): - """Standalone Hyena operator without registry dependencies""" - def __init__( - self, - d_model, - l_max, - order=2, - filter_order=64, - num_heads=1, - inner_factor=1, - num_blocks=1, - dropout=0.0, - filter_dropout=0.0, - short_filter_order=3, - **filter_args, - ): - super().__init__() - - assert d_model % num_heads == 0, f'Model dimension {d_model} must be divisible by num heads {num_heads}' - assert l_max % num_blocks == 0, f'Maximum signal length {l_max} must be divisible by block dimension {num_blocks}' - - self.d_model = d_model - self.order = order - self.l_max = l_max - self.num_heads = num_heads - self.inner_factor = inner_factor - self.num_blocks = num_blocks - self.filter_order = filter_order - self.short_filter_order = short_filter_order - self.filter_dropout = filter_dropout - - self.block_dim = l_max // num_blocks - self.head_dim = d_model // num_heads - - self.dropout = nn.Dropout(dropout) - - # Projections - self.out_proj = nn.Linear(self.d_model * inner_factor, self.d_model) - self.in_proj = nn.Linear(self.d_model, (self.order + 1) * self.d_model) - - # Short filter - total_width = self.d_model * self.inner_factor * (self.order + 1) - self.short_filter = nn.Conv1d( - in_channels=total_width, - out_channels=total_width, - kernel_size=self.short_filter_order, - groups=total_width, - padding=self.short_filter_order - 1 - ) - - # Long implicit filter - self.filter_fn = HyenaFilter( - self.head_dim * self.inner_factor * (self.order - 1), - order=self.filter_order, - seq_len=self.l_max, - dropout=self.filter_dropout, - **filter_args - ) - - def forward(self, u): - l = u.size(-2) - l_filter = min(l, self.l_max) - - u = self.in_proj(u) - u = rearrange(u, 'b l d -> b d l') - - uc = self.short_filter(u)[..., :l_filter] - - uc = rearrange( - uc, 'b (ho v) (z l) -> b ho v z l', - z=self.num_blocks, - ho=self.num_heads, - v=self.head_dim * (self.order + 1) - ) - - *x, v = uc.split(self.d_model, dim=2) - k = self.filter_fn.filter(l_filter) - - k = rearrange(k, 'c l (v o) -> c o v l', v=self.head_dim, o=self.order - 1)[0] - bias = rearrange(self.filter_fn.bias, '(v o) -> o v', v=self.head_dim, o=self.order - 1) - - for o, x_i in enumerate(reversed(x[1:])): - v = self.dropout(v * x_i) - v = self.filter_fn(v, l_filter, k=k[o], bias=bias[o, None, :, None]) - - y = rearrange( - v * x[0], 'b h v z l -> b (z l) (h v)', - z=self.num_blocks, - h=self.num_heads - ) - y = self.out_proj(y) - - return y - - @property - def d_output(self): - return self.d_model \ No newline at end of file diff --git a/src/clm/src/models/sequence/long_conv.py b/src/clm/src/models/sequence/long_conv.py deleted file mode 100644 index 7b5a53c1..00000000 --- a/src/clm/src/models/sequence/long_conv.py +++ /dev/null @@ -1,170 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -import opt_einsum as oe - -optimized = True - -if optimized: - contract = oe.contract -else: - contract = torch.einsum - -from clm.src.models.nn import LinearActivation, Activation, DropoutNd -from clm.src.models.sequence.block_fft import BlockFFT -from clm.src.models.sequence.long_conv_kernel import LongConvKernel - -class LongConv(nn.Module): - def __init__( - self, - d_model, - l_max=1024, - channels=1, - bidirectional=False, - # Arguments for position-wise feedforward components - activation='gelu', # activation between conv and FF - postact='glu', # activation after FF - initializer=None, # initializer on FF - weight_norm=False, # weight normalization on FF - dropout=0.0, tie_dropout=False, - transposed=True, # axis ordering (B, L, D) or (B, D, L) - verbose=False, - block_fft_conv=False, # replace the FFT conv with Monarch blocks - block_fft_conv_args={}, - - # SSM Kernel arguments - **kernel_args, - ): - """ - d_state: the dimension of the state, also denoted by N - l_max: the maximum kernel length, also denoted by L - channels: can be interpreted as a number of "heads"; the SSM is a map from a 1-dim to C-dim sequence. It's not recommended to change this unless desperate for things to tune; instead, increase d_model for larger models - bidirectional: if True, convolution kernel will be two-sided - - Position-wise feedforward components: - -------------------- - activation: activation in between SS and FF - postact: activation after FF ('id' for no activation, None to remove FF layer) - initializer: initializer on FF - weight_norm: weight normalization on FF - dropout: standard dropout argument. tie_dropout=True ties the dropout mask across the sequence length, emulating nn.Dropout1d - - Other arguments: - -------------------- - transposed: choose backbone axis ordering of (B, L, H) (if False) or (B, H, L) (if True) [B=batch size, L=sequence length, H=hidden dimension] - """ - - super().__init__() - if verbose: - import clm.src.utils.train - log = clm.src.utils.train.get_logger(__name__) - log.info(f"Constructing Long Conv (H, L) = ({d_model}, {l_max})") - - self.d_model = d_model - self.H = d_model - self.L = l_max - self.bidirectional = bidirectional - self.channels = channels - self.transposed = transposed - self.block_fft_conv = block_fft_conv - self.block_fft_conv_args = block_fft_conv_args - - self.D = nn.Parameter(torch.randn(channels, self.H)) - - if self.bidirectional: - channels *= 2 - - # SSM Kernel - self.kernel = LongConvKernel(self.H, L=self.L, channels=channels, verbose=verbose, **kernel_args) - - if self.block_fft_conv: - self.block_fft_u = BlockFFT(**self.block_fft_conv_args) - self.block_fft_k = BlockFFT(**self.block_fft_conv_args) - - # Pointwise - self.activation = Activation(activation) - # dropout_fn = nn.Dropout2d if self.transposed else nn.Dropout # Broken in torch==1.11 - dropout_fn = DropoutNd if tie_dropout else nn.Dropout - self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() - - # position-wise output transform to mix features - if postact is None: - self.output_linear = nn.Identity() - else: - self.output_linear = LinearActivation( - self.d_model * self.channels, - self.d_model, - # self.H*self.channels, - # self.d_model*(1 if self.gate is None else self.gate), - transposed=self.transposed, - initializer=initializer, - activation=postact, - activate=True, - weight_norm=weight_norm, - ) - - - - def forward(self, u, state=None, rate=1.0, lengths=None, **kwargs): # absorbs return_output and transformer src mask - """ - u: (B H L) if self.transposed else (B L H) - state: (H N) never needed, remnant from state spaces repo - - Returns: same shape as u - """ - if not self.transposed: u = u.transpose(-1, -2) - L = u.size(-1) - # Mask out padding tokens - # TODO handle option for mask - instead of lengths, which assumes suffix padding - if isinstance(lengths, int): - if lengths != L: - lengths = torch.tensor(lengths, dtype=torch.long, device=u.device) - else: - lengths = None - if lengths is not None: - assert isinstance(lengths, torch.Tensor) and lengths.ndim == 1 and lengths.size(0) in [1, u.size(0)] - mask = torch.where(torch.arange(L, device=lengths.device) < lengths[:, None, None], 1., 0.) - u = u * mask - - # Compute SS Kernel - L_kernel = L if self.L is None else min(L, round(self.L / rate)) - k, _ = self.kernel(L=L_kernel, rate=rate, state=state) # (C H L) (B C H L) - - # Convolution - if self.bidirectional: - k0, k1 = rearrange(k, '(s c) h l -> s c h l', s=2) - k = F.pad(k0, (0, L)) \ - + F.pad(k1.flip(-1), (L, 0)) - - if self.block_fft_conv: - k_f = self.block_fft_k(k.to(torch.complex64), N=L_kernel+L) # (C H L) - u_f = self.block_fft_u(u.to(torch.complex64), N=L_kernel+L) # (B H L) - y_f = contract('bhl,chl->bchl', u_f, k_f) - y = torch.fft.ifft(y_f, n=L_kernel+L, dim=-1).real[..., :L] # (B C H L) - else: - k_f = torch.fft.rfft(k, n=L_kernel+L) # (C H L) - u_f = torch.fft.rfft(u, n=L_kernel+L) # (B H L) - y_f = contract('bhl,chl->bchl', u_f, k_f) - y = torch.fft.irfft(y_f, n=L_kernel+L)[..., :L] # (B C H L) - - # Compute skip connection - y = y + contract('bhl,ch->bchl', u, self.D) - - # Reshape to flatten channels - y = rearrange(y, '... c h l -> ... (c h) l') - - if not self.transposed: y = y.transpose(-1, -2) - y = self.activation(y) - y = self.dropout(y) - y = self.output_linear(y) - - return y, None - - @property - def d_state(self): - return self.H - - @property - def d_output(self): - return self.d_model diff --git a/src/clm/src/models/sequence/long_conv_kernel.py b/src/clm/src/models/sequence/long_conv_kernel.py deleted file mode 100644 index f54b8b88..00000000 --- a/src/clm/src/models/sequence/long_conv_kernel.py +++ /dev/null @@ -1,82 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import repeat - -from clm.src.utils.train import OptimModule - -class LongConvKernel(OptimModule): - def __init__( - self, - H, - L, - channels=1, - learning_rate=None, - lam=0.1, - causal=True, - kernel_dropout=0, - weight_init="random", - use_ma_smoothing = False, - ma_window_len = 7, - smooth_freq = False, - **kwargs - ): - super().__init__() - - self.drop = torch.nn.Dropout(p=kernel_dropout) - self.H = H - self.weight_init = weight_init - self.causal = causal - self.L = L*2 if not causal else L - - self.channels = channels - self.lam = lam - self.kernel = torch.nn.Parameter(self._parameter_initialization()) #(c,H,L) - - self.register("kernel", self.kernel, learning_rate) - - self.use_ma_smoothing=use_ma_smoothing - self.smooth_freq = smooth_freq - self.ma_window_len = ma_window_len - if self.use_ma_smoothing: - if smooth_freq: - weight = torch.arange(ma_window_len, dtype = self.kernel.dtype) - weight = torch.exp(-0.5 * torch.abs(weight - ma_window_len // 2) ** 2) - weight = repeat(weight, 'l -> h1 h2 l', h1 = self.H, h2 = 1) - weight = weight.type(torch.fft.rfft(self.kernel).dtype) - self.smooth_weight = weight - else: - self.ma_window_len = ma_window_len - assert self.ma_window_len%2!=0, "window size must be odd" - padding = (self.ma_window_len//2) - self.smooth = torch.nn.AvgPool1d(kernel_size=self.ma_window_len,stride=1,padding=padding) - - def _parameter_initialization(self): - if self.weight_init=="random": - return torch.randn(self.channels, self.H, self.L) * 0.002 - elif self.weight_init=="double_exp": - K = torch.randn(self.channels, self.H, self.L,dtype=torch.float32) * 0.02 - double_exp = torch.zeros((self.H,self.L),dtype=torch.float32) - for i in range(self.H): - for j in range(self.L): - double_exp[i,j] = torch.exp(-(j/self.L)*torch.pow(torch.tensor(int(self.H/2)),torch.tensor(i/self.H))) - K = torch.einsum("c h l, h l -> c h l",K,double_exp) - return K - else: raise NotImplementedError(f"{self.weight_init} is not valid") - - def forward(self, **kwargs): - k = self.kernel - if self.use_ma_smoothing: - if self.smooth_freq: - k_f = torch.fft.rfft(k, dim=-1) - k_f = F.conv1d(k_f, self.smooth_weight.to(k_f.device), padding='same', groups=self.H) - k = torch.fft.irfft(k_f, dim=-1) - else: - k = self.smooth(k) - k = F.relu(torch.abs(k)-self.lam)*torch.sign(k) - k = self.drop(k) - return k, None - - @property - def d_output(self): - return self.H \ No newline at end of file diff --git a/src/clm/src/models/sequence/long_conv_lm.py b/src/clm/src/models/sequence/long_conv_lm.py deleted file mode 100644 index 88222e5e..00000000 --- a/src/clm/src/models/sequence/long_conv_lm.py +++ /dev/null @@ -1,397 +0,0 @@ -# Copyright (c) 2023, Tri Dao, Dan Fu. - -import copy -import math -import re -from functools import partial - -from collections import namedtuple, OrderedDict -from collections.abc import Sequence - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from transformers.models.gpt2.configuration_gpt2 import GPT2Config - -from einops import rearrange - -from flash_attn.modules.mha import MHA, ParallelMHA -from flash_attn.modules.mlp import Mlp, FusedMLP, ParallelFusedMLP -from flash_attn.modules.block import Block -from flash_attn.modules.embedding import GPT2Embeddings, ParallelGPT2Embeddings -from flash_attn.utils.generation import GenerationMixin -from flash_attn.utils.distributed import sync_shared_params, all_gather_raw - -try: - from flash_attn.ops.fused_dense import ColumnParallelLinear -except ImportError: - ColumnParallelLinear = None - -try: - from flash_attn.ops.layer_norm import dropout_add_layer_norm -except ImportError: - dropout_add_layer_norm = None - -from clm.src.utils import instantiate -import clm.src.utils.registry as registry - -def create_mixer_cls(layer=None, process_group=None, - attn_layer_idx=None, attn_cfg=None, layer_idx=None, - sequence_parallel=True, device=None, dtype=None): - factory_kwargs = {'device': device, 'dtype': dtype} - parallel_kwargs = ({'process_group': process_group, 'sequence_parallel': sequence_parallel} - if process_group is not None else {}) - if attn_layer_idx is not None and layer_idx in attn_layer_idx: - causal = True if attn_cfg is None else attn_cfg.pop('causal', True) - fused_bias_fc = False if attn_cfg is None else attn_cfg.get('fused_bias_fc', False) - if not fused_bias_fc: - assert process_group is None, 'TensorParallel MHA requires fused_bias_fc' - mha_cls = MHA if process_group is None else ParallelMHA - # ParallelMHA doesn't take 'fused_bias_fc', it is assumed that we fuse matmul + bias - if process_group is not None: - attn_cfg = copy.deepcopy(attn_cfg) # Don't modify the original cfg - attn_cfg.pop('fused_bias_fc', None) - mixer_cls = partial(mha_cls, causal=causal, layer_idx=layer_idx, - **(attn_cfg if attn_cfg is not None else {}), - **parallel_kwargs, **factory_kwargs) - else: - fused_bias_fc = False if layer is None else layer.get('fused_bias_fc', False) - if process_group is not None: - assert fused_bias_fc, 'TensorParallel SSM requires fused_bias_fc' - mixer_cls = instantiate(registry.layer, layer, partial=True, layer_idx=layer_idx, **factory_kwargs, **parallel_kwargs) - # mixer_cls = partial(ssm_cls, layer_idx=layer_idx, - # **(ssm_cfg if ssm_cfg is not None else {}), - # **parallel_kwargs, **factory_kwargs) - return mixer_cls - - -def create_mlp_cls(d_model, d_inner=None, process_group=None, fused_mlp=False, - sequence_parallel=True, device=None, dtype=None): - factory_kwargs = {'device': device, 'dtype': dtype} - inner_dim = d_inner if d_inner is not None else 4 * d_model - if process_group is not None: - assert fused_mlp, 'Tensor Parallel is only implemented for FusedMLP' - if not fused_mlp: - mlp_cls = partial(Mlp, hidden_features=inner_dim, - activation=partial(F.gelu, approximate='tanh'), **factory_kwargs) - else: - mlp_cls = FusedMLP if process_group is None else ParallelFusedMLP - parallel_kwargs = ({'process_group': process_group, 'sequence_parallel': sequence_parallel} - if process_group is not None else {}) - mlp_cls = partial(mlp_cls, hidden_features=inner_dim, **parallel_kwargs, **factory_kwargs) - return mlp_cls - - -def create_block(d_model, d_inner=None, process_group=None, - layer=None, attn_layer_idx=None, - attn_cfg=None, layer_norm_epsilon=1e-5, - resid_dropout1=0.0, resid_dropout2=0.0, residual_in_fp32=False, - fused_mlp=False, fused_dropout_add_ln=False, layer_idx=None, - sequence_parallel=True, - device=None, dtype=None): - factory_kwargs = {'device': device, 'dtype': dtype} - mixer_cls = create_mixer_cls(layer=layer, process_group=process_group, - attn_layer_idx=attn_layer_idx, - attn_cfg=attn_cfg, layer_idx=layer_idx, - sequence_parallel=sequence_parallel, - **factory_kwargs) - mlp_cls = create_mlp_cls(d_model, d_inner=d_inner, process_group=process_group, - fused_mlp=fused_mlp, sequence_parallel=sequence_parallel, - **factory_kwargs) - norm_cls = partial(nn.LayerNorm, eps=layer_norm_epsilon, **factory_kwargs) - block = Block(d_model, mixer_cls, mlp_cls, norm_cls=norm_cls, - prenorm=True, resid_dropout1=resid_dropout1, resid_dropout2=resid_dropout2, - fused_dropout_add_ln=fused_dropout_add_ln, residual_in_fp32=residual_in_fp32, - sequence_parallel=sequence_parallel and process_group is not None, - mark_shared_params=process_group is not None) - block.layer_idx = layer_idx - return block - - -# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 -def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True, - glu_act=False): - if isinstance(module, nn.Linear): - nn.init.normal_(module.weight, std=initializer_range) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - nn.init.normal_(module.weight, std=initializer_range) - - if rescale_prenorm_residual: - # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: - # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale - # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. - # > -- GPT-2 :: https://openai.com/blog/better-language-models/ - # - # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py - for name, p in module.named_parameters(): - if name in ["out_proj.weight", "fc2.weight"]: - # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer)) - # If using GLU activation for now, we scale the std by 2 - elif name in ["output_linear.0.weight"]: - # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - if not glu_act: - nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer)) - else: - out_features = p.shape[0] - # Multiplying the first half of the matrix by 2 since sigmoid scales it down by 0.5 - # on average. - nn.init.normal_(p[:out_features // 2], mean=0.0, std=initializer_range / math.sqrt(2 * n_layer) * 2) - - -class LMBackbone(nn.Module): - - def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_size: int, - process_group=None, layer=None, - attn_layer_idx=None, attn_cfg=None, max_position_embeddings=0, - resid_dropout: float = 0.0, embed_dropout: float = 0.1, dropout_cls=nn.Dropout, - layer_norm_epsilon: float = 1e-5, initializer_cfg=None, - fused_mlp=False, fused_dropout_add_ln=False, residual_in_fp32=False, - sequence_parallel=True, - device=None, dtype=None, **kwargs) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} - super().__init__() - self.process_group = process_group - self.sequence_parallel = sequence_parallel - self.residual_in_fp32 = residual_in_fp32 - - if process_group is None: - self.embeddings = GPT2Embeddings(d_model, vocab_size, max_position_embeddings, - **factory_kwargs) - else: - self.embeddings = ParallelGPT2Embeddings( - d_model, vocab_size, max_position_embeddings, - process_group=process_group, sequence_parallel=self.sequence_parallel, - **factory_kwargs - ) - - # We change the order of dropout, residual and layer norm: - # Instead of LN -> Attn / MLP -> Dropout -> Add, we do: - # Dropout -> Add -> LN -> Attn / MLP, returning both the residual branch (output of Add) and - # the main branch (output of MLP). The model definition is unchanged, but the mapping of the - # nn.Dropout probabilities are changed. - # This is for performance reason: we can fuse dropout + add + layer_norm. - self.fused_dropout_add_ln = fused_dropout_add_ln - if self.fused_dropout_add_ln and dropout_add_layer_norm is None: - raise ImportError('dropout_add_layer_norm is not installed') - - self.layers = nn.ModuleList([create_block( - d_model, d_inner=d_inner, process_group=process_group, - layer=layer, attn_layer_idx=attn_layer_idx, - attn_cfg=attn_cfg, layer_norm_epsilon=layer_norm_epsilon, - resid_dropout1=embed_dropout if i == 0 else resid_dropout, - resid_dropout2=resid_dropout, residual_in_fp32=residual_in_fp32, - fused_mlp=fused_mlp, fused_dropout_add_ln=fused_dropout_add_ln, layer_idx=i, - sequence_parallel=self.sequence_parallel, - **factory_kwargs, - ) for i in range(n_layer)]) - - self.drop_f = nn.Dropout(resid_dropout) - self.ln_f = nn.LayerNorm(d_model, eps=layer_norm_epsilon, **factory_kwargs) - - if process_group is not None: - for p in self.ln_f.parameters(): - # Mark the norm parameters as "shared_params" so that we sync their values at init. - p._shared_params = True - # Mark the norm params as "sequence_parallel" so we run all-reduce on their grads. - if self.sequence_parallel: - p._sequence_parallel = True - - self.apply(partial(_init_weights, n_layer=n_layer, - **(initializer_cfg if initializer_cfg is not None else {}))) - self.tie_weights() - - def tie_weights(self): - if self.process_group is not None: - sync_shared_params(self, self.process_group) - - def forward(self, input_ids, position_ids=None, inference_params=None): - # If using Tensor Parallel with sequence parallel, we combine the batch and the seqlen - # dimensions so that we can split on it easily, in case of small batch size. - # Only the attention/SSM layers need to know the seqlen. - embedding_kwargs = ({'combine_batch_seqlen_dim': True} - if self.process_group is not None and self.sequence_parallel else {}) - hidden_states = self.embeddings(input_ids, position_ids=position_ids, **embedding_kwargs) - residual = None - mixer_kwargs = ({'seqlen': input_ids.shape[1]} - if self.process_group is not None and self.sequence_parallel else {}) - if inference_params is not None: - mixer_kwargs['inference_params'] = inference_params - for layer in self.layers: - hidden_states, residual = layer(hidden_states, residual, mixer_kwargs=mixer_kwargs) - if not self.fused_dropout_add_ln: - dropped = self.drop_f(hidden_states) - residual = (dropped + residual) if residual is not None else dropped - hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype)) - else: - # Set prenorm=False here since we don't need the residual - hidden_states = dropout_add_layer_norm( - hidden_states, residual, self.ln_f.weight, self.ln_f.bias, - self.drop_f.p if self.training else 0.0, self.ln_f.eps, prenorm=False, - residual_in_fp32=self.residual_in_fp32 - ) - return hidden_states - - -class ConvLMHeadModel(nn.Module, GenerationMixin): - - def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_size: int, - process_group=None, layer=None, - attn_layer_idx=None, attn_cfg=None, max_position_embeddings=0, - resid_dropout: float = 0.0, embed_dropout: float = 0.1, dropout_cls=nn.Dropout, - layer_norm_epsilon: float = 1e-5, initializer_cfg=None, - fused_mlp=False, fused_dropout_add_ln=False, residual_in_fp32=False, - pad_vocab_size_multiple: int = 1, sequence_parallel=True, - device=None, dtype=None, **kwargs) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} - super().__init__() - self.process_group = process_group - if vocab_size % pad_vocab_size_multiple != 0: - vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) - self.backbone = LMBackbone( - d_model=d_model, n_layer=n_layer, d_inner=d_inner, vocab_size=vocab_size, - process_group=process_group, - layer=layer, attn_layer_idx=attn_layer_idx, attn_cfg=attn_cfg, - max_position_embeddings=max_position_embeddings, - resid_dropout=resid_dropout, embed_dropout=embed_dropout, - dropout_cls=dropout_cls, layer_norm_epsilon=layer_norm_epsilon, - initializer_cfg=initializer_cfg, fused_mlp=fused_mlp, - fused_dropout_add_ln=fused_dropout_add_ln, residual_in_fp32=residual_in_fp32, - sequence_parallel=sequence_parallel, - **factory_kwargs, **kwargs - ) - if process_group is None: - self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) - else: - if ColumnParallelLinear is None: - raise ImportError('fused_dense_lib is not installed') - self.lm_head = ColumnParallelLinear( - d_model, vocab_size, process_group, bias=False, - sequence_parallel=sequence_parallel, **factory_kwargs - ) - # Initialize weights and apply final processing - self.apply(partial(_init_weights, n_layer=n_layer, - **(initializer_cfg if initializer_cfg is not None else {}))) - self.tie_weights() - - def tie_weights(self): - self.lm_head.weight = self.backbone.embeddings.word_embeddings.weight - if self.process_group is not None: - sync_shared_params(self, self.process_group) - - def forward(self, input_ids, position_ids=None, inference_params=None, state=None): # state for the repo interface - hidden_states = self.backbone(input_ids, position_ids=position_ids, - inference_params=inference_params) - lm_logits = self.lm_head(hidden_states) - # During inference, we want the full logit for sampling - if isinstance(self.lm_head, ColumnParallelLinear) and inference_params is not None: - lm_logits, _ = all_gather_raw(lm_logits, self.lm_head.process_group) - lm_logits = rearrange(lm_logits, '(n b) s d -> b s (n d)', b=hidden_states.shape[0]) - CausalLMOutput = namedtuple('CausalLMOutput', ['logits']) - return CausalLMOutput(logits=lm_logits), None - - def load_state_dict(self, state_dict, strict=True): - # Remapping from our checkpoints that used different names - def key_mapping_backbone(key): - key = re.sub(r'^s4seq.encoder.', 'backbone.', key) - key = re.sub(r'^embedding.', 'backbone.embeddings.word_embeddings.', key) - key = re.sub(r'^backbone.norm', 'backbone.ln_0', key) - key = re.sub(r'^backbone.layers.(\d+).mixer.output_linear.', - r'backbone.layers.\1.mixer.out_proj.', key) - return key - state_dict = OrderedDict((key_mapping_backbone(k), v) for k, v in state_dict.items()) - # Remapping from our checkpoints that used a different ordering of layers in the block - # Previous: Mixer / MLP -> Dropout -> Add -> LN - # Current: Dropout -> Add -> LN -> Attn / MLP - if 'backbone.ln_0.weight' in state_dict: - n_layers = len(self.backbone.layers) - ln_weight = state_dict.pop(f'backbone.layers.{n_layers - 1}.norm2.weight') - ln_bias = state_dict.pop(f'backbone.layers.{n_layers - 1}.norm2.bias') - state_dict['backbone.ln_f.weight'] = ln_weight - state_dict['backbone.ln_f.bias'] = ln_bias - for l in reversed(range(n_layers)): - ln_weight = state_dict.pop(f'backbone.layers.{l}.norm1.weight') - ln_bias = state_dict.pop(f'backbone.layers.{l}.norm1.bias') - state_dict[f'backbone.layers.{l}.norm2.weight'] = ln_weight - state_dict[f'backbone.layers.{l}.norm2.bias'] = ln_bias - if l > 0: - ln_weight = state_dict.pop(f'backbone.layers.{l - 1}.norm2.weight') - ln_bias = state_dict.pop(f'backbone.layers.{l - 1}.norm2.bias') - state_dict[f'backbone.layers.{l}.norm1.weight'] = ln_weight - state_dict[f'backbone.layers.{l}.norm1.bias'] = ln_bias - ln_weight = state_dict.pop('backbone.ln_0.weight') - ln_bias = state_dict.pop('backbone.ln_0.bias') - state_dict[f'backbone.layers.0.norm1.weight'] = ln_weight - state_dict[f'backbone.layers.0.norm1.bias'] = ln_bias - # Previously we have separate projection matrices for q, k, v, now we stack them - if 'backbone.layers.0.mixer.q_proj.weight' in state_dict: - n_layers = len(self.backbone.layers) - for l in range(n_layers): - if f'backbone.layers.{l}.mixer.q_proj.weight' in state_dict: - Wq = state_dict.pop(f'backbone.layers.{l}.mixer.q_proj.weight') - Wk = state_dict.pop(f'backbone.layers.{l}.mixer.k_proj.weight') - Wv = state_dict.pop(f'backbone.layers.{l}.mixer.v_proj.weight') - bq = state_dict.pop(f'backbone.layers.{l}.mixer.q_proj.bias') - bk = state_dict.pop(f'backbone.layers.{l}.mixer.k_proj.bias') - bv = state_dict.pop(f'backbone.layers.{l}.mixer.v_proj.bias') - state_dict[f'backbone.layers.{l}.mixer.Wqkv.weight'] = torch.cat( - [Wq, Wk, Wv], dim=0 - ) - state_dict[f'backbone.layers.{l}.mixer.Wqkv.bias'] = torch.cat( - [bq, bk, bv], dim=0 - ) - return super().load_state_dict(state_dict, strict=strict) - - -def shard_state_dict_tp(state_dict, world_size, rank, pad_vocab_size_multiple=1): - """Convert the state_dict of a standard SSM model to the state_dict of a SSM model - with tensor parallel. - """ - layer_idx_match = [re.search(r'backbone\.layers\.(\d+)\.', k) for k in state_dict.keys()] - num_hidden_layers = len(set(m.group(1) for m in layer_idx_match if m is not None)) - vocab_size = state_dict['backbone.embeddings.word_embeddings.weight'].shape[0] - inner_dim, hidden_size = state_dict['backbone.layers.0.mlp.fc1.weight'].shape - vocab_size = (math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple) - assert vocab_size % world_size == 0 - assert hidden_size % world_size == 0 - assert inner_dim % world_size == 0 - - def shard_dim(state_dict, key, dim=0): - x = state_dict[key] - dimension = x.shape[dim] // world_size - state_dict[key] = x.narrow(dim, rank * dimension, dimension) - - def shard_qkv_headdim(state_dict, key): - x = rearrange(state_dict[key], '(three d) ... -> three d ...', three=3) - dim = x.shape[1] // world_size - state_dict[key] = rearrange(x[:, rank * dim:(rank + 1) * dim], - 'three d ... -> (three d) ...') - - shard_dim(state_dict, 'backbone.embeddings.word_embeddings.weight', 0) - if 'lm_head.weight' in state_dict: - shard_dim(state_dict, 'lm_head.weight', 0) - if 'backbone.embeddings.position_embeddings.weight' in state_dict: - shard_dim(state_dict, 'backbone.embeddings.position_embeddings.weight', -1) - for i in range(num_hidden_layers): - shard_qkv_headdim(state_dict, f'backbone.layers.{i}.mixer.Wqkv.weight') - shard_qkv_headdim(state_dict, f'backbone.layers.{i}.mixer.Wqkv.bias') - shard_dim(state_dict, f'backbone.layers.{i}.mixer.out_proj.weight', -1) - if rank != 0: - state_dict.pop(f'backbone.layers.{i}.mixer.out_proj.bias') - shard_dim(state_dict, f'backbone.layers.{i}.mlp.fc1.weight', 0) - shard_dim(state_dict, f'backbone.layers.{i}.mlp.fc1.bias', 0) - shard_dim(state_dict, f'backbone.layers.{i}.mlp.fc2.weight', -1) - if rank != 0: - state_dict.pop(f'backbone.layers.{i}.mlp.fc2.bias') - if f'backbone.layers.{i}.mixer.kernel.kernel.B' in state_dict: - for name in ['D', 'ssm_k_D', 'kernel.kernel.B', 'kernel.kernel.inv_A_real', - 'kernel.kernel.A_imag', 'ssm_k_kernel.kernel.B', 'kernel.kernel.log_dt']: - if f'backbone.layers.{i}.mixer.{name}' in state_dict: - shard_dim(state_dict, f'backbone.layers.{i}.mixer.{name}', 0) - for name in ['kernel.kernel.C', 'ssm_k_kernel.kernel.C']: - if f'backbone.layers.{i}.mixer.{name}' in state_dict: - shard_dim(state_dict, f'backbone.layers.{i}.mixer.{name}', 1) - return state_dict diff --git a/src/clm/src/models/sequence/mha.py b/src/clm/src/models/sequence/mha.py deleted file mode 100644 index 12d55fda..00000000 --- a/src/clm/src/models/sequence/mha.py +++ /dev/null @@ -1,122 +0,0 @@ -""" Wrapper around nn.MultiheadAttention to adhere to SequenceModule interface. """ - -import torch -import torch.nn.functional as F -from torch import nn -import hydra -from clm.src.models.sequence.base import SequenceModule, TransposedModule -import clm.src.models.nn.utils as U -from einops import rearrange - -@TransposedModule -class MultiheadAttention(SequenceModule): - """ Simple wrapper for MultiheadAttention """ - def __init__(self, d_model, n_heads, *args, causal=True, **kwargs): - super().__init__() - self.d_model = d_model - self.d_output = d_model - self.mha = nn.MultiheadAttention(d_model, n_heads, *args, batch_first=True, **kwargs) - self.causal = causal - - def forward(self, src, attn_mask=None, key_padding_mask=None, state=None, **kwargs): - """ state should represent a mask and key padding mask """ - if self.causal and attn_mask is None: - attn_mask = torch.triu(torch.ones(src.size(-2), clm.src.size(-2), - dtype=torch.bool, device=src.device), - diagonal=1) - # attn_mask, key_padding_mask = state - # Note that this returns None for the second argument - y, _ = self.mha(src, src, src, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False) - return y, None - - def step(self, x, state): - # TODO proper cached inference - # x: (B, D) - pass - - -class VitAttention(SequenceModule): - """Copied from implementation for ViT: only used for ViT model - - This attention class makes several simplifying assumptions (commonly satisfied in vision - applications): - 1. q = k = v - 2. No masks: no attention mask, no key padding mask - 3. Embed dimension = Input dimension, i.e. projection matrices are square. - """ - - @property - def d_output(self): - return self.dim - - def __init__( - self, - dim, - num_heads=8, - qkv_bias=False, - qk_scale=None, - attn_drop=0., - # proj_drop=0., - packed_linear=True, - linear_cfg=None, - **kwargs, - ): - """packed_linear: whether to pack all 3 q_proj, k_proj, v_proj into 2 matrix. - This option is to be compatible with T2T-ViT pretrained weights, where there's only one - projection weight matrix. - """ - super().__init__() - self.dim = dim - self.num_heads = num_heads - head_dim = dim // num_heads - - self.scale = qk_scale or head_dim ** -0.5 - - if linear_cfg is not None: - packed_linear = False - self.packed_linear = packed_linear - if packed_linear: - self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) - else: - if linear_cfg is None: - linear_cfg = {'_target_': 'torch.nn.Linear'} - self.q_proj = hydra.utils.instantiate(linear_cfg, dim, dim, bias=qkv_bias, - _recursive_=False) - self.k_proj = hydra.utils.instantiate(linear_cfg, dim, dim, bias=qkv_bias, - _recursive_=False) - self.v_proj = hydra.utils.instantiate(linear_cfg, dim, dim, bias=qkv_bias, - _recursive_=False) - - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Linear(dim, dim) - - # Removing this dropout because we do this in SequenceResidualBlock - # self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x, state=None): - B, N, C = x.shape - if self.packed_linear: - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) - else: - q, k, v = self.q_proj(x), self.k_proj(x), self.v_proj(x) - q, k, v = [rearrange(x, 'b n (h d) -> b h n d', h=self.num_heads) for x in (q, k, v)] - - # attn = (q @ k.transpose(-2, -1) * self.scale) - # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) - bsz, num_heads, q_seq_len, dk = q.size() - _, _, k_seq_len, _ = k.size() - q = rearrange(q, 'b h t d -> (b h) t d') - k = rearrange(k, 'b h s d -> (b h) d s') - # Preallocate attn_weights for `baddbmm` - attn = torch.empty(bsz * num_heads, q_seq_len, k_seq_len, dtype=q.dtype, device=q.device) - attn = rearrange(torch.baddbmm(attn, q, k, beta=0, alpha=self.scale), - '(b h) t s -> b h t s', h = self.num_heads) - - attn = F.softmax(attn, dim=-1, dtype=v.dtype) - attn = self.attn_drop(attn) - - x = (attn @ v).transpose(1, 2).reshape(B, N, C) - x = self.proj(x) - # x = self.proj_drop(x) - return x, None diff --git a/src/clm/src/models/sequence/model.py b/src/clm/src/models/sequence/model.py deleted file mode 100644 index 930c19c6..00000000 --- a/src/clm/src/models/sequence/model.py +++ /dev/null @@ -1,134 +0,0 @@ -""" Isotropic deep sequence model backbone, in the style of ResNets / Transformers. - -The SequenceModel class implements a generic (batch, length, d_input) -> (batch, length, d_output) transformation -""" - -from functools import partial - -import torch -import torch.nn as nn -from einops import rearrange - -from clm.src.utils.config import to_list, to_dict -from clm.src.models.sequence.block import SequenceResidualBlock -from clm.src.models.sequence.base import SequenceModule -from clm.src.models.nn.components import Normalization, DropoutNd - - -class SequenceModel(SequenceModule): - def __init__( - self, - d_model, # Resize input (useful for deep models with residuals) - n_layers=1, # Number of layers - transposed=False, # Transpose inputs so each layer receives (batch, dim, length) - dropout=0.0, # Dropout parameter applied on every residual and every layer - tie_dropout=False, # Tie dropout mask across sequence like nn.Dropout1d/nn.Dropout2d - prenorm=True, # Pre-norm vs. post-norm - n_repeat=1, # Each layer is repeated n times per stage before applying pooling - layer=None, # Layer config, must be specified - residual=None, # Residual config - norm=None, # Normalization config (e.g. layer vs batch) - pool=None, # Config for pooling layer per stage - track_norms=True, # Log norms of each layer output - dropinp=0.0, # Input dropout - ): - super().__init__() - # Save arguments needed for forward pass - self.d_model = d_model - self.transposed = transposed - self.track_norms = track_norms - - # Input dropout (not really used) - dropout_fn = partial(DropoutNd, transposed=self.transposed) if tie_dropout else nn.Dropout - self.drop = dropout_fn(dropinp) if dropinp > 0.0 else nn.Identity() - - layer = to_list(layer, recursive=False) - - # Some special arguments are passed into each layer - for _layer in layer: - # If layers don't specify dropout, add it - if _layer.get('dropout', None) is None: - _layer['dropout'] = dropout - # Ensure all layers are shaped the same way - _layer['transposed'] = transposed - - # Duplicate layers - layers = layer * n_layers * n_repeat - - # Instantiate layers - _layers = [] - d = d_model - for l, layer in enumerate(layers): - # Pool at the end of every n_repeat blocks - pool_cfg = pool if (l+1) % n_repeat == 0 else None - block = SequenceResidualBlock(d, l+1, prenorm=prenorm, dropout=dropout, tie_dropout=tie_dropout, transposed=transposed, layer=layer, residual=residual, norm=norm, pool=pool_cfg) - _layers.append(block) - d = block.d_output - - self.d_output = d - self.layers = nn.ModuleList(_layers) - if prenorm: - if norm is None: - self.norm = None - elif isinstance(norm, str): - self.norm = Normalization(self.d_output, transposed=self.transposed, _name_=norm) - else: - self.norm = Normalization(self.d_output, transposed=self.transposed, **norm) - else: - self.norm = nn.Identity() - - def forward(self, inputs, *args, state=None, **kwargs): - """ Inputs assumed to be (batch, sequence, dim) """ - if self.transposed: inputs = rearrange(inputs, 'b ... d -> b d ...') - inputs = self.drop(inputs) - - # Track norms - if self.track_norms: output_norms = [torch.mean(inputs.detach() ** 2)] - - # Apply layers - outputs = inputs - prev_states = [None] * len(self.layers) if state is None else state - next_states = [] - for layer, prev_state in zip(self.layers, prev_states): - outputs, state = layer(outputs, *args, state=prev_state, **kwargs) - next_states.append(state) - if self.track_norms: output_norms.append(torch.mean(outputs.detach() ** 2)) - if self.norm is not None: outputs = self.norm(outputs) - - if self.transposed: outputs = rearrange(outputs, 'b d ... -> b ... d') - - if self.track_norms: - metrics = to_dict(output_norms, recursive=False) - self.metrics = {f'norm/{i}': v for i, v in metrics.items()} - - return outputs, next_states - - @property - def d_state(self): - d_states = [layer.d_state for layer in self.layers] - return sum([d for d in d_states if d is not None]) - - @property - def state_to_tensor(self): - # Slightly hacky way to implement this in a curried manner (so that the function can be extracted from an instance) - # Somewhat more sound may be to turn this into a @staticmethod and grab subclasses using hydra.utils.get_class - def fn(state): - x = [_layer.state_to_tensor(_state) for (_layer, _state) in zip(self.layers, state)] - x = [_x for _x in x if _x is not None] - return torch.cat( x, dim=-1) - return fn - - def default_state(self, *batch_shape, device=None): - return [layer.default_state(*batch_shape, device=device) for layer in self.layers] - - def step(self, x, state, **kwargs): - # Apply layers - prev_states = [None] * len(self.layers) if state is None else state - next_states = [] - for layer, prev_state in zip(self.layers, prev_states): - x, state = layer.step(x, state=prev_state, **kwargs) - next_states.append(state) - - x = self.norm(x) - - return x, next_states diff --git a/src/clm/src/models/sequence/pool.py b/src/clm/src/models/sequence/pool.py deleted file mode 100644 index e5b8f4c6..00000000 --- a/src/clm/src/models/sequence/pool.py +++ /dev/null @@ -1,459 +0,0 @@ -"""Implements downsampling and upsampling on sequences.""" - -import torch -from torch import nn -import torch.nn.functional as F -from einops import rearrange, repeat, reduce - -from clm.src.models.sequence import SequenceModule -from clm.src.models.nn import LinearActivation - -""" Simple pooling functions that just downsample or repeat - -stride: Subsample on the layer dimension -expand: Repeat on the feature dimension -""" - - -class DownSample(SequenceModule): - def __init__(self, d_input, stride=1, expand=1, transposed=True): - super().__init__() - self.d_input = d_input - self.stride = stride - self.expand = expand - self.transposed = transposed - - def forward(self, x): - if x is None: return None - if self.stride > 1: - assert x.ndim == 3, "Downsampling with higher-dimensional inputs is currently not supported. It is recommended to use average or spectral pooling instead." - if self.transposed: - x = x[..., 0::self.stride] - else: - x = x[..., 0::self.stride, :] - - if self.expand > 1: - if self.transposed: - x = repeat(x, 'b d ... -> b (d e) ...', e=self.expand) - else: - x = repeat(x, 'b ... d -> b ... (d e)', e=self.expand) - - return x, None - - - def step(self, x, state, **kwargs): - if self.stride > 1 or self.expand > 1: - raise NotImplementedError - return x, state - - @property - def d_output(self): - return self.d_input * self.expand - -class DownAvgPool(SequenceModule): - def __init__(self, d_input, stride=1, expand=None, transposed=True): - super().__init__() - self.d_input = d_input - self.stride = stride - self.expand = expand - self.transposed = transposed - - if self.expand is not None: - self.linear = LinearActivation( - d_input, - d_input * expand, - transposed=transposed, - ) - - def forward(self, x): - if not self.transposed: - x = rearrange(x, 'b ... d -> b d ...') - - if self.stride > 1: - # einops appears slower than F - if x.ndim == 3: - x = F.avg_pool1d(x, self.stride, self.stride) - elif x.ndim == 4: - x = F.avg_pool2d(x, self.stride, self.stride) - else: - # Reduction string e.g. "b d (l1 2) (l2 2) -> b d l1 l2" - reduce_str = "b d " + " ".join([f"(l{i} {self.stride})" for i in range(x.ndim-2)]) \ - + " -> b d " + " ".join([f"l{i}" for i in range(x.ndim-2)]) - x = reduce(x, reduce_str, 'mean') - - # if self.expand > 1: - # x = repeat(x, 'b d ... -> b (d e) ...', e=self.expand) - - if not self.transposed: - x = rearrange(x, 'b d ... -> b ... d') - if self.expand is not None: - x = self.linear(x) - return x, None - - def step(self, x, state, **kwargs): - if self.stride > 1 or self.expand > 1: - raise NotImplementedError - return x, state - - @property - def d_output(self): - if self.expand is None: - return self.d_input - else: - return self.d_input * self.expand - -class DownSpectralPool(SequenceModule): - def __init__(self, d_input, stride=1, expand=1, transposed=True): - super().__init__() - self.d_input = d_input - self.stride = stride - self.expand = expand - self.transposed = transposed - - def forward(self, x): - """ - x: (B, L..., D) - """ - if not self.transposed: - x = rearrange(x, 'b ... d -> b d ...') - shape = x.shape[2:] - x_f = torch.fft.ifftn(x, s=shape) - - for axis, l in enumerate(shape): - assert l % self.stride == 0, 'input length must be divisible by stride' - new_l = l // self.stride - idx = torch.cat([torch.arange(0, new_l-new_l//2), l+torch.arange(-new_l//2, 0)]).to(x_f.device) - x_f = torch.index_select(x_f, 2+axis, idx) - x = torch.fft.ifftn(x_f, s=[l//self.stride for l in shape]) - x = x.real - - if self.expand > 1: - x = repeat(x, 'b d ... -> b (d e) ...', e=self.expand) - if not self.transposed: - x = rearrange(x, 'b d ... -> b ... d') - return x, None - - def step(self, x, state, **kwargs): - if self.stride > 1 or self.expand > 1: - raise NotImplementedError - return x, state - - @property - def d_output(self): - return self.d_input * self.expand - -class UpSample(SequenceModule): - def __init__(self, d_input, stride=1, expand=1, transposed=True): - super().__init__() - self.d_input = d_input - self.stride = stride - self.expand = expand - self.transposed = transposed - - def forward(self, x): - if x is None: return None - if self.expand > 1: - if self.transposed: - x = reduce(x, '... (d e) l -> ... d l', 'mean', e=self.expand) - else: - x = reduce(x, '... (d e) -> ... d', 'mean', e=self.expand) - if self.stride > 1: - if self.transposed: - x = repeat(x, '... l -> ... (l e)', e=self.stride) - else: - x = repeat(x, '... l d -> ... (l e) d', e=self.stride) - return x, None - - @property - def d_output(self): - return self.d_input // self.expand - def step(self, x, state, **kwargs): - if self.stride > 1 or self.expand > 1: - raise NotImplementedError - return x, state -class UpAvgPool(SequenceModule): - def __init__(self, d_input, stride=1, expand=1, causal=False, transposed=True): - super().__init__() - assert d_input % expand == 0 - self.d_input = d_input - self.stride = stride - self.expand = expand - self.causal = causal - self.transposed = transposed - - self.linear = LinearActivation( - d_input, - d_input // expand, - transposed=transposed, - ) - - def forward(self, x): - # TODO only works for 1D right now - if x is None: return None - x = self.linear(x) - if self.stride > 1: - if self.transposed: - if self.causal: - x = F.pad(x[..., :-1], (1, 0)) # Shift to ensure causality - x = repeat(x, '... l -> ... (l e)', e=self.stride) - else: - if self.causal: - x = F.pad(x[..., :-1, :], (0, 0, 1, 0)) # Shift to ensure causality - x = repeat(x, '... l d -> ... (l e) d', e=self.stride) - return x, None - - @property - def d_output(self): - return self.d_input // self.expand - def step(self, x, state, **kwargs): - if self.stride > 1 or self.expand > 1: - raise NotImplementedError - return x, state - -class DownLinearPool(SequenceModule): - def __init__(self, d_model, stride=1, expand=1, causal=False, transposed=True): - super().__init__() - - self.d_model = d_model - self.stride = stride - self.expand = expand - self.transposed = transposed - - self.linear = LinearActivation( - d_model * stride, - d_model * expand, - transposed=transposed, - ) - - def forward(self, x): - if self.transposed: - x = rearrange(x, '... h (l s) -> ... (h s) l', s=self.stride) - else: - x = rearrange(x, '... (l s) h -> ... l (h s)', s=self.stride) - x = self.linear(x) - return x, None - - def step(self, x, state, **kwargs): - # if self.stride > 1 or self.expand > 1: - # raise NotImplementedError - # return x, state - if x is None: return None, state - state.append(x) - if len(state) == self.stride: - x = rearrange(torch.stack(state, dim=-1), '... h s -> ... (h s)') - if self.transposed: x = x.unsqueeze(-1) - x = self.linear(x) - if self.transposed: x = x.squeeze(-1) - return x, [] - else: - return None, state - - def default_state(self, *batch_shape, device=None): - return [] - - @property - def d_output(self): - return self.d_input * self.expand - -class UpLinearPool(SequenceModule): - def __init__(self, d, stride=1, expand=1, causal=False, transposed=True): - super().__init__() - - # self.d_model = d * expand - # self.d_output = d - assert d % expand == 0 - self.d_model = d - self.d_output = d // expand - # self._d_output = d_output - self.stride = stride - self.causal = causal - self.transposed = transposed - - self.linear = LinearActivation( - self.d_model, - self.d_output * stride, - transposed=transposed, - ) - - def forward(self, x, skip=None): - x = self.linear(x) - if self.transposed: - if self.causal: - x = F.pad(x[..., :-1], (1, 0)) # Shift to ensure causality - x = rearrange(x, '... (h s) l -> ... h (l s)', s=self.stride) - else: - if self.causal: - x = F.pad(x[..., :-1, :], (0, 0, 1, 0)) # Shift to ensure causality - x = rearrange(x, '... l (h s) -> ... (l s) h', s=self.stride) - if skip is not None: - x = x + skip - return x, None - - def step(self, x, state, **kwargs): - """ - x: (..., H) - """ - - assert len(state) > 0 - y, state = state[0], state[1:] - if len(state) == 0: - assert x is not None - if self.transposed: x = x.unsqueeze(-1) - x = self.linear(x) - if self.transposed: x = x.squeeze(-1) - x = rearrange(x, '... (h s) -> ... h s', s=self.stride) - state = list(torch.unbind(x, dim=-1)) - else: assert x is None - return y, state - - def default_state(self, *batch_shape, device=None): - state = torch.zeros(batch_shape + (self.d_output, self.stride), device=device) # (batch, h, s) - state = list(torch.unbind(state, dim=-1)) # List of (..., H) - return state - - # @property - # def d_output(self): return self._d_output - -""" Pooling functions with trainable parameters """ # TODO make d_output expand instead - -class DownPool2d(SequenceModule): - - def __init__(self, d_input, d_output, stride=1, transposed=True, weight_norm=True): - super().__init__() - - self.linear = LinearActivation( - d_input, - d_output, - transposed=transposed, - weight_norm=weight_norm, - ) - - self.pool = nn.AvgPool2d(kernel_size=stride, stride=stride), - - def forward(self, x): - if self.transposed: - x = self.pool(x) - -# TODO DownPool/UpPool are currently used by unet/sashimi backbones -# DownLinearPool is used by the registry (for isotropic backbone) -# DownPool is essentially the same as DownLinearPool. These should be consolidated -class DownPool(SequenceModule): - def __init__(self, d_input, d_output=None, expand=None, stride=1, transposed=True, weight_norm=True, initializer=None, activation=None): - super().__init__() - assert (d_output is None) + (expand is None) == 1 - if d_output is None: d_output = d_input * expand - - self.d_output = d_output - self.stride = stride - self.transposed = transposed - - self.linear = LinearActivation( - d_input * stride, - d_output, - transposed=transposed, - initializer=initializer, - weight_norm = weight_norm, - activation=activation, - activate=True if activation is not None else False, - ) - - def forward(self, x): - if self.transposed: - x = rearrange(x, '... h (l s) -> ... (h s) l', s=self.stride) - else: - x = rearrange(x, '... (l s) h -> ... l (h s)', s=self.stride) - x = self.linear(x) - return x, None - - def step(self, x, state, **kwargs): - """ - x: (..., H) - """ - - if x is None: return None, state - state.append(x) - if len(state) == self.stride: - x = rearrange(torch.stack(state, dim=-1), '... h s -> ... (h s)') - if self.transposed: x = x.unsqueeze(-1) - x = self.linear(x) - if self.transposed: x = x.squeeze(-1) - return x, [] - else: - return None, state - - def default_state(self, *batch_shape, device=None): - return [] - - -class UpPool(SequenceModule): - def __init__(self, d_input, d_output, stride, transposed=True, weight_norm=True, initializer=None, activation=None): - super().__init__() - - self.d_input = d_input - self._d_output = d_output - self.stride = stride - self.transposed = transposed - - self.linear = LinearActivation( - d_input, - d_output * stride, - transposed=transposed, - initializer=initializer, - weight_norm = weight_norm, - activation=activation, - activate=True if activation is not None else False, - ) - - def forward(self, x, skip=None): - x = self.linear(x) - if self.transposed: - x = F.pad(x[..., :-1], (1, 0)) # Shift to ensure causality - x = rearrange(x, '... (h s) l -> ... h (l s)', s=self.stride) - else: - x = F.pad(x[..., :-1, :], (0, 0, 1, 0)) # Shift to ensure causality - x = rearrange(x, '... l (h s) -> ... (l s) h', s=self.stride) - if skip is not None: - x = x + skip - return x, None - - def step(self, x, state, **kwargs): - """ - x: (..., H) - """ - - assert len(state) > 0 - y, state = state[0], state[1:] - if len(state) == 0: - assert x is not None - if self.transposed: x = x.unsqueeze(-1) - x = self.linear(x) - if self.transposed: x = x.squeeze(-1) - x = rearrange(x, '... (h s) -> ... h s', s=self.stride) - state = list(torch.unbind(x, dim=-1)) - else: assert x is None - return y, state - - def default_state(self, *batch_shape, device=None): - state = torch.zeros(batch_shape + (self.d_output, self.stride), device=device) # (batch, h, s) - state = list(torch.unbind(state, dim=-1)) # List of (..., H) - return state - - @property - def d_output(self): return self._d_output - -registry = { - 'sample': DownSample, - 'pool': DownAvgPool, - 'avg': DownAvgPool, - 'linear': DownLinearPool, - 'spectral': DownSpectralPool, -} - -up_registry = { - # 'sample': UpSample, - 'pool': UpAvgPool, - 'avg': UpAvgPool, - 'linear': UpLinearPool, - # 'spectral': UpSpectralPool, # Not implemented and no way to make this causal -} - diff --git a/src/clm/src/models/sequence/simple_lm.py b/src/clm/src/models/sequence/simple_lm.py deleted file mode 100644 index bc525d55..00000000 --- a/src/clm/src/models/sequence/simple_lm.py +++ /dev/null @@ -1,469 +0,0 @@ -# Copyright (c) 2023, Tri Dao, Dan Fu. -# Simplified, mostly standalone version of LongConvLM for synthetics. - -import math -from functools import partial - -from collections import namedtuple - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchvision.ops import StochasticDepth - -from einops import rearrange - -from clm.src.utils import instantiate -import clm.src.utils.registry as registry - -class LinearResidual(nn.Linear): - """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense. - """ - - def forward(self, input: torch.Tensor) -> torch.Tensor: - return super().forward(input), input - -class SelfAttention(nn.Module): - """Implement the scaled dot product attention with softmax. - Arguments - --------- - softmax_scale: The temperature to use for the softmax attention. - (default: 1/sqrt(d_keys) where d_keys is computed at - runtime) - attention_dropout: The dropout rate to apply to the attention - (default: 0.0) - """ - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): - super().__init__() - self.causal = causal - self.softmax_scale = softmax_scale - self.dropout_p = attention_dropout - - def forward(self, qkv, causal=None, key_padding_mask=None): - """Implements the multihead softmax attention. - Arguments - --------- - qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) - causal: if passed, will override self.causal - key_padding_mask: boolean mask to apply to the attention weights. True means to keep, - False means to mask out. (B, S) - """ - batch_size, seqlen = qkv.shape[0], qkv.shape[1] - causal = self.causal if causal is None else causal - q, k, v = qkv.unbind(dim=2) - softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) - scores = torch.einsum('bthd,bshd->bhts', q, k * softmax_scale) - if key_padding_mask is not None: - padding_mask = torch.full((batch_size, seqlen), -10000.0, dtype=scores.dtype, - device=scores.device) - padding_mask.masked_fill_(key_padding_mask, 0.0) - # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) - scores = scores + rearrange(padding_mask, 'b s -> b 1 1 s') - if causal: - # "triu_tril_cuda_template" not implemented for 'BFloat16' - # So we have to construct the mask in float - causal_mask = torch.triu(torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1) - # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) - scores = scores + causal_mask.to(dtype=scores.dtype) - attention = torch.softmax(scores, dim=-1, dtype=v.dtype) - attention_drop = F.dropout(attention, self.dropout_p if self.training else 0.0) - output = torch.einsum('bhts,bshd->bthd', attention_drop, v) - return output - -class MHA(nn.Module): - """Multi-head self-attention and cross-attention - """ - - def __init__(self, embed_dim, num_heads, bias=True, dropout=0.0, - softmax_scale=None, causal=False, layer_idx=None, dwconv=False,return_residual=False,device=None, dtype=None) -> None: - """ - return_residual: whether to return the input x along with the output. This is for - performance reason: for post-norm architecture, returning the input allows us - to fuse the backward of nn.Linear with the residual connection. - """ - factory_kwargs = {'device': device, 'dtype': dtype} - super().__init__() - self.embed_dim = embed_dim - self.causal = causal - self.layer_idx = layer_idx - self.dwconv = dwconv - self.return_residual = return_residual - - self.num_heads = num_heads - assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" - self.head_dim = self.embed_dim // num_heads - - linear_cls = nn.Linear - linear_resid_cls = LinearResidual - inner_attn_cls = SelfAttention - - if not self.return_residual: - self.Wqkv = linear_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) - else: - self.Wqkv = linear_resid_cls(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) - if self.dwconv: - self.dwconv_qkv = nn.Conv1d(3 * embed_dim, 3 * embed_dim, kernel_size=3, padding=2, - groups=3 * embed_dim) - - self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, - attention_dropout=dropout) - - # output projection always have the bias (for now) - self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs) - - def forward(self, x, key_padding_mask=None, **kwargs): - """ - Arguments: - x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if - cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total - is the is the sum of the sequence lengths in the batch. - cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths - of the sequences in the batch, used to index into x. Only applicable when using - FlashAttention. - max_seqlen: int. Maximum sequence length in the batch. - key_padding_mask: boolean mask, True means to keep, False means to mask out. - (batch, seqlen). Only applicable when not using FlashAttention. - mixer_subset: for cross-attention only. If not None, will take a subset of x - before applying the query projection. Useful for e.g., ViT where we only care - about the CLS token in the last layer. - inference_params: for generation. Adapted from Megatron-LM (and Apex) - https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 - """ - - kwargs = ({'key_padding_mask': key_padding_mask, **kwargs}) - - if not self.return_residual: - qkv = self.Wqkv(x) - else: - qkv, x = self.Wqkv(x) - if self.dwconv: - qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2], - 'b d s -> b s d').contiguous() - qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim) - - context = self.inner_attn(qkv, **kwargs) - - out = self.out_proj(rearrange(context, '... h d -> ... (h d)')) - return out if not self.return_residual else (out, x) - - -class GPT2Embeddings(nn.Module): - - def __init__(self, embed_dim, vocab_size, max_position_embeddings, padding_idx=None, - word_embed_proj_dim=None, device=None, dtype=None): - """ - If max_position_embeddings <= 0, there's no position embeddings - If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension - the project up to embed_dim - """ - factory_kwargs = {'device': device, 'dtype': dtype} - super().__init__() - if word_embed_proj_dim is None: - self.word_embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=padding_idx, - **factory_kwargs) - self.project_in = None - else: - self.word_embeddings = nn.Embedding(vocab_size, word_embed_proj_dim, - padding_idx=padding_idx, **factory_kwargs) - self.project_in = nn.Linear(word_embed_proj_dim, embed_dim, bias=False, - **factory_kwargs) - self.max_position_embeddings = max_position_embeddings - if self.max_position_embeddings > 0: - self.position_embeddings = nn.Embedding(max_position_embeddings, embed_dim, - **factory_kwargs) - - def forward(self, input_ids, position_ids=None): - """ - input_ids: (batch, seqlen) - position_ids: (batch, seqlen) - """ - batch_size, seqlen = input_ids.shape - embeddings = self.word_embeddings(input_ids) - if self.project_in is not None: - embeddings = self.project_in(embeddings) - if self.max_position_embeddings > 0: - if position_ids is None: - position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) - position_embeddings = self.position_embeddings(position_ids) - embeddings = embeddings + position_embeddings - return embeddings - -class Mlp(nn.Module): - - def __init__(self, in_features, hidden_features=None, out_features=None, activation=F.gelu, - return_residual=False, device=None, dtype=None): - """ - From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/mlp.py - """ - factory_kwargs = {'device': device, 'dtype': dtype} - super().__init__() - out_features = out_features or in_features - hidden_features = hidden_features or in_features - self.return_residual = return_residual - self.fc1 = nn.Linear(in_features, hidden_features, **factory_kwargs) - self.activation = activation - self.fc2 = nn.Linear(hidden_features, out_features, **factory_kwargs) - - def forward(self, x): - y = self.fc1(x) - y = self.activation(y) - y = self.fc2(y) - return y if not self.return_residual else (y, x) - -class Block(nn.Module): - - def __init__(self, dim, mixer_cls=None, mlp_cls=None, norm_cls=nn.LayerNorm, - dropout_cls=nn.Dropout, prenorm=True, resid_dropout1=0., resid_dropout2=0., - drop_path1=0., drop_path2=0., - return_residual=False, - residual_in_fp32=False): - """ - From https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/modules/block.py - For prenorm=True, this Block has a slightly different structure compared to a regular - prenorm Transformer block. - The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add. - [Ref: https://arxiv.org/abs/2002.04745] - Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both - the hidden_states (output of the MLP) and the residual. - This is for performance reasons, as we can fuse the dropout, add and LayerNorm. - The residual needs to be provided (except for the very first block). - For prenorm=False, this Block has the same structure as a regular postnorm Transformer - block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN. - return_residual: whether each of the sub-layers (mixer and mlp) will return the residual. - This is for performance reason: for post-norm architecture, returning the input allows us - to fuse the backward of nn.Linear with the residual connection. - """ - super().__init__() - self.prenorm = prenorm - self.return_residual = return_residual - self.residual_in_fp32 = residual_in_fp32 - if self.residual_in_fp32: - assert self.prenorm, 'residual_in_fp32 is only compatible with prenorm=True' - if mixer_cls is None: - mixer_cls = partial(MHA, num_heads=dim // 64) - if mlp_cls is None: - mlp_cls = partial(Mlp, hidden_features=4 * dim) - self.mixer = mixer_cls(dim) - self.dropout1 = dropout_cls(resid_dropout1) - self.drop_path1 = StochasticDepth(drop_path1, mode='row') - self.norm1 = norm_cls(dim) - self.mlp = mlp_cls(dim) - if not isinstance(self.mlp, nn.Identity): - self.dropout2 = dropout_cls(resid_dropout2) - self.drop_path2 = StochasticDepth(drop_path2, mode='row') - self.norm2 = norm_cls(dim) - - def forward(self, hidden_states, residual = None, - mixer_subset=None, mixer_kwargs=None): - r"""Pass the input through the encoder layer. - Args: - hidden_states: the sequence to the encoder layer (required). - residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual)) - mixer_subset: for cross-attention only. If not None, will take a subset of x - before applying the query projection. Useful for e.g., ViT where we only care - about the CLS token in the last layer. - """ - if self.prenorm: - dropped = self.drop_path1(self.dropout1(hidden_states)) - residual = (dropped + residual) if residual is not None else dropped - hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) - if self.residual_in_fp32: - residual = residual.to(torch.float32) - if mixer_kwargs is None: - mixer_kwargs = {} - if mixer_subset is not None: - mixer_kwargs['mixer_subset'] = mixer_subset - hidden_states = self.mixer(hidden_states, **mixer_kwargs) - if mixer_subset is not None: - residual = residual[:, mixer_subset] - if not isinstance(self.mlp, nn.Identity): - dropped = self.drop_path2(self.dropout2(hidden_states)) - residual = (dropped + residual) if residual is not None else dropped - hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) - if self.residual_in_fp32: - residual = residual.to(torch.float32) - - hidden_states = self.mlp(hidden_states) - return hidden_states, residual - else: - assert residual is None - mixer_out = self.mixer( - hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {}) - ) - if self.return_residual: # mixer out is actually a pair here - mixer_out, hidden_states = mixer_out - - hidden_states = self.norm1((self.drop_path1(self.dropout1(mixer_out)) - + hidden_states).to(dtype=self.norm1.weight.dtype)) - - if not isinstance(self.mlp, nn.Identity): - mlp_out = self.mlp(hidden_states) - if self.return_residual: # mlp out is actually a pair here - mlp_out, hidden_states = mlp_out - - hidden_states = self.norm2((self.drop_path2(self.dropout2(mlp_out)) - + hidden_states).to(dtype=self.norm2.weight.dtype)) - - return hidden_states - -def create_mixer_cls(layer=None, - attn_layer_idx=None, attn_cfg=None, layer_idx=None, - device=None, dtype=None): - factory_kwargs = {'device': device, 'dtype': dtype} - if attn_layer_idx is not None and layer_idx in attn_layer_idx: - causal = True if attn_cfg is None else attn_cfg.pop('causal', True) - - mha_cls = MHA - - mixer_cls = partial(mha_cls, causal=causal, layer_idx=layer_idx, - **(attn_cfg if attn_cfg is not None else {}),**factory_kwargs) - else: - mixer_cls = instantiate(registry.layer, layer, partial=True, layer_idx=layer_idx, **factory_kwargs) - return mixer_cls - - -def create_mlp_cls(d_model, d_inner=None, device=None, dtype=None): - factory_kwargs = {'device': device, 'dtype': dtype} - inner_dim = d_inner if d_inner is not None else 4 * d_model - - mlp_cls = partial(Mlp, hidden_features=inner_dim, - activation=partial(F.gelu, approximate='tanh'), **factory_kwargs) - - return mlp_cls - - -def create_block(d_model, d_inner=None, - layer=None, attn_layer_idx=None, - attn_cfg=None, layer_norm_epsilon=1e-5, - resid_dropout1=0.0, resid_dropout2=0.0, residual_in_fp32=False, - layer_idx=None, - device=None, dtype=None): - factory_kwargs = {'device': device, 'dtype': dtype} - mixer_cls = create_mixer_cls(layer=layer, - attn_layer_idx=attn_layer_idx, - attn_cfg=attn_cfg, layer_idx=layer_idx, - **factory_kwargs) - mlp_cls = create_mlp_cls(d_model, d_inner=d_inner, - **factory_kwargs) - norm_cls = partial(nn.LayerNorm, eps=layer_norm_epsilon, **factory_kwargs) - block = Block(d_model, mixer_cls, mlp_cls, norm_cls=norm_cls, - prenorm=True, resid_dropout1=resid_dropout1, resid_dropout2=resid_dropout2,residual_in_fp32=residual_in_fp32) - block.layer_idx = layer_idx - return block - - -# https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454 -def _init_weights(module, n_layer, initializer_range=0.02, rescale_prenorm_residual=True, - glu_act=False): - if isinstance(module, nn.Linear): - nn.init.normal_(module.weight, std=initializer_range) - if module.bias is not None: - nn.init.zeros_(module.bias) - elif isinstance(module, nn.Embedding): - nn.init.normal_(module.weight, std=initializer_range) - - if rescale_prenorm_residual: - # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme: - # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale - # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers. - # > -- GPT-2 :: https://openai.com/blog/better-language-models/ - # - # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py - for name, p in module.named_parameters(): - if name in ["out_proj.weight", "fc2.weight"]: - # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer)) - # If using GLU activation for now, we scale the std by 2 - elif name in ["output_linear.0.weight"]: - # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block - if not glu_act: - nn.init.normal_(p, mean=0.0, std=initializer_range / math.sqrt(2 * n_layer)) - else: - out_features = p.shape[0] - # Multiplying the first half of the matrix by 2 since sigmoid scales it down by 0.5 - # on average. - nn.init.normal_(p[:out_features // 2], mean=0.0, std=initializer_range / math.sqrt(2 * n_layer) * 2) - - -class LMBackbone(nn.Module): - - def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_size: int, - process_group=None, layer=None, - attn_layer_idx=None, attn_cfg=None, max_position_embeddings=0, - resid_dropout: float = 0.0, embed_dropout: float = 0.1, - layer_norm_epsilon: float = 1e-5, initializer_cfg=None,residual_in_fp32=False, - device=None, dtype=None, **kwargs) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} - super().__init__() - self.process_group = process_group - self.residual_in_fp32 = residual_in_fp32 - self.embeddings = GPT2Embeddings(d_model, vocab_size, max_position_embeddings, - **factory_kwargs) - - - self.layers = nn.ModuleList([create_block( - d_model, d_inner=d_inner, - layer=layer, attn_layer_idx=attn_layer_idx, - attn_cfg=attn_cfg, layer_norm_epsilon=layer_norm_epsilon, - resid_dropout1=embed_dropout if i == 0 else resid_dropout, - resid_dropout2=resid_dropout, residual_in_fp32=residual_in_fp32,layer_idx=i, - **factory_kwargs, - ) for i in range(n_layer)]) - - self.drop_f = nn.Dropout(resid_dropout) - self.ln_f = nn.LayerNorm(d_model, eps=layer_norm_epsilon, **factory_kwargs) - - self.apply(partial(_init_weights, n_layer=n_layer, - **(initializer_cfg if initializer_cfg is not None else {}))) - - def forward(self, input_ids, position_ids=None): - hidden_states = self.embeddings(input_ids, position_ids=position_ids,) - residual = None - - for layer in self.layers: - hidden_states, residual = layer(hidden_states, residual) - - dropped = self.drop_f(hidden_states) - residual = (dropped + residual) if residual is not None else dropped - hidden_states = self.ln_f(residual.to(dtype=self.ln_f.weight.dtype)) - - return hidden_states - - -class SimpleLMHeadModel(nn.Module): - - def __init__(self, d_model: int, n_layer: int, d_inner: int, vocab_size: int, - layer=None, - attn_layer_idx=None, attn_cfg=None, max_position_embeddings=0, - resid_dropout: float = 0.0, embed_dropout: float = 0.1, - layer_norm_epsilon: float = 1e-5, initializer_cfg=None,residual_in_fp32=False, - pad_vocab_size_multiple: int = 1, - device=None, dtype=None, **kwargs) -> None: - factory_kwargs = {'device': device, 'dtype': dtype} - super().__init__() - if vocab_size % pad_vocab_size_multiple != 0: - vocab_size += pad_vocab_size_multiple - (vocab_size % pad_vocab_size_multiple) - self.backbone = LMBackbone( - d_model=d_model, n_layer=n_layer, d_inner=d_inner, vocab_size=vocab_size, - layer=layer, attn_layer_idx=attn_layer_idx, attn_cfg=attn_cfg, - max_position_embeddings=max_position_embeddings, - resid_dropout=resid_dropout, embed_dropout=embed_dropout, - layer_norm_epsilon=layer_norm_epsilon, - initializer_cfg=initializer_cfg, residual_in_fp32=residual_in_fp32, - **factory_kwargs, **kwargs - ) - self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs) - - # Initialize weights and apply final processing - self.apply(partial(_init_weights, n_layer=n_layer, - **(initializer_cfg if initializer_cfg is not None else {}))) - self.tie_weights() - - def tie_weights(self): - self.lm_head.weight = self.backbone.embeddings.word_embeddings.weight - - def forward(self, input_ids, position_ids=None, state=None): # state for the repo interface - hidden_states = self.backbone(input_ids, position_ids=position_ids) - lm_logits = self.lm_head(hidden_states) - CausalLMOutput = namedtuple('CausalLMOutput', ['logits']) - return CausalLMOutput(logits=lm_logits), None diff --git a/src/clm/src/models/sequence/ssm/dplr.py b/src/clm/src/models/sequence/ssm/dplr.py deleted file mode 100644 index a817cc07..00000000 --- a/src/clm/src/models/sequence/ssm/dplr.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copied from https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/sequence/ss/dplr.py - -"""Initializations of structured state space models""" -import math -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange, repeat - -from clm.src.models.sequence.ssm import hippo - - -def dplr(scaling='linear', N=64, rank=1, H=1, dtype=torch.float, real_scale=1.0, imag_scale=1.0, random_real=False, random_imag=False, normalize=False, diagonal=True, random_B=False): - assert dtype == torch.float or dtype == torch.double - dtype = torch.cfloat if dtype == torch.float else torch.cdouble - - pi = torch.tensor(math.pi) - if random_real: - real_part = torch.rand(H, N//2) - else: - real_part = .5 * torch.ones(H, N//2) - if random_imag: - imag_part = N//2 * torch.rand(H, N//2) - else: - imag_part = repeat(torch.arange(N//2), 'n -> h n', h=H) - - real_part = real_scale * real_part - if scaling == 'random': - imag_part = torch.randn(H, N//2) - elif scaling == 'real': - imag_part = 0 * imag_part - real_part = 1 + repeat(torch.arange(N//2), 'n -> h n', h=H) - elif scaling in ['linear', 'lin']: - imag_part = pi * imag_part - elif scaling in ['inverse', 'inv']: # Based on asymptotics of the default HiPPO matrix - imag_part = 1/pi * N * (N/(1+2*imag_part)-1) - elif scaling in ['inverse2', 'inv2']: - imag_part = 1/pi * N * (N/(1+imag_part)-1) - elif scaling in ['quadratic', 'quad']: - imag_part = 1/pi * (1+2*imag_part)**2 - elif scaling in ['legs', 'hippo']: - w, _, _, _ = hippo.nplr('legsd', N) - imag_part = w.imag - - else: raise NotImplementedError - imag_part = imag_scale * imag_part - w = -real_part + 1j * imag_part - - # Initialize B - if random_B: - B = torch.randn(H, N//2, dtype=dtype) - else: - B = torch.ones(H, N//2, dtype=dtype) - - if normalize: - norm = -B/w # (H, N) # Result if you integrate the kernel with constant 1 function - zeta = 2*torch.sum(torch.abs(norm)**2, dim=-1, keepdim=True) # Variance with a random C vector - B = B / zeta**.5 - - P = torch.randn(rank, H, N//2, dtype=dtype) - if diagonal: P = P * 0.0 - V = torch.eye(N, dtype=dtype)[:, :N//2] # Only used in testing - V = repeat(V, 'n m -> h n m', h=H) - - return w, P, B, V - -def ssm(measure, N, R, H, **ssm_args): - """Dispatcher to create single SSM initialization - N: state size - R: rank (for DPLR parameterization) - H: number of independent SSM copies - """ - - if measure == "dplr": - w, P, B, V = dplr(N=N, rank=R, H=H, **ssm_args) - elif measure.startswith("diag"): - args = measure.split("-") - assert args[0] == "diag" and len(args) > 1 - scaling = args[1] - w, P, B, V = dplr(scaling=scaling, N=N, rank=R, H=H, diagonal=True, **ssm_args) - else: - w, P, B, V = hippo.nplr(measure, N, R, **ssm_args) - w = repeat(w, 'n -> s n', s=H) - P = repeat(P, 'r n -> r s n', s=H) - B = repeat(B, 'n -> s n', s=H) - V = repeat(V, 'n m -> s n m', s=H) - return w, P, B, V - -combinations = { - 'hippo': ['legs', 'fourier'], - 'diag': ['diag-inv', 'diag-lin'], - 'all': ['legs', 'fourier', 'diag-inv', 'diag-lin'], -} - -def combination(measures, N, R, S, **ssm_args): - if isinstance(measures, str): - measures = combinations[measures] if measures in combinations else [measures] - - assert S % len(measures) == 0, f"{S} independent trainable SSM copies must be multiple of {len(measures)} different measures" - w, P, B, V = zip( - *[ssm(measure, N, R, S // len(measures), **ssm_args) for measure in measures] - ) - w = torch.cat(w, dim=0) # (S N) - P = torch.cat(P, dim=1) # (R S N) - B = torch.cat(B, dim=0) # (S N) - V = torch.cat(V, dim=0) # (S N N) - return w, P, B, V diff --git a/src/clm/src/models/sequence/ssm/hippo.py b/src/clm/src/models/sequence/ssm/hippo.py deleted file mode 100644 index 07707b65..00000000 --- a/src/clm/src/models/sequence/ssm/hippo.py +++ /dev/null @@ -1,259 +0,0 @@ -# Copied from https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/hippo/hippo.py - -""" Definitions of A and B matrices for various HiPPO operators. """ - -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -from scipy import special as ss -from einops import rearrange, repeat -from opt_einsum import contract - -def embed_c2r(A): - A = rearrange(A, '... m n -> ... m () n ()') - A = np.pad(A, ((0, 0), (0, 1), (0, 0), (0, 1))) + \ - np.pad(A, ((0, 0), (1, 0), (0, 0), (1,0))) - return rearrange(A, 'm x n y -> (m x) (n y)') - -# TODO take in 'torch' option to return torch instead of numpy, and converts the shape of B from (N, 1) to (N) -def transition(measure, N, **measure_args): - """ A, B transition matrices for different measures - measure: the type of measure - legt - Legendre (translated) - legs - Legendre (scaled) - glagt - generalized Laguerre (translated) - lagt, tlagt - previous versions of (tilted) Laguerre with slightly different normalization - """ - # Laguerre (translated) - if measure == 'lagt': - b = measure_args.get('beta', 1.0) - A = np.eye(N) / 2 - np.tril(np.ones((N, N))) - B = b * np.ones((N, 1)) - # Generalized Laguerre - # alpha 0, beta small is most stable (limits to the 'lagt' measure) - # alpha 0, beta 1 has transition matrix A = [lower triangular 1] - elif measure == 'glagt': - alpha = measure_args.get('alpha', 0.0) - beta = measure_args.get('beta', 0.01) - A = -np.eye(N) * (1 + beta) / 2 - np.tril(np.ones((N, N)), -1) - B = ss.binom(alpha + np.arange(N), np.arange(N))[:, None] - - L = np.exp(.5 * (ss.gammaln(np.arange(N)+alpha+1) - ss.gammaln(np.arange(N)+1))) - A = (1./L[:, None]) * A * L[None, :] - B = (1./L[:, None]) * B * np.exp(-.5 * ss.gammaln(1-alpha)) * beta**((1-alpha)/2) - # Legendre (translated) - elif measure == 'legt': - Q = np.arange(N, dtype=np.float64) - R = (2*Q + 1) ** .5 - j, i = np.meshgrid(Q, Q) - A = R[:, None] * np.where(i < j, (-1.)**(i-j), 1) * R[None, :] - B = R[:, None] - A = -A - - # Halve again for timescale correctness - A *= 0.5 - B *= 0.5 - # LMU: equivalent to LegT up to normalization - elif measure == 'lmu': - Q = np.arange(N, dtype=np.float64) - R = (2*Q + 1)[:, None] # / theta - j, i = np.meshgrid(Q, Q) - A = np.where(i < j, -1, (-1.)**(i-j+1)) * R - B = (-1.)**Q[:, None] * R - # Legendre (scaled) - elif measure == 'legs': - q = np.arange(N, dtype=np.float64) - col, row = np.meshgrid(q, q) - r = 2 * q + 1 - M = -(np.where(row >= col, r, 0) - np.diag(q)) - T = np.sqrt(np.diag(2 * q + 1)) - A = T @ M @ np.linalg.inv(T) - B = np.diag(T)[:, None] - B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) - elif measure == 'legsd': - q = np.arange(N, dtype=np.float64) - col, row = np.meshgrid(q, q) - r = 2 * q + 1 - M = -(np.where(row >= col, r, 0) - np.diag(q)) - T = np.sqrt(np.diag(2 * q + 1)) - A = T @ M @ np.linalg.inv(T) - B = np.diag(T)[:, None] - B = B.copy() # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) - A += .5 * B*B[None, :, 0] - B = B / 2.0 - elif measure in ['fourier_diag', 'foud']: - freqs = np.arange(N//2) - d = np.stack([freqs, np.zeros(N//2)], axis=-1).reshape(-1)[:-1] - A = 2*np.pi*(-np.diag(d, 1) + np.diag(d, -1)) - A = A - .5 * np.eye(N) - B = np.zeros(N) - B[0::2] = 2**.5 - B[0] = 1 - B = B[:, None] - elif measure in ['fourier', 'fout']: - freqs = np.arange(N//2) - d = np.stack([np.zeros(N//2), freqs], axis=-1).reshape(-1)[1:] - A = np.pi*(-np.diag(d, 1) + np.diag(d, -1)) - B = np.zeros(N) - B[0::2] = 2**.5 - B[0] = 1 - - # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case - A = A - B[:, None] * B[None, :] - B = B[:, None] - elif measure == 'fourier_decay': - freqs = np.arange(N//2) - d = np.stack([np.zeros(N//2), freqs], axis=-1).reshape(-1)[1:] - A = np.pi*(-np.diag(d, 1) + np.diag(d, -1)) - B = np.zeros(N) - B[0::2] = 2**.5 - B[0] = 1 - - # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case - A = A - .5 * B[:, None] * B[None, :] - B = .5 * B[:, None] - elif measure == 'fourier2': # Double everything: orthonormal on [0, 1] - freqs = 2*np.arange(N//2) - d = np.stack([np.zeros(N//2), freqs], axis=-1).reshape(-1)[1:] - A = np.pi*(-np.diag(d, 1) + np.diag(d, -1)) - B = np.zeros(N) - B[0::2] = 2**.5 - B[0] = 1 - - # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case - A = A - B[:, None] * B[None, :] * 2 - B = B[:, None] * 2 - elif measure == 'random': - A = np.random.randn(N, N) / N - B = np.random.randn(N, 1) - elif measure == 'diagonal': - A = -np.diag(np.exp(np.random.randn(N))) - B = np.random.randn(N, 1) - else: - raise NotImplementedError - - return A, B - -def rank_correction(measure, N, rank=1, dtype=torch.float): - """ Return low-rank matrix L such that A + L is normal """ - - if measure == 'legs': - assert rank >= 1 - P = torch.sqrt(.5+torch.arange(N, dtype=dtype)).unsqueeze(0) # (1 N) - elif measure == 'legt': - assert rank >= 2 - P = torch.sqrt(1+2*torch.arange(N, dtype=dtype)) # (N) - P0 = P.clone() - P0[0::2] = 0. - P1 = P.clone() - P1[1::2] = 0. - P = torch.stack([P0, P1], dim=0) # (2 N) - P *= 2**(-0.5) # Halve the rank correct just like the original matrix was halved - elif measure == 'lagt': - assert rank >= 1 - P = .5**.5 * torch.ones(1, N, dtype=dtype) - elif measure in ['fourier', 'fout']: - P = torch.zeros(N) - P[0::2] = 2**.5 - P[0] = 1 - P = P.unsqueeze(0) - elif measure == 'fourier_decay': - P = torch.zeros(N) - P[0::2] = 2**.5 - P[0] = 1 - P = P.unsqueeze(0) - P = P / 2**.5 - elif measure == 'fourier2': - P = torch.zeros(N) - P[0::2] = 2**.5 - P[0] = 1 - P = 2**.5 * P.unsqueeze(0) - elif measure in ['fourier_diag', 'foud', 'legsd']: - P = torch.zeros(1, N, dtype=dtype) - else: raise NotImplementedError - - d = P.size(0) - if rank > d: - P = torch.cat([P, torch.zeros(rank-d, N, dtype=dtype)], dim=0) # (rank N) - return P - -def initial_C(measure, N, dtype=torch.float): - """ Return C that captures the other endpoint in the HiPPO approximation """ - - if measure == 'legt': - C = (torch.arange(N, dtype=dtype)*2+1)**.5 * (-1)**torch.arange(N) - elif measure == 'fourier': - C = torch.zeros(N) - C[0::2] = 2**.5 - C[0] = 1 - else: - C = torch.zeros(N, dtype=dtype) # (N) - - return C - - -def nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=True): - """ Return w, p, q, V, B such that - (w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V - i.e. A = V[w - p q^*]V^*, B = V B - """ - assert dtype == torch.float or dtype == torch.double - cdtype = torch.cfloat if dtype == torch.float else torch.cdouble - - A, B = transition(measure, N) - A = torch.as_tensor(A, dtype=dtype) # (N, N) - B = torch.as_tensor(B, dtype=dtype)[:, 0] # (N,) - - P = rank_correction(measure, N, rank=rank, dtype=dtype) # (r N) - AP = A + torch.sum(P.unsqueeze(-2)*P.unsqueeze(-1), dim=-3) - - # We require AP to be nearly skew-symmetric - _A = AP + AP.transpose(-1, -2) - if (err := torch.sum((_A - _A[0,0]*torch.eye(N))**2) / N) > 1e-5: # if not torch.allclose(_A - _A[0,0]*torch.eye(N), torch.zeros(N, N), atol=1e-5): - print("WARNING: HiPPO matrix not skew symmetric", err) - - - # Take advantage of identity + skew-symmetric form to calculate real and imaginary parts separately - # Imaginary part can use eigh instead of eig - w_re = torch.mean(torch.diagonal(AP), -1, keepdim=True) - - # Diagonalize in double precision - if diagonalize_precision: AP = AP.to(torch.double) - # w, V = torch.linalg.eig(AP) # (..., N) (..., N, N) - w_im, V = torch.linalg.eigh(AP*-1j) # (..., N) (..., N, N) - if diagonalize_precision: w_im, V = w_im.to(cdtype), V.to(cdtype) - w = w_re + 1j * w_im - # Check: V w V^{-1} = A - # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)) - - - # Only keep half of each conjugate pair - _, idx = torch.sort(w.imag) - w_sorted = w[idx] - V_sorted = V[:, idx] - - # There is an edge case when eigenvalues can be 0, which requires some machinery to handle - # We use a huge hack here: Assume only one pair is 0, and that it is the first row/column of A (only happens in Fourier case) - V = V_sorted[:, :N//2] - w = w_sorted[:N//2] - assert w[-2].abs() > 1e-4, "Only 1 zero eigenvalue allowed in diagonal part of A" - if w[-1].abs() < 1e-4: - V[:, -1] = 0. - V[0, -1] = 2**-0.5 - V[1, -1] = 2**-0.5 * 1j - - _AP = V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2) - if ((err := torch.sum((2*_AP.real-AP)**2)/N) > 1e-5): - print("Warning: Diagonalization of A matrix not numerically precise - error", err) - # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)) - - V_inv = V.conj().transpose(-1, -2) - - # C = initial_C(measure, N, dtype=dtype) - B = contract('ij, j -> i', V_inv, B.to(V)) # V^* B - # C = contract('ij, j -> i', V_inv, C.to(V)) # V^* C - P = contract('ij, ...j -> ...i', V_inv, P.to(V)) # V^* P - - # return w, P, B, C, V - return w, P, B, V diff --git a/src/clm/src/models/sequence/ssm/s4_simple.py b/src/clm/src/models/sequence/ssm/s4_simple.py deleted file mode 100644 index 2176b481..00000000 --- a/src/clm/src/models/sequence/ssm/s4_simple.py +++ /dev/null @@ -1,262 +0,0 @@ -import torch -import torch.nn as nn -from clm.src.models.nn import LinearActivation, Activation, DropoutNd -from einops import rearrange, repeat -import opt_einsum as oe - -import math -class OurModule(nn.Module): - def __init__(self): super().__init__() - - def register(self, name, tensor, trainable=False, lr=None, wd=None): - """Utility method: register a tensor as a buffer or trainable parameter""" - - if trainable: - self.register_parameter(name, nn.Parameter(tensor)) - else: - self.register_buffer(name, tensor) - - optim = {} - if trainable and lr is not None: optim["lr"] = lr - if trainable and wd is not None: optim["weight_decay"] = wd - if len(optim) > 0: setattr(getattr(self, name), "_optim", optim) - -# -# This is intended to match np.convolve(x,w)[:len(w)] -# That is, (u \ast v)[k] = sum_{j} u[k-j]v[j] -# Here y = (u \ask v) on return. -# We assume the inputs are: -# u (B H L) -# v (C H L) -# and we want to produce y that is (B C H L) -# - - -def fft_conv(u,v): - L = u.shape[-1] - u_f = torch.fft.rfft(u, n=2*L) # (B H L) - v_f = torch.fft.rfft(v, n=2*L) # (C H L) - - y_f = oe.contract('bhl,chl->bchl', u_f, v_f) - y = torch.fft.irfft(y_f, n=2*L)[..., :L] # (B C H L) - return y - -def normalize_param(a, method, norm_const=None): - if method == "l1": - if norm_const is not None: - return a/((1+norm_const)*torch.linalg.norm(a,ord=1,dim=2).unsqueeze(2)) - return a/torch.linalg.norm(a,ord=1,dim=2).unsqueeze(2) - if method == "l2": - return a/torch.linalg.norm(a,ord=2,dim=2).unsqueeze(2) - if method == "max": - return 0.1*a/torch.max(a,dim=2)[0].unsqueeze(2) - if method == "none": - return a - raise ValueError(f"{method} normalization not implemented") - -class SimpleS4(OurModule): - def __init__(self, - nHippos, - d_state=64, - channels=1, - use_initial=True, # Use the initial state? - zero_order_hold=False, # Use zero-order hold approximation - trap_rule=True, - dt_min=0.001, - dt_max=0.1, - lr=None, # Hook to set LR of SSM parameters differently - learn_a=True, - learn_theta=True, - learn_dt=False, # whether to learn separate dt for each hippo - theta_scale=False, - skip_connection=True, - repr='cont', # representation to use: ['cont','disc','comp'] - param_norm = 'none', # for normalizing parameters for stability - **kernel_args,): # Use the trapezoid rule - super().__init__() - # H is number of hippos - # D is the dimension (also shockingly n other places) - # B is the batch - # L is the length - self.h = nHippos - self.d = d_state // 2 - self.channels = channels - self.use_initial = use_initial - self.zero_order_hold = zero_order_hold - # - # Use the trapezoid rule correct or just do zero-order hold. - self.trap_rule = trap_rule - self.repr = repr - self.learn_dt = learn_dt - self.shift = 'shift' in self.repr - self.param_norm = param_norm - - _fp = (self.channels, self.h, self.d) - - # Chebyshev initialization - h_scale = torch.exp(torch.arange(self.h)/self.h * math.log(dt_max/dt_min)) - angles = torch.arange(self.d)*torch.pi - t_scale = h_scale if theta_scale else torch.ones(self.h) - theta = oe.contract('c,h,d->chd', torch.ones(self.channels), t_scale, angles) - if self.repr == 'disc': - # discrete diagonal representation - a = torch.randn(*_fp).abs() - #a = 2*torch.rand(*_fp)-1 # init randomly from [-1,1] - else: - # default continuous diagonal representation - a = -repeat(h_scale, 'h -> c h d', c=self.channels, d=self.d) - - self.register("theta", theta,learn_theta,lr=lr, wd=None) - self.register("a", a, learn_a,lr=lr, wd=None) - - if self.learn_dt: - log_dt = torch.rand(self.h) * ( - math.log(dt_max) - math.log(dt_min) - ) + math.log(dt_min) - self.register("log_dt", log_dt, True,lr=lr, wd=None) - - # The other maps - if not skip_connection: - self.register("D", torch.zeros((channels, self.h)), False) - else: - self.D = nn.Parameter(torch.randn(channels, self.h)) - - if use_initial or 'comp' in self.repr: - if self.shift: - b = torch.zeros(*_fp) - b[:,:,0] = 1 - self.register("b", b, False) - else: - self.b = nn.Parameter(torch.randn(*_fp)) - self.c = nn.Parameter(torch.randn(*_fp)) - self.x0 = nn.Parameter(torch.randn(*_fp)) - else: - # This is an optimization that we combine q = c * b - # It's as if we're setting x0 = 0. - self.q = nn.Parameter(torch.randn(*_fp)) - - - def quadrature_method(self, u, horizon): - # The input is now Batch x Hippos x Length - l = u.size(-1) - - dt = 1/(l-1) # the step size - if self.learn_dt: - dt = torch.exp(self.log_dt).view(1,-1,1, 1) - - # q and a are both C x H x D - # zk is of length l we want a C x H x L matrix - zk = dt*torch.arange(l, device=u.device).view(1,1,-1,1) - - if self.repr == 'disc': - # discrete diagonal representation - a_ = (self.a).abs() - base_term = 2 * dt * torch.pow(a_.unsqueeze(2), zk) * torch.cos(self.theta.unsqueeze(2) * zk) - else: - # continuous diagonal representation - a_ = self.a #/torch.linalg.norm(self.a,ord=1,dim=2).unsqueeze(2) - a_ = -a_.abs() - # a_ = -self.a.abs() - base_term = 2*dt*torch.exp(a_.unsqueeze(2) * zk)*torch.cos( self.theta.unsqueeze(2) * zk) - - q = self.b*self.c if self.use_initial else self.q - f = (q.unsqueeze(2)*base_term).sum(-1) - - y = fft_conv(u,f) - # Add in the skip connection with per-channel D matrix - y = y + oe.contract('bhl,ch->bchl', u, self.D) - # Add back the initial state - if self.use_initial: - y = y + (2*(self.c*self.x0).unsqueeze(2)*base_term).sum(-1) - - return rearrange(y, 'b c h l-> b (c h) l'), None # flatten the channels. - - def forward(self, u, horizon=None): - return self.quadrature_method(u, horizon) - - -# Below here are standard wrapper classes to handle -# (1) Non-linearity -# (2) Integration with the Hippo Code base -class NonLinear(nn.Module): - def __init__(self, h, channels, - ln=False, # Extra normalization - transposed=True, - dropout=0.0, - postact=None, # activation after FF - activation='gelu', # activation in between SS and FF - initializer=None, # initializer on FF - weight_norm=False, # weight normalization on FF - ): - super().__init__() - dropout_fn = DropoutNd # nn.Dropout2d bugged in PyTorch 1.11 - dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() - #norm = Normalization(h*channels, transposed=transposed) if ln else nn.Identity() - - activation_fn = Activation(activation) - - output_linear = LinearActivation( - h*channels, - h, - transposed=transposed, - initializer=initializer, - activation=postact, - activate=True, - weight_norm=weight_norm, - ) - #self.f = nn.Sequential(activation_fn, dropout, norm, output_linear) - self.f = nn.Sequential(activation_fn, dropout, output_linear) - def forward(self,x): # Always (B H L) - return self.f(x) - -class SimpleS4Wrapper(nn.Module): - def __init__( - self, - d_model, - d_state=64, - channels=1, - bidirectional=False, - dropout=0.0, - transposed=True, # axis ordering (B, L, D) or (B, D, L) - ln=True, # IGNORED: Extra normalization - postact=None, # activation after FF - activation='gelu', # activation in between SS and FF - initializer=None, # initializer on FF - weight_norm=False, # weight normalization on FF - linear=False, - # SSM Kernel arguments - **kernel_args, - ): - super().__init__() - self.h = d_model - self.d = d_state - self.channels = channels - #self.shift = shift - #self.linear = linear - self.out_d = self.h - self.transposed = transposed - self.bidirectional = bidirectional - assert not bidirectional, f"Bidirectional NYI" - self.s4 = SimpleS4(nHippos=d_model, d_state=d_state, - channels=channels, **kernel_args) - # the mapping - # We transpose if it's not in the forward. - nl = NonLinear(self.h, channels=self.channels, ln=ln, # Extra normalization - dropout=dropout, postact=postact, activation=activation, transposed=True, - initializer=initializer, weight_norm=weight_norm) - self.out = nn.Identity() if linear else nl - - def forward(self, u, *w, state=None, horizon=None): - # u: (B H L) if self.transposed else (B L H) - if not self.transposed: u = u.transpose(-1, -2) - # We only pass BHL, and it is as if transposed is True. - y, state = self.s4(u,horizon=horizon) - ret = self.out(y) - if not self.transposed: ret = ret.transpose(-1, -2) - return ret, state - - @property - def d_state(self): return self.h * self.d - - @property - def d_output(self): return self.out_d \ No newline at end of file diff --git a/src/clm/src/models/sequence/ssm/s4d.py b/src/clm/src/models/sequence/ssm/s4d.py deleted file mode 100644 index 643e1a55..00000000 --- a/src/clm/src/models/sequence/ssm/s4d.py +++ /dev/null @@ -1,404 +0,0 @@ -""" Standalone version of Structured (Sequence) State Space (S4) model. """ - - -import logging -from functools import partial -import math -import numpy as np -from scipy import special as ss -import torch -import torch.nn as nn -import torch.nn.functional as F -from pytorch_lightning.utilities import rank_zero_only -from einops import rearrange, repeat -import opt_einsum as oe - -contract = oe.contract -contract_expression = oe.contract_expression - - -_c2r = torch.view_as_real -_r2c = torch.view_as_complex -if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 10): - _resolve_conj = lambda x: x.conj().resolve_conj() -else: - _resolve_conj = lambda x: x.conj() - - - -""" simple nn.Module components """ - -def Activation(activation=None, dim=-1): - if activation in [ None, 'id', 'identity', 'linear' ]: - return nn.Identity() - elif activation == 'tanh': - return nn.Tanh() - elif activation == 'relu': - return nn.ReLU() - elif activation == 'gelu': - return nn.GELU() - elif activation in ['swish', 'silu']: - return nn.SiLU() - elif activation == 'glu': - return nn.GLU(dim=dim) - elif activation == 'sigmoid': - return nn.Sigmoid() - else: - raise NotImplementedError("hidden activation '{}' is not implemented".format(activation)) - -def LinearActivation( - d_input, d_output, bias=True, - transposed=False, - activation=None, - activate=False, # Apply activation as part of this module - **kwargs, - ): - """ Returns a linear nn.Module with control over axes order, initialization, and activation """ - - # Construct core module - linear_cls = partial(nn.Conv1d, kernel_size=1) if transposed else nn.Linear - if activation == 'glu': d_output *= 2 - linear = linear_cls(d_input, d_output, bias=bias, **kwargs) - - if activate and activation is not None: - activation = Activation(activation, dim=-2 if transposed else -1) - linear = nn.Sequential(linear, activation) - return linear - - -""" HiPPO utilities """ - -def random_dplr(N, H=1, scaling='inverse', real_scale=1.0, imag_scale=1.0): - dtype = torch.cfloat - - pi = torch.tensor(np.pi) - real_part = .5 * torch.ones(H, N//2) - imag_part = repeat(torch.arange(N//2), 'n -> h n', h=H) - - real_part = real_scale * real_part - if scaling == 'random': - imag_part = torch.randn(H, N//2) - elif scaling == 'linear': - imag_part = pi * imag_part - elif scaling == 'inverse': # Based on asymptotics of the default HiPPO matrix - imag_part = 1/pi * N * (N/(1+2*imag_part)-1) - else: raise NotImplementedError - imag_part = imag_scale * imag_part - w = -real_part + 1j * imag_part - - - B = torch.randn(H, N//2, dtype=dtype) - - norm = -B/w # (H, N) # Result if you integrate the kernel with constant 1 function - zeta = 2*torch.sum(torch.abs(norm)**2, dim=-1, keepdim=True) # Variance with a random C vector - B = B / zeta**.5 - - return w, B - - -class SSKernelDiag(nn.Module): - """ Version using (complex) diagonal state matrix. Note that it is slower and less memory efficient than the NPLR kernel because of lack of kernel support. - - """ - - def __init__( - self, - w, C, log_dt, - lr=None, - train_w = True, - train_dt = True, - **kwargs # For compatibility with other kernels - ): - - super().__init__() - - # Rank of low-rank correction - assert w.size(-1) == C.size(-1) - self.H = log_dt.size(-1) - self.N = w.size(-1) - assert self.H % w.size(0) == 0 - self.copies = self.H // w.size(0) - - # Broadcast everything to correct shapes - C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))) # (H, C, N) - - # Register parameters - self.C = nn.Parameter(_c2r(_resolve_conj(C))) - self.register("log_dt", log_dt, train_dt, lr, 0.0) - - log_w_real = torch.log(-w.real + 1e-4) - w_imag = w.imag - self.register("log_w_real", log_w_real, train_w, lr, 0.0) - self.register("w_imag", w_imag, train_w, lr, 0.0) - - - def _w(self): - # Get the internal w (diagonal) parameter - w_real = -torch.exp(self.log_w_real) - w_imag = self.w_imag - w = w_real + 1j * w_imag - w = repeat(w, 't n -> (v t) n', v=self.copies) # (H N) - return w - - def forward(self, L): - """ - returns: (..., c, L) where c is number of channels (default 1) - """ - - dt = torch.exp(self.log_dt) # (H) - C = _r2c(self.C) # (C H N) - w = self._w() # (H N) - - # Incorporate dt into A - dtA = w * dt.unsqueeze(-1) # (H N) - - # Power up - K = dtA.unsqueeze(-1) * torch.arange(L, device=w.device) # (H N L) - C = C * (torch.exp(dtA)-1.) / w - K = contract('chn, hnl -> chl', C, torch.exp(K)) - K = 2*K.real - # Keops version is more memory efficient - # C = C * (torch.exp(dtA)-1.) / w - # K = log_vandermonde(C, dtA, L) # (H L) - - return K - - def setup_step(self): - dt = torch.exp(self.log_dt) # (H) - C = _r2c(self.C) # (C H N) - w = self._w() # (H N) - - # Incorporate dt into A - dtA = w * dt.unsqueeze(-1) # (H N) - self.dA = torch.exp(dtA) # (H N) - self.dC = C * (torch.exp(dtA)-1.) / w # (C H N) - self.dB = self.dC.new_ones(self.H, self.N) # (H N) - - def default_state(self, *batch_shape): - C = _r2c(self.C) - state = torch.zeros(*batch_shape, self.H, self.N, dtype=C.dtype, device=C.device) - return state - - def step(self, u, state): - next_state = contract("h n, b h n -> b h n", self.dA, state) \ - + contract("h n, b h -> b h n", self.dB, u) - y = contract("c h n, b h n -> b c h", self.dC, next_state) - return 2*y.real, next_state - - - def register(self, name, tensor, trainable=False, lr=None, wd=None): - """Utility method: register a tensor as a buffer or trainable parameter""" - - if trainable: - self.register_parameter(name, nn.Parameter(tensor)) - else: - self.register_buffer(name, tensor) - - optim = {} - if trainable and lr is not None: - optim["lr"] = lr - if trainable and wd is not None: - optim["weight_decay"] = wd - if len(optim) > 0: - setattr(getattr(self, name), "_optim", optim) - -class S4DKernel(nn.Module): - """Wrapper around SSKernelDiag that generates the diagonal SSM parameters - """ - - def __init__( - self, - H, - N=64, - scaling="inverse", - channels=1, # 1-dim to C-dim map; can think of C as having separate "heads" - dt_min=0.001, - dt_max=0.1, - lr=None, # Hook to set LR of SSM parameters differently - n_ssm=1, # Copies of the ODE parameters A and B. Must divide H - **kernel_args, - ): - super().__init__() - self.N = N - self.H = H - dtype = torch.float - cdtype = torch.cfloat - self.channels = channels - self.n_ssm = n_ssm - - # Generate dt - log_dt = torch.rand(self.H, dtype=dtype) * ( - math.log(dt_max) - math.log(dt_min) - ) + math.log(dt_min) - - # Compute the preprocessed representation - # Generate low rank correction p for the measure - w, B = random_dplr(self.N, H=n_ssm, scaling=scaling) - - C = torch.randn(channels, self.H, self.N // 2, dtype=cdtype) - - # Broadcast tensors to n_ssm copies - # These will be the parameters, so make sure tensors are materialized and contiguous - B = repeat(B, 't n -> (v t) n', v=self.n_ssm // B.size(-2)).clone().contiguous() - w = repeat(w, 't n -> (v t) n', v=self.n_ssm // w.size(-2)).clone().contiguous() - - # Combine B and C using structure of diagonal SSM - C = C * repeat(B, 't n -> (v t) n', v=H//self.n_ssm) - self.kernel = SSKernelDiag( - w, C, log_dt, - lr=lr, - **kernel_args, - ) - - def forward(self, L=None): - k = self.kernel(L=L) - return k.float() - - def setup_step(self): - self.kernel.setup_step() - - def step(self, u, state, **kwargs): - u, state = self.kernel.step(u, state, **kwargs) - return u.float(), state - - def default_state(self, *args, **kwargs): - return self.kernel.default_state(*args, **kwargs) - - -class S4D(nn.Module): - - def __init__( - self, - d_model, - d_state=64, - channels=1, # maps 1-dim to C-dim - bidirectional=False, - # Arguments for FF - activation='gelu', # activation in between SS and FF - postact=None, # activation after FF - dropout=0.0, - transposed=True, # axis ordering (B, L, D) or (B, D, L) - return_state=True, # return state in addition to output - # SSM Kernel arguments - **kernel_args, - ): - """ - d_state: the dimension of the state, also denoted by N - channels: can be interpreted as a number of "heads" - bidirectional: bidirectional - dropout: standard dropout argument - transposed: choose backbone axis ordering of (B, L, H) or (B, H, L) [B=batch size, L=sequence length, H=hidden dimension] - - Other options are all experimental and should not need to be configured - """ - - super().__init__() - - self.h = d_model - self.n = d_state - self.bidirectional = bidirectional - self.channels = channels - self.transposed = transposed - self.return_state = return_state - - self.D = nn.Parameter(torch.randn(channels, self.h)) - - if self.bidirectional: - channels *= 2 - - # SSM Kernel - self.kernel = S4DKernel(self.h, N=self.n, channels=channels, **kernel_args) - - # Pointwise - self.activation = Activation(activation) - dropout_fn = nn.Dropout2d if self.transposed else nn.Dropout - self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() - - # position-wise output transform to mix features - self.output_linear = LinearActivation( - self.h*self.channels, - self.h, - transposed=self.transposed, - activation=postact, - activate=True, - ) - - - def forward(self, u, **kwargs): # absorbs return_output and transformer src mask - """ - u: (B H L) if self.transposed else (B L H) - state: (H N) never needed unless you know what you're doing - - Returns: same shape as u - """ - if not self.transposed: u = u.transpose(-1, -2) - L = u.size(-1) - - # Compute SS Kernel - k = self.kernel(L=L) # (C H L) (B C H L) - - # Convolution - if self.bidirectional: - k0, k1 = rearrange(k, '(s c) h l -> s c h l', s=2) - k = F.pad(k0, (0, L)) \ - + F.pad(k1.flip(-1), (L, 0)) \ - - k_f = torch.fft.rfft(k, n=2*L) # (C H L) - u_f = torch.fft.rfft(u, n=2*L) # (B H L) - y_f = contract('bhl,chl->bchl', u_f, k_f) # k_f.unsqueeze(-4) * u_f.unsqueeze(-3) # (B C H L) - y = torch.fft.irfft(y_f, n=2*L)[..., :L] # (B C H L) - - - # Compute D term in state space equation - essentially a skip connection - y = y + contract('bhl,ch->bchl', u, self.D) # u.unsqueeze(-3) * self.D.unsqueeze(-1) - - # Reshape to flatten channels - y = rearrange(y, '... c h l -> ... (c h) l') - - y = self.dropout(self.activation(y)) - - if not self.transposed: y = y.transpose(-1, -2) - - y = self.output_linear(y) - - if self.return_state: - return y, None # Return a None to satisfy this repo's interface, but this can be modified - else: - return y - - def setup_step(self): - self.kernel.setup_step() - - def step(self, u, state): - """ Step one time step as a recurrent model. Intended to be used during validation. - - u: (B H) - state: (B H N) - Returns: output (B H), state (B H N) - """ - assert not self.training - - y, next_state = self.kernel.step(u, state) # (B C H) - y = y + u.unsqueeze(-2) * self.D - y = rearrange(y, '... c h -> ... (c h)') - y = self.activation(y) - if self.transposed: - y = self.output_linear(y.unsqueeze(-1)).squeeze(-1) - else: - y = self.output_linear(y) - return y, next_state - - def default_state(self, *batch_shape, device=None): - return self.kernel.default_state(*batch_shape) - - @property - def d_state(self): - return self.h * self.n - - @property - def d_output(self): - return self.h - - @property - def state_to_tensor(self): - return lambda state: rearrange('... h n -> ... (h n)', state) diff --git a/src/clm/src/models/sequence/ssm/ss_kernel.py b/src/clm/src/models/sequence/ssm/ss_kernel.py deleted file mode 100644 index b0079898..00000000 --- a/src/clm/src/models/sequence/ssm/ss_kernel.py +++ /dev/null @@ -1,180 +0,0 @@ -# TD: [2023-01-05]: Extracted the SSKernel class from -# https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/sequence/ss/kernel.py -# We add option to use the shift kernel, and remove the option of SSKernelNPLR - -"""SSM convolution kernels. -SSKernel wraps different kernels with common options and handles the initialization. -""" - -import math -import torch -import torch.nn as nn -import torch.nn.functional as F - -from einops import rearrange, repeat -from opt_einsum import contract - -from clm.src.models.sequence.ssm.ss_kernel_diag import SSKernelDiag, EMAKernel -from clm.src.models.sequence.ssm.ss_kernel_shift import SSKernelShift - -from clm.src.models.sequence.ssm import hippo -from clm.src.models.sequence.ssm import dplr -from clm.src.ops.krylov import power - -from clm.src.utils.train import get_logger - -log = get_logger(__name__) - - -_conj = lambda x: torch.cat([x, x.conj()], dim=-1) - - -class SSKernel(nn.Module): - """Wrapper around SSKernel parameterizations. - - The SSKernel is expected to support the interface - forward() - default_state() - _setup_step() - step() - """ - - def __init__( - self, - H, - N=64, - L=None, - measure="diag-lin", - rank=1, - channels=1, - dt_min=0.001, - dt_max=0.1, - deterministic=False, - lr=None, - mode="diag", - n_ssm=None, - verbose=False, - measure_args={}, - **kernel_args, - ): - """State Space Kernel which computes the convolution kernel $\\bar{K}$ - - H: Number of independent SSM copies; controls the size of the model. Also called d_model in the config. - N: State size (dimensionality of parameters A, B, C). Also called d_state in the config. Generally shouldn't need to be adjusted and doens't affect speed much. - L: Maximum length of convolution kernel, if known. Should work in the majority of cases even if not known. - measure: Options for initialization of (A, B). For NPLR mode, recommendations are "legs", "fout", "hippo" (combination of both). For Diag mode, recommendations are "diag-inv", "diag-lin", "diag-legs", and "diag" (combination of diag-inv and diag-lin) - rank: Rank of low-rank correction for NPLR mode. Needs to be increased for measure "legt" - channels: C channels turns the SSM from a 1-dim to C-dim map; can think of it having C separate "heads" per SSM. This was partly a feature to make it easier to implement bidirectionality; it is recommended to set channels=1 and adjust H to control parameters instead - dt_min, dt_max: min and max values for the step size dt (\Delta) - mode: Which kernel algorithm to use. 'nplr' is the full S4 model; 'diag' is the simpler S4D; 'slow' is a dense version for testing - n_ssm: Number of independent trainable (A, B) SSMs, e.g. n_ssm=1 means all A/B parameters are tied across the H different instantiations of C. n_ssm=None means all H SSMs are completely independent. Generally, changing this option can save parameters but doesn't affect performance or speed much. This parameter must divide H - lr: Passing in a number (e.g. 0.001) sets attributes of SSM parameers (A, B, dt). A custom optimizer hook is needed to configure the optimizer to set the learning rates appropriately for these parameters. - """ - super().__init__() - self.N = N - self.H = H - dtype, cdtype = torch.float, torch.cfloat - self.channels = channels - self.n_ssm = n_ssm if n_ssm is not None else H - self.mode = mode - self.verbose = verbose - self.kernel_args = kernel_args - - # Generate dt - if deterministic: - log_dt = torch.exp(torch.linspace(math.log(dt_min), math.log(dt_max), H)) - else: - log_dt = torch.rand(self.H, dtype=dtype) * ( - math.log(dt_max) - math.log(dt_min) - ) + math.log(dt_min) - - # Compute the preprocessed representation - if mode == "ema": - self.kernel = EMAKernel(H, N=N, channels=channels, **kernel_args) - else: - w, P, B, V = dplr.combination(measure, self.N, rank, self.n_ssm, **measure_args) - - # Broadcast C to have H channels - if deterministic: - C = torch.zeros(channels, self.n_ssm, self.N, dtype=cdtype) - C[:, :, :1] = 1. - C = contract('hmn, chn -> chm', V.conj().transpose(-1, -2), C) # V^* C - C = repeat(C, 'c t n -> c (v t) n', v=self.n_ssm // C.size(-2)).clone().contiguous() - else: - C = torch.randn(channels, self.H, self.N//2, dtype=cdtype) - - # Broadcast other parameters to have n_ssm copies - assert self.n_ssm % B.size(-2) == 0 \ - and self.n_ssm % P.size(-2) == 0 \ - and self.n_ssm % w.size(-2) == 0 - # Broadcast tensors to n_ssm copies - # These will be the parameters, so make sure tensors are materialized and contiguous - B = repeat(B, 't n -> (v t) n', v=self.n_ssm // B.size(-2)).clone().contiguous() - P = repeat(P, 'r t n -> r (v t) n', v=self.n_ssm // P.size(-2)).clone().contiguous() - w = repeat(w, 't n -> (v t) n', v=self.n_ssm // w.size(-2)).clone().contiguous() - - if mode == "diag": - if not measure.startswith("diag"): - log.warning("Diagonal kernel (S4D) activated but initialization is not intended for S4D. Set `measure` to 'diag-lin', 'diag-inv', or 'diag-legs' for the main variants, or 'diag' for a combination of S4D-Lin and S4D-Inv.") - C = C * repeat(B, 't n -> (v t) n', v=H//self.n_ssm) - self.kernel = SSKernelDiag( - w, B, C, log_dt, L=L, - lr=lr, - **kernel_args, - ) - elif mode == 'shift': - # Initializing B to be e_1 - B = torch.zeros(self.H, self.N) - B[..., 0] = 1.0 - # Match torch.Conv1d init - C = torch.randn(self.H, self.channels, self.N) - nn.init.kaiming_uniform_(C, a=math.sqrt(5)) - C = rearrange(C, 'h c n -> c h n') - self.kernel = SSKernelShift(B, C, L=L, lr=lr, **kernel_args) - else: - raise NotImplementedError(f"{mode=} is not valid") - - def forward(self, state=None, L=None, rate=None): - return self.kernel(state=state, L=L, rate=rate) - - @torch.no_grad() - def forward_state(self, u, state): - """ Forward the state through a sequence, i.e. computes the state after passing chunk through SSM - - state: (B, H, N) - u: (B, H, L) - - Returns: (B, H, N) - """ - - if hasattr(self.kernel, "forward_state"): - return self.kernel.forward_state(u, state) - - dA, dB = self.kernel._setup_state() # Construct dA, dB matrices - # dA, dB = self.kernel.dA, self.kernel.dB # (H N N) (H N) - - conj = state.size(-1) != dA.size(-1) - if conj: state = _conj(state) - - v = contract('h n, b h l -> b h n l', dB, u.flip(-1)) # dB.unsqueeze(-1) * u.flip(-1).unsqueeze(-2) - AL, v = power(u.size(-1), dA, v) - next_state = contract("h m n, b h n -> b h m", AL, state) - next_state = next_state + v - - if conj: next_state = next_state[..., : next_state.size(-1) // 2] - return next_state - - def _setup_step(self, **kwargs): - # This method is intended to be private so that setting up an S4 module with - # ``` - # if hasattr(module, 'setup_step'): module.setup_step() - # ``` - # will not trigger this method multiple times - self.kernel._setup_step(**kwargs) - - def step(self, u, state, **kwargs): - y, state = self.kernel.step(u, state, **kwargs) - return y, state - - def default_state(self, *args, **kwargs): - return self.kernel.default_state(*args, **kwargs) diff --git a/src/clm/src/models/sequence/ssm/ss_kernel_diag.py b/src/clm/src/models/sequence/ssm/ss_kernel_diag.py deleted file mode 100644 index 49ab0118..00000000 --- a/src/clm/src/models/sequence/ssm/ss_kernel_diag.py +++ /dev/null @@ -1,331 +0,0 @@ -# TD: [2023-01-05]: Extracted the SSKernelDiag class from -# https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/sequence/ss/kernel.py -# We make a small change to use the log_vandermonde CUDA code. - -"""SSKernelDiag is the S4D kernel, a simpler algorithm for computing the kernel for the case of diagonal state matrices A. -""" -import math -import torch -import torch.nn as nn -import torch.nn.functional as F - -from einops import rearrange, repeat -from opt_einsum import contract - -from clm.src.utils.train import OptimModule - -from clm.src.utils.train import get_logger - -log = get_logger(__name__) - -# This could be None if the CUDA import fails -from clm.src.ops.vandermonde import log_vandermonde_fast -try: - import pykeops - from clm.src.ops.vandermonde import log_vandermonde, log_vandermonde_transpose - has_pykeops = True - log.info("Pykeops installation found.") -except ImportError: - has_pykeops = False - from clm.src.ops.vandermonde import log_vandermonde_naive as log_vandermonde - from clm.src.ops.vandermonde import log_vandermonde_transpose_naive as log_vandermonde_transpose - log.warning( - "Falling back on slow Vandermonde kernel. Install pykeops for improved memory efficiency." - ) - - -_c2r = torch.view_as_real -_r2c = torch.view_as_complex - -if tuple(map(int, torch.__version__.split('.')[:2])) >= (1, 10): - _resolve_conj = lambda x: x.conj().resolve_conj() -else: - _resolve_conj = lambda x: x.conj() - - -class SSKernelDiag(OptimModule): - """Version using (complex) diagonal state matrix (S4D)""" - - def __init__( - self, - A, B, C, log_dt, - L=None, - disc='bilinear', - real_type='exp', - lr=None, - bandlimit=None, - force_real=False, - ): - - super().__init__() - self.L = L - self.disc = disc - self.bandlimit = bandlimit - self.real_type = real_type - self.force_real = force_real - - # Rank of low-rank correction - assert A.size(-1) == C.size(-1) - self.H = log_dt.size(-1) - self.N = A.size(-1) - assert A.size(-2) == B.size(-2) # Number of independent SSMs trained - assert self.H % A.size(-2) == 0 - self.n_ssm = A.size(-2) - self.repeat = self.H // A.size(0) - - self.channels = C.shape[0] - self.C = nn.Parameter(_c2r(_resolve_conj(C))) - - # Register parameters - if lr is None or isinstance(lr, float): lr_dict = {} - else: lr_dict, lr = lr, None - - self.register("log_dt", log_dt, lr_dict.get('dt', lr)) - self.register("B", _c2r(B), lr_dict.get('B', lr)) - self.register("inv_A_real", self._A_init(A.real), lr_dict.get('A', lr)) - self.register("A_imag", A.imag, lr_dict.get('A', lr)) - - def _A_init(self, A_real): - A_real = torch.clamp(A_real, max=-1e-4) - if self.real_type == 'none': - return -A_real - elif self.real_type == 'exp': - return torch.log(-A_real) # Some of the HiPPO methods have real part 0 - elif self.real_type == 'relu': - return -A_real - elif self.real_type == 'sigmoid': - return torch.logit(-A_real) - elif self.real_type == 'softplus': - return torch.log(torch.exp(-A_real)-1) - else: raise NotImplementedError - - def _A(self): - # Get the internal A (diagonal) parameter - if self.real_type == 'none': - A_real = -self.inv_A_real - elif self.real_type == 'exp': - A_real = -torch.exp(self.inv_A_real) - elif self.real_type == 'relu': - # JAX version seems to NaN if you alloA 0's, although this code Aas fine Aithout it - A_real = -F.relu(self.inv_A_real)-1e-4 - elif self.real_type == 'sigmoid': - A_real = -F.sigmoid(self.inv_A_real) - elif self.real_type == 'softplus': - A_real = -F.softplus(self.inv_A_real) - else: raise NotImplementedError - A = A_real + 1j * self.A_imag - return A - - def forward(self, L, state=None, rate=1.0, u=None): - """ - state: (B, H, N) initial state - rate: sampling rate factor - L: target length - returns: - (C, H, L) convolution kernel (generally C=1) - (B, H, L) output from initial state - """ - - dt = torch.exp(self.log_dt) * rate # (H) - C = _r2c(self.C) # (C H N) - A = self._A() # (H N) - - B = _r2c(self.B) - B = repeat(B, 't n -> 1 (v t) n', v=self.repeat) - - # Force A to be real valued, so the whole kernel can be interpreted as a "multi-head EMA" - if self.force_real: - A = A.real + 0j - - if self.bandlimit is not None: - freqs = dt[:, None] / rate * A.imag.abs() / (2*math.pi) # (H, N) - mask = torch.where(freqs < self.bandlimit * .5, 1, 0) - C = C * mask - - # Incorporate dt into A - A = repeat(A, 't n -> (v t) n', v=self.repeat) - dtA = A * dt.unsqueeze(-1) # (H N) - - - # Augment B with state - if state is not None: - s = state / dt.unsqueeze(-1) - if self.disc == 'bilinear': - s = s * (1. + dtA/2) - elif self.disc == 'zoh': - s = s * dtA * dtA.exp() / (dtA.exp() - 1.) - B = torch.cat([s, B], dim=-3) # (1+B H N) - - C = (B[:, None, :, :] * C).view(-1, self.H, self.N) - if self.disc == 'zoh': - # Power up - C = C * (torch.exp(dtA)-1.) / A - # TODO (TD): make it work for C.shape[0] > 1 - if log_vandermonde_fast is not None and C.shape[0] == 1: - K = log_vandermonde_fast(C.squeeze(0), dtA, L).unsqueeze(0) # (H L) - else: - K = log_vandermonde(C, dtA, L) # (H L) - elif self.disc == 'bilinear': - C = C * (1. - dtA/2).reciprocal() * dt.unsqueeze(-1) # or * dtA / A - dA = (1. + dtA/2) / (1. - dtA/2) - if log_vandermonde_fast is not None: - dA_log = repeat(dA.log(), 'h d -> (c h) d', c=C.shape[0]) - K = rearrange(log_vandermonde_fast(rearrange(C, 'c h d -> (c h) d'), dA_log, L), - '(c h) d -> c h d', c=C.shape[0]) - else: - K = log_vandermonde(C, dA.log(), L) - elif self.disc == 'dss': - # Implementation from DSS meant for case when real eigenvalues can be positive - P = dtA.unsqueeze(-1) * torch.arange(L, device=C.device) # [H N L] - A_gt_0 = A.real > 0 # [N] - if A_gt_0.any(): - with torch.no_grad(): - P_max = dtA * (A_gt_0 * (L-1)) # [H N] - P = P - P_max.unsqueeze(-1) # [H N L] - S = P.exp() # [H N L] - - dtA_neg = dtA * (1 - 2*A_gt_0) # [H N] - num = dtA_neg.exp() - 1 # [H N] - den = (dtA_neg * L).exp() - 1 # [H N] - - # Inline reciprocal function for DSS logic - x = den * A - x_conj = _resolve_conj(x) - r = x_conj / (x*x_conj + 1e-7) - - C = C * num * r # [C H N] - K = contract('chn,hnl->chl', C, S).float() - else: assert False, f"{self.disc} not supported" - - K = K.view(-1, self.channels, self.H, L) # (1+B C H L) - if state is not None: - K_state = K[:-1, :, :, :] # (B C H L) - else: - K_state = None - K = K[-1, :, :, :] # (C H L) - return K, K_state - - def _setup_step(self): - # These methods are organized like this to be compatible with the NPLR kernel interface - dt = torch.exp(self.log_dt) # (H) - B = _r2c(self.B) # (H N) - C = _r2c(self.C) # (C H N) - self.dC = C - A = self._A() # (H N) - - A = repeat(A, 't n -> (v t) n', v=self.repeat) - B = repeat(B, 't n -> (v t) n', v=self.repeat) - - # Incorporate dt into A - dtA = A * dt.unsqueeze(-1) # (H N) - if self.disc == 'zoh': - self.dA = torch.exp(dtA) # (H N) - self.dB = B * (torch.exp(dtA)-1.) / A # (C H N) - elif self.disc == 'bilinear': - self.dA = (1. + dtA/2) / (1. - dtA/2) - self.dB = B * (1. - dtA/2).reciprocal() * dt.unsqueeze(-1) # or * dtA / A - - - def default_state(self, *batch_shape): - C = _r2c(self.C) - state = torch.zeros(*batch_shape, self.H, self.N, dtype=C.dtype, device=C.device) - return state - - def step(self, u, state): - next_state = contract("h n, b h n -> b h n", self.dA, state) \ - + contract("h n, b h -> b h n", self.dB, u) - y = contract("c h n, b h n -> b c h", self.dC, next_state) - return 2*y.real, next_state - - def forward_state(self, u, state): - self._setup_step() - AL = self.dA ** u.size(-1) - u = u.flip(-1).to(self.dA).contiguous() # (B H L) - v = log_vandermonde_transpose(u, self.dB, self.dA.log(), u.size(-1)) - next_state = AL * state + v - return next_state - - -class EMAKernel(OptimModule): - """Translation of Mega's MultiHeadEMA. - This is a minimal implementation of the convolution kernel part of the module. - This module, together with the main S4 block in clm.src.models.sequence.ss.s4 - (which is really just a fft-conv wrapper around any convolution kernel, - such as this one), should be exactly equivalent to using the original Mega - EMA module in clm.src.models.sequence.ss.ema. - Two additional flags have been provided to resolve discrepencies in parameter - count between S4(D) and EMA - - `dt_tie` makes the shape of the step size \Delta (H, 1) instead of (H, N) - - `efficient_bidirectional` ties the A/B/dt parameters for the conv kernels - in both forwards and backwards directions. This should have exactly the same - speed, slightly more parameter efficiency, and unchanged performance. - """ - - def __init__( - self, - H, - N=2, - channels=1, - l_max=None, - dt_tie=False, - efficient_bidirectional=False, - ): - super().__init__() - - self.H = H - self.N = N - self.channels = channels - self.l_max = l_max - self.scale = math.sqrt(1.0 / self.N) - - # Exactly match the parameter count of S4(D) when bididirectional is on - self.efficient_bidirectional = efficient_bidirectional - if self.efficient_bidirectional: - H_C = H * channels - else: - H *= channels - H_C = H - - self.delta = nn.Parameter(torch.Tensor(H, 1 if dt_tie else N, 1)) - self.alpha = nn.Parameter(torch.Tensor(H, N, 1)) - self.beta = nn.Parameter(torch.Tensor(H, N, 1)) - self.gamma = nn.Parameter(torch.Tensor(H_C, N)) - # self.omega = nn.Parameter(torch.Tensor(H)) # D skip connection handled by outside class - - self.reset_parameters() - - def reset_parameters(self): - with torch.no_grad(): - nn.init.normal_(self.delta, mean=0.0, std=0.2) - nn.init.normal_(self.alpha, mean=0.0, std=0.2) - # Mega comment: beta [1, -1, 1, -1, ...] seems more stable. - val = torch.ones(self.N, 1) - if self.N > 1: - idx = torch.tensor(list(range(1, self.N, 2))) - val.index_fill_(0, idx, -1.0) - self.beta.normal_(mean=0.0, std=0.02).add_(val) - nn.init.normal_(self.gamma, mean=0.0, std=1.0) - # nn.init.normal_(self.omega, mean=0.0, std=1.0) - - def coeffs(self): # Same as discretize - p = torch.sigmoid(self.delta) # (H N 1) - alpha = torch.sigmoid(self.alpha) - q = 1.0 - p * alpha - return p, q - - def forward(self, L=None, state=None, rate=1.0): - L = L if self.l_max is None else min(self.l_max, L) - p, q = self.coeffs() # (H N 1) - vander = torch.arange(L).to(p).view(1, 1, L) * torch.log(q) # (H N L) - kernel = (p * self.beta) * torch.exp(vander) - if self.efficient_bidirectional: - C = rearrange(self.gamma * self.scale, '(c h) n -> c h n', c=self.channels) - kernel = torch.einsum('dnl,cdn->cdl', kernel, C) - # kernel = rearrange(kernel, 'c d l -> (c d) l') - else: - kernel = torch.einsum('dnl,dn->dl', kernel, self.gamma * self.scale) - kernel = rearrange(kernel, '(c h) l -> c h l', c=self.channels) - - kernel = kernel[..., :L] - # kernel = rearrange(kernel, '(c h) l -> c h l', c=self.channels) - return kernel, None # k_state diff --git a/src/clm/src/models/sequence/ssm/ss_kernel_shift.py b/src/clm/src/models/sequence/ssm/ss_kernel_shift.py deleted file mode 100644 index b926297a..00000000 --- a/src/clm/src/models/sequence/ssm/ss_kernel_shift.py +++ /dev/null @@ -1,83 +0,0 @@ -# TD: [2023-01-05]: Extracted the SSKernelDiag class from -# https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/sequence/ss/kernel.py -# We make a small change to use the log_vandermonde CUDA code. - -"""SSKernelDiag is the S4D kernel, a simpler algorithm for computing the kernel for the case of diagonal state matrices A. -""" -import math -import torch -import torch.nn as nn -import torch.nn.functional as F - -from einops import rearrange, repeat -from opt_einsum import contract - -from clm.src.utils.train import OptimModule - - -class SSKernelShift(OptimModule): - - def __init__(self, B, C, L=None, lr=None, **kwargs): - """ - B: (H, d), real - C: (channel, H, d), real - """ - super().__init__() - self.L = L - self.N = B.size(-1) - self.H = B.shape[0] - - # Register parameters - if lr is None or isinstance(lr, float): lr_dict = {} - else: lr_dict, lr = lr, None - self.register("B", B, lr_dict.get('B', lr)) - self.C = nn.Parameter(C) - - def forward(self, state=None, rate=1.0, L=None): - if L is None: - L = self.L - # This class doesn't support variable length functionalities, since it's a discrete SSM - assert rate == 1.0 and L is not None - - # Augment B with state - B = self.B - if state is not None: - B = rearrange(torch.cat([rearrange(B, 'h n -> 1 h n'), state], dim=-3), - 'bp1 h n -> bp1 1 h n') # (1 + B, 1, H, N) - B_f = torch.fft.rfft(B, n=2*self.N) - C_f = torch.fft.rfft(self.C, n=2*self.N) - k = torch.fft.irfft(B_f.conj() * C_f, n=2*self.N)[..., :min(self.N, L)] - # If self.N < L, need to pad with zeros to reach length L - if self.N < L: - k = F.pad(k, (0, L - self.N)) - k = k.float() # Otherwise it could be dtype half - if state is not None: - k, k_state = k[0], k[1:] - else: - k_state = None - return k, k_state - - def _setup_step(self): - # Just here to conform to the interface, eventually we should refactor out - pass - - def default_state(self, *batch_shape): - return torch.zeros(*batch_shape, self.H, self.N, dtype=self.C.dtype, device=self.C.device) - - def step(self, u, state): - """u: (B, H), state: (B, H, N)""" - next_state = F.pad(state, (1, -1)) + contract("h n, b h -> b h n", self.B, u) - y = contract("c h n, b h n -> b c h", self.C, next_state) - return y, next_state - - def forward_state(self, u, state): - """u: (B, H, L), state: (B, H, N)""" - L = u.shape[-1] - B_f = torch.fft.rfft(self.B, n=2 * self.N) - u_f = torch.fft.rfft(u[..., -self.N:].flip(-1).to(dtype=self.B.dtype), n=2 * self.N) - v = torch.fft.irfft(B_f * u_f, n=2 * self.N)[..., :self.N] - if L < self.N: - next_state = F.pad(state, (L, -L)) + v - else: - next_state = v - return next_state diff --git a/src/clm/src/ops/fftconv.py b/src/clm/src/ops/fftconv.py deleted file mode 100644 index b5d2749b..00000000 --- a/src/clm/src/ops/fftconv.py +++ /dev/null @@ -1,103 +0,0 @@ -import math - -import torch -import torch.nn.functional as F - -from einops import rearrange - -from fftconv import fftconv_fwd, fftconv_bwd - -@torch.jit.script -def _mul_sum(y, q): - return (y * q).sum(dim=1) - -# reference convolution with residual connection -def fftconv_ref(u, k, D, dropout_mask, gelu=True, k_rev=None): - seqlen = u.shape[-1] - fft_size = 2 * seqlen - k_f = torch.fft.rfft(k, n=fft_size) / fft_size - if k_rev is not None: - k_rev_f = torch.fft.rfft(k_rev, n=fft_size) / fft_size - k_f = k_f + k_rev_f.conj() - u_f = torch.fft.rfft(u.to(dtype=k.dtype), n=fft_size) - y = torch.fft.irfft(u_f * k_f, n=fft_size, norm='forward')[..., :seqlen] - out = y + u * D.unsqueeze(-1) - if gelu: - out = F.gelu(out) - if dropout_mask is not None: - return (out * rearrange(dropout_mask, 'b H -> b H 1')).to(dtype=u.dtype) - else: - return out.to(dtype=u.dtype) - -# reference H3 forward pass -def fftconv_h3_ref(k, ssm_kernel, D, q, v, head_dim=1, ssm_kernel_rev=None): - seqlen = k.shape[-1] - fft_size = 2 * seqlen - kv = (rearrange(k, 'b (h d1) l -> b d1 1 h l', d1=head_dim) - * rearrange(v, 'b (h d2) l -> b 1 d2 h l', d2=head_dim)) # b d1 d2 h l - kv_f = torch.fft.rfft(kv.to(dtype=ssm_kernel.dtype), n=fft_size) / fft_size - ssm_kernel_f = torch.fft.rfft(ssm_kernel, n=fft_size) # h L+1 - if ssm_kernel_rev is not None: - ssm_kernel_rev_f = torch.fft.rfft(ssm_kernel_rev, n=fft_size) # h L+1 - ssm_kernel_f = ssm_kernel_f + ssm_kernel_rev_f.conj() - y = torch.fft.irfft(kv_f * ssm_kernel_f, n=fft_size, norm='forward')[..., :seqlen] # b d1 d2 h l - out = y + kv * D.unsqueeze(-1) # b d1 d2 h l - q = rearrange(q, 'b (h d1) l -> b d1 1 h l', d1=head_dim) - if head_dim > 1: - out = _mul_sum(out, q) - return rearrange(out, 'b d2 h l -> b (h d2) l').to(dtype=k.dtype) - else: - return rearrange(out * q, 'b 1 1 h l -> b h l').to(dtype=k.dtype) - - -class FFTConvFunc(torch.autograd.Function): - - @staticmethod - def forward(ctx, u, k, D, dropout_mask=None, gelu=True, force_fp16_output=False, - output_hbl_layout=False, v=None, head_dim=1, q=None, fftfp16=False, k_rev=None): - seqlen = u.shape[-1] - fft_size = max(2 * 2 ** int(math.ceil(math.log2(seqlen))), 16) - k_f = torch.fft.rfft(k, n=fft_size) - if k_rev is not None: - k_f = k_f + torch.fft.rfft(k_rev, n=fft_size).conj() - if u.stride(-1) != 1: - u = u.contiguous() - k_f = k_f.contiguous() - D = D.contiguous() - if v is not None and v.stride(-1) != 1: - v = v.contiguous() - if q is not None and q.stride(-1) != 1: - q = q.contiguous() - if dropout_mask is not None: - dropout_mask = dropout_mask.contiguous() - ctx.save_for_backward(u, k_f, D, dropout_mask, v, q) - ctx.output_hbl_layout = output_hbl_layout - ctx.head_dim = head_dim - ctx.gelu = gelu - ctx.fftfp16 = fftfp16 - ctx.has_k_rev = k_rev is not None - out = fftconv_fwd(u, k_f, D, v, head_dim, q, dropout_mask, gelu, False, False, fft_size, force_fp16_output, output_hbl_layout, fftfp16) - return out - - @staticmethod - def backward(ctx, dout): - if ctx.output_hbl_layout: - dout = rearrange(rearrange(dout, 'b h l -> h b l').contiguous(), 'h b l -> b h l') - else: - dout = dout.contiguous() - u, k_f, D, dropout_mask, v, q = ctx.saved_tensors - seqlen = u.shape[-1] - fft_size = max(2 * 2 ** int(math.ceil(math.log2(seqlen))), 16) - du, dk_f, dD, dv, dq = fftconv_bwd(dout, u, k_f, D, v, ctx.head_dim, q, dropout_mask, ctx.gelu, False, False, fft_size, - ctx.output_hbl_layout, ctx.fftfp16) - dk = torch.fft.irfft(dk_f, n=fft_size, norm='forward')[..., :seqlen] - dk_rev = (None if not ctx.has_k_rev - else torch.fft.irfft(dk_f.conj(), n=fft_size, norm='forward')[..., :seqlen]) - if v is not None: - dv = dv.to(dtype=v.dtype) # We do atomicAdd in fp32 so might need to convert to fp16 - return du, dk, dD, None, None, None, None, dv if v is not None else None, None, dq if q is not None else None, None, dk_rev - -def fftconv_func(u, k, D, dropout_mask=None, gelu=True, force_fp16_output=False, - output_hbl_layout=False, v=None, head_dim=1, q=None, fftfp16=False, k_rev=None): - return FFTConvFunc.apply(u, k, D, dropout_mask, gelu, force_fp16_output, - output_hbl_layout, v, head_dim, q, fftfp16, k_rev) diff --git a/src/clm/src/ops/krylov.py b/src/clm/src/ops/krylov.py deleted file mode 100644 index 34544252..00000000 --- a/src/clm/src/ops/krylov.py +++ /dev/null @@ -1,198 +0,0 @@ -""" Compute a Krylov function efficiently. (S4 renames the Krylov function to a "state space kernel") - -A : (N, N) -b : (N,) -c : (N,) -Return: [c^T A^i b for i in [L]] -""" - -import torch -import torch.nn.functional as F -from einops import rearrange, repeat - -from clm.src.ops.toeplitz import causal_convolution - -def krylov_sequential(L, A, b, c=None): - """ Constant matrix A - - A : (..., N, N) - b : (..., N) - c : (..., N) - - Returns - if c: - x : (..., L) - x[i, l] = c[i] @ A^l @ b[i] - - else: - x : (..., N, L) - x[i, l] = A^l @ b[i] - """ - - # Check which of dim b and c is smaller to save memory - if c is not None and c.numel() < b.numel(): - return krylov_sequential(L, A.transpose(-1, -2), c, b) - - b_ = b - x = [] - for _ in range(L): - if c is not None: - x_ = torch.sum(c*b_, dim=-1) # (...) # could be faster with matmul or einsum? - else: - x_ = b_ - x.append(x_) - b_ = (A @ b_.unsqueeze(-1)).squeeze(-1) - - x = torch.stack(x, dim=-1) - return x - - -def krylov(L, A, b, c=None, return_power=False): - """ - Compute the Krylov matrix (b, Ab, A^2b, ...) using the squaring trick. - - If return_power=True, return A^{L-1} as well - """ - # TODO There is an edge case if L=1 where output doesn't get broadcasted, which might be an issue if caller is expecting broadcasting semantics... can deal with it if it arises - - x = b.unsqueeze(-1) # (..., N, 1) - A_ = A - - AL = None - if return_power: - AL = torch.eye(A.shape[-1], dtype=A.dtype, device=A.device) - _L = L-1 - - done = L == 1 - # loop invariant: _L represents how many indices left to compute - while not done: - if return_power: - if _L % 2 == 1: AL = A_ @ AL - _L //= 2 - - # Save memory on last iteration - l = x.shape[-1] - if L - l <= l: - done = True - _x = x[..., :L-l] - else: _x = x - - _x = A_ @ _x - x = torch.cat([x, _x], dim=-1) # there might be a more efficient way of ordering axes - if not done: A_ = A_ @ A_ - - assert x.shape[-1] == L - - if c is not None: - x = torch.einsum('...nl, ...n -> ...l', x, c) - x = x.contiguous() # WOW!! - if return_power: - return x, AL - else: - return x - -@torch.no_grad() -def power(L, A, v=None): - """ Compute A^L and the scan sum_i A^i v_i - - A: (..., N, N) - v: (..., N, L) - """ - - I = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device) - - powers = [A] - l = 1 - while True: - if L % 2 == 1: I = powers[-1] @ I - L //= 2 - if L == 0: break - l *= 2 - if v is None: - powers = [powers[-1] @ powers[-1]] - else: - powers.append(powers[-1] @ powers[-1]) - - if v is None: return I - - # Invariants: - # powers[-1] := A^l - # l := largest po2 at most L - - # Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A - # We do this reverse divide-and-conquer for efficiency reasons: - # 1) it involves fewer padding steps for non-po2 L - # 2) it involves more contiguous arrays - - # Take care of edge case for non-po2 arrays - # Note that this initial step is a no-op for the case of power of 2 (l == L) - k = v.size(-1) - l - v_ = powers.pop() @ v[..., l:] - v = v[..., :l] - v[..., :k] = v[..., :k] + v_ - - # Handle reduction for power of 2 - while v.size(-1) > 1: - v = rearrange(v, '... (z l) -> ... z l', z=2) - v = v[..., 0, :] + powers.pop() @ v[..., 1, :] - return I, v.squeeze(-1) - -def krylov_toeplitz(L, A, b, c=None): - """ Specializes to lower triangular Toeplitz matrix A represented by its diagonals - - A : (..., N) - b : (..., N) - c : (..., N) - - Returns - x : (..., N, L) - x[i, l] = A^l @ b[i] - """ - x = b.unsqueeze(0) # (1, ..., N) - A_ = A - while x.shape[0] < L: - xx = causal_convolution(A_, x) - x = torch.cat([x, xx], dim=0) # there might be a more efficient way of ordering axes - A_ = causal_convolution(A_, A_) - x = x[:L, ...] # (L, ..., N) - if c is not None: - x = torch.einsum('l...n, ...n -> ...l', x, c) - else: - x = rearrange(x, 'l ... n -> ... n l') - x = x.contiguous() - return x - -def krylov_toeplitz_(L, A, b, c=None): - """ Padded version of krylov_toeplitz that saves some fft's - - TODO currently not faster than original version, not sure why - """ - N = A.shape[-1] - - x = b.unsqueeze(0) # (1, ..., N) - x = F.pad(x, (0, N)) - A = F.pad(A, (0, N)) - done = L == 1 - while not done: - l = x.shape[0] - # Save memory on last iteration - if L - l <= l: - done = True - _x = x[:L-l] - else: _x = x - Af = torch.fft.rfft(A, n=2*N, dim=-1) - xf = torch.fft.rfft(_x, n=2*N, dim=-1) - xf_ = Af * xf - x_ = torch.fft.irfft(xf_, n=2*N, dim=-1) - x_[..., N:] = 0 - x = torch.cat([x, x_], dim=0) # there might be a more efficient way of ordering axes - if not done: - A = torch.fft.irfft(Af*Af, n=2*N, dim=-1) - A[..., N:] = 0 - x = x[:L, ..., :N] # (L, ..., N) - if c is not None: - x = torch.einsum('l...n, ...n -> ...l', x, c) - else: - x = rearrange(x, 'l ... n -> ... n l') - x = x.contiguous() - return x diff --git a/src/clm/src/ops/toeplitz.py b/src/clm/src/ops/toeplitz.py deleted file mode 100644 index af007390..00000000 --- a/src/clm/src/ops/toeplitz.py +++ /dev/null @@ -1,157 +0,0 @@ -""" Utilities for computing convolutions. - -There are 3 equivalent views: - 1. causal convolution - 2. multiplication of (lower) triangular Toeplitz matrices - 3. polynomial multiplication (mod x^N) -""" - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -def construct_toeplitz(v, f=0.0): - """Explicit construction of Krylov matrix [v A @ v A^2 @ v ... A^{n-1} @ v] - where A = Z_f. This uses vectorized indexing and cumprod so it's much - faster than using the Krylov function. - Parameters: - v: the starting vector of size n or (rank, n). - f: real number - Returns: - K: Krylov matrix of size (n, n) or (rank, n, n). - """ - n = v.shape[-1] - a = torch.arange(n, device=v.device) - b = -a - indices = a[:, None] + b[None] - K = v[..., indices] - K[..., indices < 0] *= f - return K - -def triangular_toeplitz_multiply_(u, v, sum=None): - n = u.shape[-1] - u_expand = F.pad(u, (0, n)) - v_expand = F.pad(v, (0, n)) - u_f = torch.fft.rfft(u_expand, n=2*n, dim=-1) - v_f = torch.fft.rfft(v_expand, n=2*n, dim=-1) - uv_f = u_f * v_f - if sum is not None: - uv_f = uv_f.sum(dim=sum) - output = torch.fft.irfft(uv_f, n=2*n, dim=-1)[..., :n] - return output - -def triangular_toeplitz_multiply_padded_(u, v): - """ Same as triangular_toeplitz_multiply but inputs and output assume to be 0-padded already. """ - n = u.shape[-1] - assert n % 2 == 0 - u_f = torch.fft.rfft(u, n=n, dim=-1) - v_f = torch.fft.rfft(v, n=n, dim=-1) - uv_f = u_f * v_f - output = torch.fft.irfft(uv_f, n=n, dim=-1) - output[..., n:] = 0 - return output - -class TriangularToeplitzMult(torch.autograd.Function): - @staticmethod - def forward(ctx, u, v): - ctx.save_for_backward(u, v) - return triangular_toeplitz_multiply_(u, v) - - @staticmethod - def backward(ctx, grad): - u, v = ctx.saved_tensors - d_u = triangular_toeplitz_multiply_(grad.flip(-1), v).flip(-1) - d_v = triangular_toeplitz_multiply_(grad.flip(-1), u).flip(-1) - return d_u, d_v - -class TriangularToeplitzMultFast(torch.autograd.Function): - @staticmethod - def forward(ctx, u, v): - n = u.shape[-1] - u_expand = F.pad(u, (0, n)) - v_expand = F.pad(v, (0, n)) - u_f = torch.fft.rfft(u_expand, n=2*n, dim=-1) - v_f = torch.fft.rfft(v_expand, n=2*n, dim=-1) - - ctx.save_for_backward(u_f, v_f) - - uv_f = u_f * v_f - output = torch.fft.irfft(uv_f, n=2*n, dim=-1)[..., :n] - return output - - @staticmethod - def backward(ctx, grad): - u_f, v_f = ctx.saved_tensors - n = grad.shape[-1] - g_expand = F.pad(grad.flip(-1), (0, n)) - g_f = torch.fft.rfft(g_expand, n=2*n, dim=-1) - gu_f = g_f * u_f - gv_f = g_f * v_f - d_u = torch.fft.irfft(gv_f, n=2*n, dim=-1)[..., :n] - d_v = torch.fft.irfft(gu_f, n=2*n, dim=-1)[..., :n] - d_u = d_u.flip(-1) - d_v = d_v.flip(-1) - return d_u, d_v - -class TriangularToeplitzMultPadded(torch.autograd.Function): - @staticmethod - def forward(ctx, u, v): - ctx.save_for_backward(u, v) - output = triangular_toeplitz_multiply_(u, v) - return output - - @staticmethod - def backward(ctx, grad): - u, v = ctx.saved_tensors - d_u = triangular_toeplitz_multiply_padded_(grad.flip(-1), v).flip(-1) - d_v = triangular_toeplitz_multiply_padded_(grad.flip(-1), u).flip(-1) - return d_u, d_v - -class TriangularToeplitzMultPaddedFast(torch.autograd.Function): - """ Trade off speed (20-25% faster) for more memory (20-25%) """ - - @staticmethod - def forward(ctx, u, v): - n = u.shape[-1] - u_f = torch.fft.rfft(u, n=n, dim=-1) - v_f = torch.fft.rfft(v, n=n, dim=-1) - - ctx.save_for_backward(u_f, v_f) - - uv_f = u_f * v_f - output = torch.fft.irfft(uv_f, n=n, dim=-1) - output[..., n//2:].zero_() - return output - - @staticmethod - def backward(ctx, grad): - u_f, v_f = ctx.saved_tensors - n = grad.shape[-1] - g_expand = F.pad(grad[..., :n//2].flip(-1), (0, n//2)) - g_f = torch.fft.rfft(g_expand, n=n, dim=-1) - gu_f = g_f * u_f - gv_f = g_f * v_f - d_u = torch.fft.irfft(gv_f, n=n, dim=-1) - d_v = torch.fft.irfft(gu_f, n=n, dim=-1) - d_u[..., n//2:].zero_() - d_v[..., n//2:].zero_() - d_u[..., :n//2] = d_u[..., :n//2].flip(-1) # TODO - d_v[..., :n//2] = d_v[..., :n//2].flip(-1) # TODO - return d_u, d_v - -# triangular_toeplitz_multiply = triangular_toeplitz_multiply_ -triangular_toeplitz_multiply = TriangularToeplitzMult.apply -triangular_toeplitz_multiply_fast = TriangularToeplitzMultFast.apply -triangular_toeplitz_multiply_padded = TriangularToeplitzMultPadded.apply -triangular_toeplitz_multiply_padded_fast = TriangularToeplitzMultPaddedFast.apply - -def causal_convolution(u, v, fast=True, pad=False): - if not pad and not fast: - return triangular_toeplitz_multiply(u, v) - if not pad and fast: - return triangular_toeplitz_multiply_fast(u, v) - if pad and not fast: - return triangular_toeplitz_multiply_padded(u, v) - if pad and fast: - return triangular_toeplitz_multiply_padded_fast(u, v) diff --git a/src/clm/src/ops/unroll.py b/src/clm/src/ops/unroll.py deleted file mode 100644 index b8f8c8db..00000000 --- a/src/clm/src/ops/unroll.py +++ /dev/null @@ -1,421 +0,0 @@ -""" Old utilities for parallel scan implementation of Linear RNNs. """ -# TODO this file could use much cleanup - -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -import math - -from clm.src.models.functional.toeplitz import triangular_toeplitz_multiply, triangular_toeplitz_multiply_padded -from clm.src.utils.permutations import bitreversal_po2, bitreversal_permutation - - - -### Utilities - - -def shift_up(a, s=None, drop=True, dim=0): - assert dim == 0 - if s is None: - s = torch.zeros_like(a[0, ...]) - s = s.unsqueeze(dim) - if drop: - a = a[:-1, ...] - return torch.cat((s, a), dim=dim) - -def interleave(a, b, uneven=False, dim=0): - """ Interleave two tensors of same shape """ - # assert(a.shape == b.shape) - assert dim == 0 # TODO temporary to make handling uneven case easier - if dim < 0: - dim = N + dim - if uneven: - a_ = a[-1:, ...] - a = a[:-1, ...] - c = torch.stack((a, b), dim+1) - out_shape = list(a.shape) - out_shape[dim] *= 2 - c = c.view(out_shape) - if uneven: - c = torch.cat((c, a_), dim=dim) - return c - -def batch_mult(A, u, has_batch=None): - """ Matrix mult A @ u with special case to save memory if u has additional batch dim - - The batch dimension is assumed to be the second dimension - A : (L, ..., N, N) - u : (L, [B], ..., N) - has_batch: True, False, or None. If None, determined automatically - - Output: - x : (L, [B], ..., N) - A @ u broadcasted appropriately - """ - - if has_batch is None: - has_batch = len(u.shape) >= len(A.shape) - - if has_batch: - u = u.permute([0] + list(range(2, len(u.shape))) + [1]) - else: - u = u.unsqueeze(-1) - v = (A @ u) - if has_batch: - v = v.permute([0] + [len(u.shape)-1] + list(range(1, len(u.shape)-1))) - else: - v = v[..., 0] - return v - - - -### Main unrolling functions - -def unroll(A, u): - """ - A : (..., N, N) # TODO I think this can't take batch dimension? - u : (L, ..., N) - output : x (..., N) # TODO a lot of these shapes are wrong - x[i, ...] = A^{i} @ u[0, ...] + ... + A @ u[i-1, ...] + u[i, ...] - """ - - m = u.new_zeros(u.shape[1:]) - outputs = [] - for u_ in torch.unbind(u, dim=0): - m = F.linear(m, A) + u_ - outputs.append(m) - - output = torch.stack(outputs, dim=0) - return output - - -def parallel_unroll_recursive(A, u): - """ Bottom-up divide-and-conquer version of unroll. """ - - # Main recursive function - def parallel_unroll_recursive_(A, u): - if u.shape[0] == 1: - return u - - u_evens = u[0::2, ...] - u_odds = u[1::2, ...] - - # u2 = F.linear(u_evens, A) + u_odds - u2 = (A @ u_evens.unsqueeze(-1)).squeeze(-1) + u_odds - A2 = A @ A - - x_odds = parallel_unroll_recursive_(A2, u2) - # x_evens = F.linear(shift_up(x_odds), A) + u_evens - x_evens = (A @ shift_up(x_odds).unsqueeze(-1)).squeeze(-1) + u_evens - - x = interleave(x_evens, x_odds, dim=0) - return x - - # Pad u to power of 2 - n = u.shape[0] - m = int(math.ceil(math.log(n)/math.log(2))) - N = 1 << m - u = torch.cat((u, u.new_zeros((N-u.shape[0],) + u.shape[1:] )), dim=0) - - return parallel_unroll_recursive_(A, u)[:n, ...] - - - -def parallel_unroll_recursive_br(A, u): - """ Same as parallel_unroll_recursive but uses bit reversal for locality. """ - - # Main recursive function - def parallel_unroll_recursive_br_(A, u): - n = u.shape[0] - if n == 1: - return u - - m = n//2 - u_0 = u[:m, ...] - u_1 = u[m:, ...] - - u2 = F.linear(u_0, A) + u_1 - A2 = A @ A - - x_1 = parallel_unroll_recursive_br_(A2, u2) - x_0 = F.linear(shift_up(x_1), A) + u_0 - - # x = torch.cat((x_0, x_1), dim=0) # is there a way to do this with cat? - x = interleave(x_0, x_1, dim=0) - return x - - # Pad u to power of 2 - n = u.shape[0] - m = int(math.ceil(math.log(n)/math.log(2))) - N = 1 << m - u = torch.cat((u, u.new_zeros((N-u.shape[0],) + u.shape[1:] )), dim=0) - - # Apply bit reversal - br = bitreversal_po2(N) - u = u[br, ...] - - x = parallel_unroll_recursive_br_(A, u) - return x[:n, ...] - -def parallel_unroll_iterative(A, u): - """ Bottom-up divide-and-conquer version of unroll, implemented iteratively """ - - # Pad u to power of 2 - n = u.shape[0] - m = int(math.ceil(math.log(n)/math.log(2))) - N = 1 << m - u = torch.cat((u, u.new_zeros((N-u.shape[0],) + u.shape[1:] )), dim=0) - - # Apply bit reversal - br = bitreversal_po2(N) - u = u[br, ...] - - # Main recursive loop, flattened - us = [] # stores the u_0 terms in the recursive version - N_ = N - As = [] # stores the A matrices - for l in range(m): - N_ = N_ // 2 - As.append(A) - u_0 = u[:N_, ...] - us.append(u_0) - u = F.linear(u_0, A) + u[N_:, ...] - A = A @ A - x_0 = [] - x = u # x_1 - for l in range(m-1, -1, -1): - x_0 = F.linear(shift_up(x), As[l]) + us[l] - x = interleave(x_0, x, dim=0) - - return x[:n, ...] - - -def variable_unroll_sequential(A, u, s=None, variable=True): - """ Unroll with variable (in time/length) transitions A. - - A : ([L], ..., N, N) dimension L should exist iff variable is True - u : (L, [B], ..., N) updates - s : ([B], ..., N) start state - output : x (..., N) - x[i, ...] = A[i]..A[0] @ s + A[i..1] @ u[0] + ... + A[i] @ u[i-1] + u[i] - """ - - if s is None: - s = torch.zeros_like(u[0]) - - if not variable: - A = A.expand((u.shape[0],) + A.shape) - has_batch = len(u.shape) >= len(A.shape) - - outputs = [] - for (A_, u_) in zip(torch.unbind(A, dim=0), torch.unbind(u, dim=0)): - # s = F.linear(s, A_) + u_ - s = batch_mult(A_.unsqueeze(0), s.unsqueeze(0), has_batch)[0] - s = s + u_ - outputs.append(s) - - output = torch.stack(outputs, dim=0) - return output - - - -def variable_unroll(A, u, s=None, variable=True, recurse_limit=16): - """ Bottom-up divide-and-conquer version of variable_unroll. """ - - if u.shape[0] <= recurse_limit: - return variable_unroll_sequential(A, u, s, variable) - - if s is None: - s = torch.zeros_like(u[0]) - - uneven = u.shape[0] % 2 == 1 - has_batch = len(u.shape) >= len(A.shape) - - u_0 = u[0::2, ...] - u_1 = u[1::2, ...] - - if variable: - A_0 = A[0::2, ...] - A_1 = A[1::2, ...] - else: - A_0 = A - A_1 = A - - u_0_ = u_0 - A_0_ = A_0 - if uneven: - u_0_ = u_0[:-1, ...] - if variable: - A_0_ = A_0[:-1, ...] - - u_10 = batch_mult(A_1, u_0_, has_batch) - u_10 = u_10 + u_1 - A_10 = A_1 @ A_0_ - - # Recursive call - x_1 = variable_unroll(A_10, u_10, s, variable, recurse_limit) - - x_0 = shift_up(x_1, s, drop=not uneven) - x_0 = batch_mult(A_0, x_0, has_batch) - x_0 = x_0 + u_0 - - - x = interleave(x_0, x_1, uneven, dim=0) # For some reason this interleave is slower than in the (non-multi) unroll_recursive - return x - -def variable_unroll_general_sequential(A, u, s, op, variable=True): - """ Unroll with variable (in time/length) transitions A with general associative operation - - A : ([L], ..., N, N) dimension L should exist iff variable is True - u : (L, [B], ..., N) updates - s : ([B], ..., N) start state - output : x (..., N) - x[i, ...] = A[i]..A[0] s + A[i..1] u[0] + ... + A[i] u[i-1] + u[i] - """ - - if not variable: - A = A.expand((u.shape[0],) + A.shape) - - outputs = [] - for (A_, u_) in zip(torch.unbind(A, dim=0), torch.unbind(u, dim=0)): - s = op(A_, s) - s = s + u_ - outputs.append(s) - - output = torch.stack(outputs, dim=0) - return output - -def variable_unroll_matrix_sequential(A, u, s=None, variable=True): - if s is None: - s = torch.zeros_like(u[0]) - - if not variable: - A = A.expand((u.shape[0],) + A.shape) - # has_batch = len(u.shape) >= len(A.shape) - - # op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0), has_batch)[0] - op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0))[0] - - return variable_unroll_general_sequential(A, u, s, op, variable=True) - -def variable_unroll_toeplitz_sequential(A, u, s=None, variable=True, pad=False): - if s is None: - s = torch.zeros_like(u[0]) - - if not variable: - A = A.expand((u.shape[0],) + A.shape) - # has_batch = len(u.shape) >= len(A.shape) - - # op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0), has_batch)[0] - # op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0))[0] - - if pad: - n = A.shape[-1] - A = F.pad(A, (0, n)) - u = F.pad(u, (0, n)) - s = F.pad(s, (0, n)) - ret = variable_unroll_general_sequential(A, u, s, triangular_toeplitz_multiply_padded, variable=True) - ret = ret[..., :n] - return ret - - return variable_unroll_general_sequential(A, u, s, triangular_toeplitz_multiply, variable=True) - - - -### General parallel scan functions with generic binary composition operators - -def variable_unroll_general(A, u, s, op, compose_op=None, sequential_op=None, variable=True, recurse_limit=16): - """ Bottom-up divide-and-conquer version of variable_unroll. - - compose is an optional function that defines how to compose A without multiplying by a leaf u - """ - - if u.shape[0] <= recurse_limit: - if sequential_op is None: - sequential_op = op - return variable_unroll_general_sequential(A, u, s, sequential_op, variable) - - if compose_op is None: - compose_op = op - - uneven = u.shape[0] % 2 == 1 - # has_batch = len(u.shape) >= len(A.shape) - - u_0 = u[0::2, ...] - u_1 = u[1::2, ...] - - if variable: - A_0 = A[0::2, ...] - A_1 = A[1::2, ...] - else: - A_0 = A - A_1 = A - - u_0_ = u_0 - A_0_ = A_0 - if uneven: - u_0_ = u_0[:-1, ...] - if variable: - A_0_ = A_0[:-1, ...] - - u_10 = op(A_1, u_0_) # batch_mult(A_1, u_0_, has_batch) - u_10 = u_10 + u_1 - A_10 = compose_op(A_1, A_0_) - - # Recursive call - x_1 = variable_unroll_general(A_10, u_10, s, op, compose_op, sequential_op, variable=variable, recurse_limit=recurse_limit) - - x_0 = shift_up(x_1, s, drop=not uneven) - x_0 = op(A_0, x_0) # batch_mult(A_0, x_0, has_batch) - x_0 = x_0 + u_0 - - - x = interleave(x_0, x_1, uneven, dim=0) # For some reason this interleave is slower than in the (non-multi) unroll_recursive - return x - -def variable_unroll_matrix(A, u, s=None, variable=True, recurse_limit=16): - if s is None: - s = torch.zeros_like(u[0]) - has_batch = len(u.shape) >= len(A.shape) - op = lambda x, y: batch_mult(x, y, has_batch) - sequential_op = lambda x, y: batch_mult(x.unsqueeze(0), y.unsqueeze(0), has_batch)[0] - matmul = lambda x, y: x @ y - return variable_unroll_general(A, u, s, op, compose_op=matmul, sequential_op=sequential_op, variable=variable, recurse_limit=recurse_limit) - -def variable_unroll_toeplitz(A, u, s=None, variable=True, recurse_limit=8, pad=False): - """ Unroll with variable (in time/length) transitions A with general associative operation - - A : ([L], ..., N) dimension L should exist iff variable is True - u : (L, [B], ..., N) updates - s : ([B], ..., N) start state - output : x (L, [B], ..., N) same shape as u - x[i, ...] = A[i]..A[0] s + A[i..1] u[0] + ... + A[i] u[i-1] + u[i] - """ - # Add the batch dimension to A if necessary - A_batch_dims = len(A.shape) - int(variable) - u_batch_dims = len(u.shape)-1 - if u_batch_dims > A_batch_dims: - # assert u_batch_dims == A_batch_dims + 1 - if variable: - while len(A.shape) < len(u.shape): - A = A.unsqueeze(1) - # else: - # A = A.unsqueeze(0) - - if s is None: - s = torch.zeros_like(u[0]) - - if pad: - n = A.shape[-1] - A = F.pad(A, (0, n)) - u = F.pad(u, (0, n)) - s = F.pad(s, (0, n)) - op = triangular_toeplitz_multiply_padded - ret = variable_unroll_general(A, u, s, op, compose_op=op, variable=variable, recurse_limit=recurse_limit) - ret = ret[..., :n] - return ret - - op = triangular_toeplitz_multiply - ret = variable_unroll_general(A, u, s, op, compose_op=op, variable=variable, recurse_limit=recurse_limit) - return ret diff --git a/src/clm/src/ops/vandermonde.py b/src/clm/src/ops/vandermonde.py deleted file mode 100644 index 0325b4ec..00000000 --- a/src/clm/src/ops/vandermonde.py +++ /dev/null @@ -1,167 +0,0 @@ -"""pykeops implementations of the Vandermonde matrix multiplication kernel used in the S4D kernel.""" -import math -import torch - -from einops import rearrange, repeat -from opt_einsum import contract - -import os - -try: - import pykeops - from pykeops.torch import LazyTensor, Genred -except: - pass - -try: - from cauchy_mult import vand_log_mult_sym_fwd, vand_log_mult_sym_bwd -except: - vand_log_mult_sym_fwd, vand_log_mult_sym_bwd = None, None - -_conj = lambda x: torch.cat([x, x.conj()], dim=-1) -def _broadcast_dims(*tensors): - max_dim = max([len(tensor.shape) for tensor in tensors]) - tensors = [tensor.view((1,)*(max_dim-len(tensor.shape))+tensor.shape) for tensor in tensors] - return tensors - -def _c2r(x): return torch.view_as_real(x) -def _r2c(x): return torch.view_as_complex(x) - -def vandermonde_naive(v, x, L, conj=True): - """ - v: (..., N) - x: (..., N) - returns: (..., L) \sum v x^l - """ - if conj: - x = _conj(x) - v = _conj(v) - vandermonde_matrix = x.unsqueeze(-1) ** torch.arange(L).to(x) # (... N L) - vandermonde_prod = torch.sum(v.unsqueeze(-1) * vandermonde_matrix, dim=-2) # (... L) - return vandermonde_prod - -def log_vandermonde_naive(v, x, L, conj=True): - """ - v: (..., N) - x: (..., N) - returns: (..., L) \sum v x^l - """ - vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L) - vandermonde_prod = contract('... n, ... n l -> ... l', v, vandermonde_matrix) # (... L) - if conj: - return 2*vandermonde_prod.real - else: - return vandermonde_prod - -def log_vandermonde_lazy(v, x, L, conj=True): - if conj: - v = _conj(v) - x = _conj(x) - l = torch.arange(L).to(x) - v, x, l = _broadcast_dims(v, x, l) - v_l = LazyTensor(rearrange(v, '... N -> ... N 1 1')) - x_l = LazyTensor(rearrange(x, '... N -> ... N 1 1')) - l_l = LazyTensor(rearrange(l, '... L -> ... 1 L 1')) - # exp - vand = (x_l * l_l).exp() - s = (v_l*vand).sum(dim=len(v_l.shape)-2) - return s.squeeze(-1) - -def log_vandermonde(v, x, L, conj=True): - expr = 'ComplexMult(v, ComplexExp(ComplexMult(x, l)))' - vandermonde_mult = Genred( - expr, - [ - 'v = Vj(2)', - 'x = Vj(2)', - 'l = Vi(2)', - ], - reduction_op='Sum', - axis=1, - ) - - l = torch.arange(L).to(x) - v, x, l = _broadcast_dims(v, x, l) - v = _c2r(v) - x = _c2r(x) - l = _c2r(l) - - r = vandermonde_mult(v, x, l, backend='GPU') - if conj: - return 2*_r2c(r).real - else: - return _r2c(r) - -def log_vandermonde_transpose_naive(u, v, x, L): - vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L) - vandermonde_prod = contract('... l, ... n, ... n l -> ... n', u.to(x), v.to(x), vandermonde_matrix) # (... L) - return vandermonde_prod - -def log_vandermonde_transpose(u, v, x, L): - """ - u: ... H L - v: ... H N - x: ... H N - Returns: ... H N - - V = Vandermonde(a, L) : (H N L) - contract_L(V * u * v) - """ - expr = 'ComplexMult(ComplexMult(v, u), ComplexExp(ComplexMult(x, l)))' - vandermonde_mult = Genred( - expr, - [ - 'u = Vj(2)', - 'v = Vi(2)', - 'x = Vi(2)', - 'l = Vj(2)', - ], - reduction_op='Sum', - axis=1, - ) - - l = torch.arange(L).to(x) - u, v, x, l = _broadcast_dims(u, v, x, l) - u = _c2r(u) - v = _c2r(v) - x = _c2r(x) - l = _c2r(l) - - r = vandermonde_mult(u, v, x, l, backend='GPU') - return _r2c(r) - -def _log_vandermonde_matmul(x, L): - vandermonde_matrix = torch.exp(x.unsqueeze(-1) * torch.arange(L).to(x)) # (... N L) - return vandermonde_matrix - -def log_vandermonde_matmul(v, K): - prod = contract('...n, ...nl -> ...l', v, K) - return 2*prod.real - -class LogVandMultiplySymmetric(torch.autograd.Function): - - @staticmethod - def forward(ctx, v, x, L): - batch, N = v.shape - supported_N_values = [1 << log_n for log_n in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]] - if not N in supported_N_values: - raise NotImplementedError(f'Only support N values in {supported_N_values}') - max_L_value = 32 * 1024 * 64 * 1024 - if L > max_L_value: - raise NotImplementedError(f'Only support L values <= {max_L_value}') - if not v.is_cuda and x.is_cuda: - raise NotImplementedError(f'Only support CUDA tensors') - ctx.save_for_backward(v, x) - return vand_log_mult_sym_fwd(v, x, L) - - @staticmethod - def backward(ctx, dout): - v, x = ctx.saved_tensors - dv, dx = vand_log_mult_sym_bwd(v, x, dout) - return dv, dx, None - - -if vand_log_mult_sym_fwd and vand_log_mult_sym_bwd is not None: - log_vandermonde_fast = LogVandMultiplySymmetric.apply -else: - log_vandermonde_fast = None \ No newline at end of file diff --git a/src/clm/src/retnet/__init__.py b/src/clm/src/retnet/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/clm/src/retnet/complex/retention.py b/src/clm/src/retnet/complex/retention.py deleted file mode 100644 index 9a61a2d9..00000000 --- a/src/clm/src/retnet/complex/retention.py +++ /dev/null @@ -1,177 +0,0 @@ -import math - -import torch -import torch.nn as nn - -from util import ComplexGroupNorm - -class SimpleRetention(nn.Module): - def __init__(self, hidden_size, gamma, precision="single"): - """ - Simple retention mechanism based on the paper - "Retentive Network: A Successor to Transformer for Large Language Models"[https://arxiv.org/pdf/2307.08621.pdf] - """ - super(SimpleRetention, self).__init__() - - if precision == "half": - raise NotImplementedError("batchmm does not support half precision complex yet.") - self.complex_type = torch.complex32 - self.real_type = torch.float16 - elif precision == "single": - self.complex_type = torch.complex64 - self.real_type = torch.float32 - - self.precision = precision - self.hidden_size = hidden_size - self.gamma = gamma - - self.i = torch.complex(torch.tensor(0.0), torch.tensor(1.0)) - - self.W_Q = nn.Parameter(torch.randn(hidden_size, hidden_size, dtype=self.real_type) / hidden_size) - self.W_K = nn.Parameter(torch.randn(hidden_size, hidden_size, dtype=self.real_type) / hidden_size) - self.W_V = nn.Parameter(torch.randn(hidden_size, hidden_size, dtype=self.real_type) / hidden_size) - - - self.theta = torch.randn(hidden_size) / hidden_size - self.theta = nn.Parameter(self.theta) - - - - def forward(self, X): - """ - Parallel (default) representation of the retention mechanism. - X: (batch_size, sequence_length, hidden_size) - """ - sequence_length = X.shape[1] - D = self._get_D(sequence_length).to(X.device) - - if X.dtype != self.complex_type: - X = torch.complex(X, torch.zeros_like(X)).to(self.complex_type) - - i = self.i.to(X.device) - ns = torch.arange(1, sequence_length + 1, dtype=self.real_type, device=X.device) - ns = torch.complex(ns, torch.zeros_like(ns)).to(self.complex_type) - Theta = [] - - for n in ns: - Theta.append(torch.exp(i * n * self.theta)) - - Theta = torch.stack(Theta, dim=0) - - Theta_bar = Theta.conj() - - Q = (X @ self.W_Q.to(self.complex_type)) * Theta.unsqueeze(0) - K = (X @ self.W_K.to(self.complex_type)) * Theta_bar.unsqueeze(0) - V = X @ self.W_V.to(self.complex_type) - att = (Q @ K.permute(0, 2, 1)) * D.unsqueeze(0) - - return att @ V - - def forward_recurrent(self, x_n, s_n_1, n): - """ - Recurrent representation of the retention mechanism. - x_n: (batch_size, hidden_size) - s_n_1: (batch_size, hidden_size) - """ - if x_n.dtype != self.complex_type: - x_n = torch.complex(x_n, torch.zeros_like(x_n)).to(self.complex_type) - - n = torch.tensor(n, dtype=self.complex_type, device=x_n.device) - - Theta = torch.exp(self.i * n * self.theta) - Theta_bar = Theta.conj() - - Q = (x_n @ self.W_Q.to(self.complex_type)) * Theta - K = (x_n @ self.W_K.to(self.complex_type)) * Theta_bar - V = x_n @ self.W_V.to(self.complex_type) - - # K: (batch_size, hidden_size) - # V: (batch_size, hidden_size) - # s_n_1: (batch_size, hidden_size, hidden_size) - # s_n = gamma * s_n_1 + K^T @ V - - s_n = self.gamma * s_n_1 + K.unsqueeze(2) @ V.unsqueeze(1) - - return (Q.unsqueeze(1) @ s_n).squeeze(1), s_n - - def _get_D(self, sequence_length): - n = torch.arange(sequence_length).unsqueeze(1) - m = torch.arange(sequence_length).unsqueeze(0) - - # Broadcast self.gamma ** (n - m) with appropriate masking to set values where n < m to 0 - D = (self.gamma ** (n - m)) * (n >= m).float() #this results in some NaN when n is much larger than m - # fill the NaN with 0 - D[D != D] = 0 - - return D - -class MultiScaleRetention(nn.Module): - def __init__(self, hidden_size, heads, precision="single"): - """ - Multi-scale retention mechanism based on the paper - "Retentive Network: A Successor to Transformer for Large Language Models"[https://arxiv.org/pdf/2307.08621.pdf] - """ - super(MultiScaleRetention, self).__init__() - self.hidden_size = hidden_size - self.heads = heads - self.precision = precision - assert hidden_size % heads == 0, "hidden_size must be divisible by heads" - self.head_size = hidden_size // heads - - if precision == "half": - raise NotImplementedError("batchmm does not support half precision complex yet.") - self.complex_type = torch.complex32 - self.real_type = torch.float16 - elif precision == "single": - self.complex_type = torch.complex64 - self.real_type = torch.float32 - - self.gammas = (1 - torch.exp(torch.linspace(math.log(1/32), math.log(1/512), heads, dtype=self.real_type))).detach().cpu().tolist() - - self.swish = lambda x: x * torch.sigmoid(x) - self.W_G = nn.Parameter(torch.randn(hidden_size, hidden_size, dtype=self.complex_type) / hidden_size) - self.W_O = nn.Parameter(torch.randn(hidden_size, hidden_size, dtype=self.complex_type) / hidden_size) - self.group_norm = ComplexGroupNorm(heads, hidden_size) - - self.retentions = nn.ModuleList([ - SimpleRetention(self.head_size, gamma) for gamma in self.gammas - ]) - - def forward(self, X): - """ - parallel representation of the multi-scale retention mechanism - """ - if X.dtype != self.complex_type: - X = torch.complex(X, torch.zeros_like(X)).to(self.complex_type) - - # apply each individual retention mechanism to a slice of X - Y = [] - for i in range(self.heads): - Y.append(self.retentions[i](X[:, :, i*self.head_size:(i+1)*self.head_size])) - - Y = torch.cat(Y, dim=2) - Y = self.group_norm(Y.reshape(-1, self.hidden_size)).reshape(X.shape) - - return (self.swish(X @ self.W_G.to(self.complex_type)) * Y) @ self.W_O.to(self.complex_type) - - def forward_recurrent(self, x_n, s_n_1s, n): - """ - recurrent representation of the multi-scale retention mechanism - """ - if x_n.dtype != self.complex_type: - x_n = torch.complex(x_n, torch.zeros_like(x_n)).to(self.complex_type) - n = torch.tensor(n, dtype=self.complex_type, device=x_n.device) - - # apply each individual retention mechanism to a slice of X - Y = [] - s_ns = [] - for i in range(self.heads): - y, s_n = self.retentions[i].forward_recurrent( - x_n[:, i*self.head_size:(i+1)*self.head_size], s_n_1s[i], n - ) - Y.append(y) - s_ns.append(s_n) - - Y = torch.cat(Y, dim=1) - Y = self.group_norm(Y) - return (self.swish(x_n @ self.W_G.to(self.complex_type)) * Y) @ self.W_O.to(self.complex_type), s_ns diff --git a/src/clm/src/retnet/complex/retnet.py b/src/clm/src/retnet/complex/retnet.py deleted file mode 100644 index 4582c859..00000000 --- a/src/clm/src/retnet/complex/retnet.py +++ /dev/null @@ -1,118 +0,0 @@ -import torch -import torch.nn as nn - -from retention import MultiScaleRetention -from util import ComplexFFN, ComplexGroupNorm, ComplexLayerNorm - -class RetNet(nn.Module): - def __init__(self, layers, hidden_dim, ffn_size, heads): - super(RetNet, self).__init__() - self.layers = layers - self.hidden_dim = hidden_dim - self.ffn_size = ffn_size - self.heads = heads - - self.retentions = nn.ModuleList([ - MultiScaleRetention(hidden_dim, heads) - for _ in range(layers) - ]) - self.ffns = nn.ModuleList([ - ComplexFFN(hidden_dim, ffn_size) - for _ in range(layers) - ]) - self.layer_norm = ComplexLayerNorm(hidden_dim) - - def forward(self, X): - """ - X: (batch_size, sequence_length, hidden_size) - """ - for i in range(self.layers): - Y = self.retentions[i](self.layer_norm(X)) + X - X = self.ffns[i](self.layer_norm(Y)) + Y - - return X - - def forward_recurrent(self, x_n, s_n_1s, n): - """ - X: (batch_size, sequence_length, hidden_size) - s_n_1s: list of lists of tensors of shape (batch_size, hidden_size // heads, hidden_size // heads) - - """ - s_ns = [] - for i in range(self.layers): - o_n, s_n = self.retentions[i].forward_recurrent(self.layer_norm(x_n), s_n_1s[i], n) - y_n = o_n + x_n - s_ns.append(s_n) - x_n = self.ffns[i](self.layer_norm(y_n)) + y_n - - return x_n, s_ns - -class RetNetCLM(nn.Module): - def __init__(self, layers, hidden_dim, ffn_size, heads, vocab_size): - """ - NOTE: softmax not included! - """ - super(RetNetCLM, self).__init__() - self.layers = layers - self.hidden_dim = hidden_dim - self.ffn_size = ffn_size - self.heads = heads - self.vocab_size = vocab_size - - self.retnet = RetNet(layers, hidden_dim, ffn_size, heads) - self.embed = nn.Embedding(vocab_size, hidden_dim) - self.proj = nn.Parameter(torch.randn(hidden_dim, vocab_size, dtype=torch.float32) / hidden_dim) - - def forward(self, input_ids): - """ - input_ids: (batch_size, sequence_length) - """ - X = self.embed(input_ids) - X = self.retnet(X) - X = X @ self.proj.to(X.dtype) - - return X.real - - def forward_recurrent(self, input_ids, s_n_1s, n): - """ - input_ids: (batch_size) - s_n_1s: list of lists of tensors of shape (batch_size, hidden_size // heads, hidden_size // heads) - """ - X = self.embed(input_ids) - X, s_ns = self.retnet.forward_recurrent(X, s_n_1s, n) - X = X @ self.proj.to(X.dtype) - - return X.real, s_ns - - def sample(self, input_ids, sample_length, temperature=1.0): - """ - input_ids: (batch_size, sequence_length) - s_n_1s: list of lists of tensors of shape (batch_size, hidden_size // heads, hidden_size // heads) - """ - s_n_1s = [ - [ - torch.zeros(self.hidden_dim // self.heads, self.hidden_dim // self.heads, dtype=torch.complex64).unsqueeze(0).repeat(input_ids.shape[0], 1, 1) - for _ in range(self.heads) - ] for _ in range(self.layers) - ] - for i in range(input_ids.shape[1]): - X, s_n_1s = self.forward_recurrent(input_ids[:, i], s_n_1s, i+1) - - # get softmax of x (real part only) - X = X.real / temperature - X = torch.softmax(X, dim=-1) - X = torch.multinomial(X, num_samples=1) - next_char = X[:, -1] - output_ids = [] - # now start sampling! - for i in range(sample_length): - X, s_n_1s = self.forward_recurrent(next_char, s_n_1s, i+1) - X = X.real / temperature - X = torch.softmax(X, dim=-1) - X = torch.multinomial(X, num_samples=1) - next_char = X[:, -1] - output_ids.append(next_char) - - output_ids = torch.stack(output_ids, dim=1) - - return output_ids \ No newline at end of file diff --git a/src/clm/src/retnet/complex/test_retention.py b/src/clm/src/retnet/complex/test_retention.py deleted file mode 100644 index 07e30d6a..00000000 --- a/src/clm/src/retnet/complex/test_retention.py +++ /dev/null @@ -1,119 +0,0 @@ -import unittest -import torch -from retention import SimpleRetention, MultiScaleRetention - -class TestSimpleRetention(unittest.TestCase): - def test_simple_retention_parallel(self): - batch_size = 4 - hidden_size = 8 - sequence_length = 16 - gamma = 0.9 - - X = torch.rand(batch_size, sequence_length, hidden_size) - retention = SimpleRetention(hidden_size, gamma) - - Y = retention(X) - self.assertEqual(Y.shape, (batch_size, sequence_length, hidden_size)) - - def test_simple_retention_recurrent(self): - batch_size = 4 - hidden_size = 8 - sequence_length = 16 - gamma = 0.9 - - X = torch.rand(batch_size, sequence_length, hidden_size) - retention = SimpleRetention(hidden_size, gamma) - - s_n_1 = torch.zeros(hidden_size, dtype=torch.complex64).unsqueeze(0).repeat(batch_size, 1, 1) - Y = [] - for i in range(sequence_length): - y_n, s_n = retention.forward_recurrent(X[:, i, :], s_n_1, i+1) - Y.append(y_n) - s_n_1 = s_n - Y = torch.stack(Y, dim=1) - self.assertEqual(Y.shape, (batch_size, sequence_length, hidden_size)) - - def test_paradigms_identical(self): - """ - check that the parallel and recurrent paradigms have identical outputs - """ - batch_size = 1 - hidden_size = 8 - sequence_length = 4 - gamma = 0.90 - - X = torch.rand(batch_size, sequence_length, hidden_size) - retention = SimpleRetention(hidden_size, gamma) - - Y_parallel = retention(X) - - s_n_1 = torch.zeros(hidden_size, hidden_size, dtype=torch.complex64).unsqueeze(0).repeat(batch_size, 1, 1) - Y_recurrent = [] - for i in range(sequence_length): - y_n, s_n = retention.forward_recurrent(X[:, i, :], s_n_1, i+1) - Y_recurrent.append(y_n) - s_n_1 = s_n - Y_recurrent = torch.stack(Y_recurrent, dim=1) - - self.assertTrue(torch.allclose(Y_parallel, Y_recurrent)) - -class TestMultiScaleRetention(unittest.TestCase): - def test_multiscale_retention_parallel(self): - batch_size = 4 - sequence_length = 5 - hidden_size = 32 - heads = 4 - retention = MultiScaleRetention(hidden_size, heads) - - X = torch.rand(batch_size, sequence_length, hidden_size) - Y = retention(X) - self.assertEqual(Y.shape, (batch_size, sequence_length, hidden_size)) - - def test_multiscale_retention_recurrent(self): - batch_size = 4 - sequence_length = 5 - hidden_size = 32 - heads = 4 - retention = MultiScaleRetention(hidden_size, heads) - - X = torch.rand(batch_size, sequence_length, hidden_size) - s_n_1s = [ - torch.zeros(hidden_size // heads, hidden_size // heads, dtype=torch.complex64).unsqueeze(0).repeat(batch_size, 1, 1) - for _ in range(heads) - ] - Y = [] - for i in range(sequence_length): - y_n, s_ns = retention.forward_recurrent(X[:, i, :], s_n_1s, i) - Y.append(y_n) - s_n_1s = s_ns - Y = torch.stack(Y, dim=1) - self.assertEqual(Y.shape, (batch_size, sequence_length, hidden_size)) - - def test_multiscale_paradigms_identical(self): - """ - check that the parallel and recurrent paradigms have identical outputs - """ - batch_size = 2 - hidden_size = 36 - sequence_length = 5 - heads = 3 - - X = torch.rand(batch_size, sequence_length, hidden_size) - retention = MultiScaleRetention(hidden_size, heads) - - Y_parallel = retention(X) - - s_n_1s = [ - torch.zeros(hidden_size // heads, hidden_size // heads, dtype=torch.complex64).unsqueeze(0).repeat(batch_size, 1, 1) - for _ in range(heads) - ] - Y_recurrent = [] - for i in range(sequence_length): - y_n, s_ns = retention.forward_recurrent(X[:, i, :], s_n_1s, i) - Y_recurrent.append(y_n) - s_n_1s = s_ns - Y_recurrent = torch.stack(Y_recurrent, dim=1) - - self.assertTrue(torch.allclose(Y_parallel, Y_recurrent)) - -unittest.main() \ No newline at end of file diff --git a/src/clm/src/retnet/complex/test_retnet.py b/src/clm/src/retnet/complex/test_retnet.py deleted file mode 100644 index a2b1d2cd..00000000 --- a/src/clm/src/retnet/complex/test_retnet.py +++ /dev/null @@ -1,102 +0,0 @@ -import unittest -import torch -from retnet import RetNet, RetNetCLM - -class TestRetNet(unittest.TestCase): - - def test_paradigms_equivalent(self): - batch_size = 2 - layers = 2 - hidden_dim = 8 - heads = 4 - sequence_length = 4 - ffn_size = 16 - - X = torch.rand(batch_size, sequence_length, hidden_dim) - - retnet = RetNet(layers, hidden_dim, ffn_size, heads) - Y_parallel = retnet(X) - - s_n_1s = [ - [ - torch.zeros(hidden_dim // heads, hidden_dim // heads, dtype=torch.complex64).unsqueeze(0).repeat(batch_size, 1, 1) - for _ in range(heads) - ] for _ in range(layers) - ] - - Y_recurrent = [] - for i in range(sequence_length): - Y, s_ns = retnet.forward_recurrent(X[:, i, :], s_n_1s, i+1) - Y_recurrent.append(Y) - s_n_1s = s_ns - - Y_recurrent = torch.stack(Y_recurrent, dim=1) - - print((Y_parallel - Y_recurrent).abs().max()) - - self.assertTrue((Y_parallel - Y_recurrent).abs().max() < 1e-4) - - def test_clm(self): - batch_size = 2 - layers = 2 - hidden_dim = 16 - heads = 4 - sequence_length = 6 - ffn_size = 32 - vocab_size = 10 - - X = torch.randint(0, vocab_size, (batch_size, sequence_length)) - - retnet = RetNetCLM(layers, hidden_dim, ffn_size, heads, vocab_size) - Y_parallel = retnet(X) - - s_n_1s = [ - [ - torch.zeros(hidden_dim // heads, hidden_dim // heads, dtype=torch.complex64).unsqueeze(0).repeat(batch_size, 1, 1) - for _ in range(heads) - ] for _ in range(layers) - ] - - Y_recurrent = [] - for i in range(sequence_length): - Y, s_ns = retnet.forward_recurrent(X[:, i], s_n_1s, i+1) - Y_recurrent.append(Y) - s_n_1s = s_ns - - Y_recurrent = torch.stack(Y_recurrent, dim=1) - - # test sample - Y_sample = retnet.sample(X, 5) - - self.assertTrue(Y_sample.shape == (batch_size, 5)) - - self.assertTrue((Y_parallel - Y_recurrent).abs().max() < 1e-4) - - def test_training(self): - batch_size = 2 - layers = 3 - hidden_dim = 16 - heads = 4 - sequence_length = 6 - ffn_size = 32 - vocab_size = 10 - bos_idx = 0 - - data = torch.randint(0, vocab_size, (batch_size, sequence_length - 1)) - X = torch.cat([torch.ones(batch_size, 1).long() * bos_idx, data[:,:-1]], dim=1) - Y = data - - # verify we can overfit autoregressive model - model = RetNetCLM(layers, hidden_dim, ffn_size, heads, vocab_size) - - optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) - criterion = torch.nn.CrossEntropyLoss() - initial_loss = criterion(model(X).reshape(-1, 10), Y.reshape(-1)) - for i in range(10): - optimizer.zero_grad() - output = model(X) - loss = criterion(output.reshape(-1, 10), Y.reshape(-1)) - loss.backward() - optimizer.step() - self.assertTrue((loss < initial_loss).item()) -unittest.main() \ No newline at end of file diff --git a/src/clm/src/retnet/complex/util.py b/src/clm/src/retnet/complex/util.py deleted file mode 100644 index f7a89da0..00000000 --- a/src/clm/src/retnet/complex/util.py +++ /dev/null @@ -1,71 +0,0 @@ -import math -import torch -import torch.nn as nn - -class ComplexGroupNorm(nn.Module): - def __init__(self, num_groups, num_channels, eps=1e-5): - super(ComplexGroupNorm, self).__init__() - self.num_groups = num_groups - self.num_channels = num_channels - self.eps = eps - self.weight = nn.Parameter(torch.ones(num_channels, dtype=torch.float32)) - self.bias = nn.Parameter(torch.zeros(num_channels, dtype=torch.float32)) - - def forward(self, X): - """ - X: (batch_size, sequence_length, hidden_size) - X is assumed to be complex - """ - X = X.reshape(-1, self.num_groups, self.num_channels // self.num_groups) - mean = X.mean(dim=2, keepdim=True) - var = X.var(dim=2, keepdim=True) - X = (X - mean) / torch.sqrt(var + self.eps) - X = X.reshape(-1, self.num_channels) - X = X * self.weight + self.bias - - return X - -class ComplexLayerNorm(nn.Module): - def __init__(self, num_channels, eps=1e-5): - super(ComplexLayerNorm, self).__init__() - self.num_channels = num_channels - self.eps = eps - self.weight = nn.Parameter(torch.ones(num_channels, dtype=torch.float32)) - self.bias = nn.Parameter(torch.zeros(num_channels, dtype=torch.float32)) - - def forward(self, X): - """ - X: unknown shape ending in hidden_size - we treat the last dimension as the hidden_size - """ - X_shape = X.shape - X = X.reshape(-1, X_shape[-1]) - mean = X.mean(dim=1, keepdim=True) - var = X.abs().var(dim=1, keepdim=True) - X = (X - mean) / torch.sqrt(var + self.eps) - X = X * self.weight + self.bias - X = X.reshape(X_shape) - return X - - -class ComplexFFN(nn.Module): - """ - 2 linear layers with no bias - """ - def __init__(self, hidden_size, ffn_size): - super(ComplexFFN, self).__init__() - self.W1 = nn.Parameter(torch.randn(hidden_size, ffn_size, dtype=torch.float32) / math.sqrt(hidden_size)) - self.W2 = nn.Parameter(torch.randn(ffn_size, hidden_size, dtype=torch.float32) / math.sqrt(ffn_size)) - self.gelu = lambda x: 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) - - def forward(self, X): - """ - X: (batch_size, sequence_length, hidden_size) - X is assumed to be complex - """ - # reshaping - X = X @ self.W1.to(X) - X = self.gelu(X) - X = X @ self.W2.to(X) - - return X diff --git a/src/clm/src/retnet/example.py b/src/clm/src/retnet/example.py deleted file mode 100644 index 0dcaaedb..00000000 --- a/src/clm/src/retnet/example.py +++ /dev/null @@ -1,17 +0,0 @@ -import torch -import torch.nn as nn - -import retnet - -if __name__ == "__main__": - # verify model size for hyperparameters in paper - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # 1.3B model - layers = 24 - hidden_dim = 2048 - ffn_size = 4096 - heads = 16 - - retnet = retnet.RetNet(layers, hidden_dim, ffn_size, heads, double_v_dim=True).to(device) - print("1.3B model:",sum(p.numel() for p in retnet.parameters() if p.requires_grad)) diff --git a/src/clm/src/retnet/retention.py b/src/clm/src/retnet/retention.py deleted file mode 100644 index e23e9e3c..00000000 --- a/src/clm/src/retnet/retention.py +++ /dev/null @@ -1,204 +0,0 @@ -import math - -import torch -import torch.nn as nn - -from clm.src.retnet.xpos_relative_position import XPOS - -class SimpleRetention(nn.Module): - def __init__(self, hidden_size, gamma, head_size=None, double_v_dim=False): - """ - Simple retention mechanism based on the paper - "Retentive Network: A Successor to Transformer for Large Language Models"[https://arxiv.org/pdf/2307.08621.pdf] - """ - super(SimpleRetention, self).__init__() - - self.hidden_size = hidden_size - if head_size is None: - head_size = hidden_size - self.head_size = head_size - - self.v_dim = head_size * 2 if double_v_dim else head_size - self.gamma = gamma - - self.W_Q = nn.Parameter(torch.randn(hidden_size, head_size) / hidden_size) - self.W_K = nn.Parameter(torch.randn(hidden_size, head_size) / hidden_size) - self.W_V = nn.Parameter(torch.randn(hidden_size, self.v_dim) / hidden_size) - - self.xpos = XPOS(head_size) - - def forward(self, X): - """ - Parallel (default) representation of the retention mechanism. - X: (batch_size, sequence_length, hidden_size) - """ - sequence_length = X.shape[1] - D = self._get_D(sequence_length).to(self.W_Q.device) - - Q = (X @ self.W_Q) - K = (X @ self.W_K) - - Q = self.xpos(Q) - K = self.xpos(K, downscale=True) - - V = X @ self.W_V - ret = (Q @ K.permute(0, 2, 1)) * D.unsqueeze(0) - - return ret @ V - - def forward_recurrent(self, x_n, s_n_1, n): - """ - Recurrent representation of the retention mechanism. - x_n: (batch_size, 1, hidden_size) - s_n_1: (batch_size, hidden_size, v_dim) - """ - - Q = (x_n @ self.W_Q) - K = (x_n @ self.W_K) - - Q = self.xpos(Q, n+1) - K = self.xpos(K, n+1, downscale=True) - - V = x_n @ self.W_V - - # K: (batch_size, 1, hidden_size) - # V: (batch_size, 1, v_dim) - # s_n = gamma * s_n_1 + K^T @ V - - s_n = self.gamma * s_n_1 + (K.transpose(-1, -2) @ V) - - return (Q @ s_n), s_n - - def forward_chunkwise(self, x_i, r_i_1, i): - """ - Chunkwise representation of the retention mechanism. - x_i: (batch_size, chunk_size, hidden_size) - r_i_1: (batch_size, hidden_size, v_dim) - """ - batch, chunk_size, _ = x_i.shape - D = self._get_D(chunk_size) - - Q = (x_i @ self.W_Q) - K = (x_i @ self.W_K) - - Q = self.xpos(Q, i * chunk_size) - K = self.xpos(K, i * chunk_size, downscale=True) - - V = x_i @ self.W_V - - r_i =(K.transpose(-1, -2) @ (V * D[-1].view(1, chunk_size, 1))) + (self.gamma ** chunk_size) * r_i_1 - - inner_chunk = ((Q @ K.transpose(-1, -2)) * D.unsqueeze(0)) @ V - - #e[i,j] = gamma ** (i+1) - e = torch.zeros(batch, chunk_size, 1) - - for _i in range(chunk_size): - e[:, _i, :] = self.gamma ** (_i + 1) - - cross_chunk = (Q @ r_i_1) * e - - return inner_chunk + cross_chunk, r_i - - def _get_D(self, sequence_length): - n = torch.arange(sequence_length).unsqueeze(1) - m = torch.arange(sequence_length).unsqueeze(0) - - # Broadcast self.gamma ** (n - m) with appropriate masking to set values where n < m to 0 - D = (self.gamma ** (n - m)) * (n >= m).float() #this results in some NaN when n is much larger than m - # fill the NaN with 0 - D[D != D] = 0 - - return D - - - -class MultiScaleRetention(nn.Module): - def __init__(self, hidden_size, heads, double_v_dim=False): - """ - Multi-scale retention mechanism based on the paper - "Retentive Network: A Successor to Transformer for Large Language Models"[https://arxiv.org/pdf/2307.08621.pdf] - """ - super(MultiScaleRetention, self).__init__() - self.hidden_size = hidden_size - self.v_dim = hidden_size * 2 if double_v_dim else hidden_size - self.heads = heads - assert hidden_size % heads == 0, "hidden_size must be divisible by heads" - self.head_size = hidden_size // heads - self.head_v_dim = hidden_size * 2 if double_v_dim else hidden_size - - self.gammas = (1 - torch.exp(torch.linspace(math.log(1/32), math.log(1/512), heads))).detach().cpu().tolist() - - self.swish = lambda x: x * torch.sigmoid(x) - self.W_G = nn.Parameter(torch.randn(hidden_size, self.v_dim) / hidden_size) - self.W_O = nn.Parameter(torch.randn(self.v_dim, hidden_size) / hidden_size) - self.group_norm = nn.GroupNorm(heads, self.v_dim) - - self.retentions = nn.ModuleList([ - SimpleRetention(self.hidden_size, gamma, self.head_size, double_v_dim) for gamma in self.gammas - ]) - - def forward(self, X): - """ - parallel representation of the multi-scale retention mechanism - """ - - # apply each individual retention mechanism to X - Y = [] - for i in range(self.heads): - Y.append(self.retentions[i](X)) - - Y = torch.cat(Y, dim=2) - Y_shape = Y.shape - Y = self.group_norm(Y.reshape(-1, self.v_dim)).reshape(Y_shape) - - return (self.swish(X @ self.W_G) * Y) @ self.W_O - - def forward_recurrent(self, x_n, s_n_1s, n): - """ - recurrent representation of the multi-scale retention mechanism - x_n: (batch_size, 1, hidden_size) - s_n_1s: (batch_size, heads, head_size, head_size) - - """ - - # apply each individual retention mechanism to a slice of X - Y = [] - s_ns = [] - for i in range(self.heads): - y, s_n = self.retentions[i].forward_recurrent( - x_n[:, :, :], s_n_1s[i], n - ) - Y.append(y) - s_ns.append(s_n) - - Y = torch.cat(Y, dim=2) - Y_shape = Y.shape - Y = self.group_norm(Y.reshape(-1, self.v_dim)).reshape(Y_shape) - - return (self.swish(x_n @ self.W_G) * Y) @ self.W_O, s_ns - - def forward_chunkwise(self, x_i, r_i_1s, i): - """ - chunkwise representation of the multi-scale retention mechanism - x_i: (batch_size, chunk_size, hidden_size) - r_i_1s: (batch_size, heads, head_size, head_size) - """ - batch, chunk_size, _ = x_i.shape - - # apply each individual retention mechanism to a slice of X - Y = [] - r_is = [] - for j in range(self.heads): - y, r_i = self.retentions[j].forward_chunkwise( - x_i[:, :, :], r_i_1s[j], i - ) - Y.append(y) - r_is.append(r_i) - - - Y = torch.cat(Y, dim=2) - Y_shape = Y.shape - Y = self.group_norm(Y.reshape(-1, self.v_dim)).reshape(Y_shape) - - return (self.swish(x_i @ self.W_G) * Y) @ self.W_O, r_is diff --git a/src/clm/src/retnet/retnet.py b/src/clm/src/retnet/retnet.py deleted file mode 100644 index dced11ec..00000000 --- a/src/clm/src/retnet/retnet.py +++ /dev/null @@ -1,76 +0,0 @@ -import torch -import torch.nn as nn - -from clm.src.retnet.retention import MultiScaleRetention - -class RetNet(nn.Module): - def __init__(self, layers, hidden_dim, ffn_size, heads, double_v_dim=False): - super(RetNet, self).__init__() - self.layers = layers - self.hidden_dim = hidden_dim - self.ffn_size = ffn_size - self.heads = heads - self.v_dim = hidden_dim * 2 if double_v_dim else hidden_dim - - self.retentions = nn.ModuleList([ - MultiScaleRetention(hidden_dim, heads, double_v_dim) - for _ in range(layers) - ]) - self.ffns = nn.ModuleList([ - nn.Sequential( - nn.Linear(hidden_dim, ffn_size), - nn.GELU(), - nn.Linear(ffn_size, hidden_dim) - ) - for _ in range(layers) - ]) - self.layer_norms_1 = nn.ModuleList([ - nn.LayerNorm(hidden_dim) - for _ in range(layers) - ]) - self.layer_norms_2 = nn.ModuleList([ - nn.LayerNorm(hidden_dim) - for _ in range(layers) - ]) - - def forward(self, X): - """ - X: (batch_size, sequence_length, hidden_size) - """ - for i in range(self.layers): - Y = self.retentions[i](self.layer_norms_1[i](X)) + X - - X = self.ffns[i](self.layer_norms_2[i](Y)) + Y - - return X - - def forward_recurrent(self, x_n, s_n_1s, n): - """ - X: (batch_size, sequence_length, hidden_size) - s_n_1s: list of lists of tensors of shape (batch_size, hidden_size // heads, hidden_size // heads) - - """ - s_ns = [] - for i in range(self.layers): - # list index out of range - o_n, s_n = self.retentions[i].forward_recurrent(self.layer_norms_1[i](x_n), s_n_1s[i], n) - y_n = o_n + x_n - s_ns.append(s_n) - x_n = self.ffns[i](self.layer_norms_2[i](y_n)) + y_n - - return x_n, s_ns - - def forward_chunkwise(self, x_i, r_i_1s, i): - """ - X: (batch_size, sequence_length, hidden_size) - r_i_1s: list of lists of tensors of shape (batch_size, hidden_size // heads, hidden_size // heads) - - """ - r_is = [] - for j in range(self.layers): - o_i, r_i = self.retentions[j].forward_chunkwise(self.layer_norms_1[j](x_i), r_i_1s[j], i) - y_i = o_i + x_i - r_is.append(r_i) - x_i = self.ffns[j](self.layer_norms_2[j](y_i)) + y_i - - return x_i, r_is diff --git a/src/clm/src/retnet/tests.py b/src/clm/src/retnet/tests.py deleted file mode 100644 index 44c8fc6d..00000000 --- a/src/clm/src/retnet/tests.py +++ /dev/null @@ -1,154 +0,0 @@ -import unittest - -import torch - -from clm.src.retnet.retention import SimpleRetention, MultiScaleRetention -from clm.src.retnet.retnet import RetNet - -class TestRetention(unittest.TestCase): - - def test_simple(self): - """ - verify that the three implementations of SimpleRetention are identical - """ - batch_size = 4 - sequence_length = 12 - hidden_size = 6 - chunk_size = 4 - - gamma = 0.9 - - X = torch.rand(batch_size, sequence_length, hidden_size) - sr = SimpleRetention(hidden_size, gamma, double_v_dim=True) - - Y_parallel = sr(X) - - s_n_1 = torch.zeros(hidden_size, sr.v_dim).unsqueeze(0).repeat(batch_size, 1, 1) - Y_recurrent = [] - for i in range(sequence_length): - y_n, s_n = sr.forward_recurrent(X[:, i:i+1, :], s_n_1, i) - Y_recurrent.append(y_n) - s_n_1 = s_n - - Y_recurrent = torch.concat(Y_recurrent, dim=1) - - r_n_1 = torch.zeros(hidden_size, sr.v_dim).unsqueeze(0).repeat(batch_size, 1, 1) - Y_chunkwise = [] - for i in range(sequence_length // chunk_size): - y_i, r_i = sr.forward_chunkwise(X[:, i*chunk_size:(i+1)*chunk_size, :], r_n_1, i) - Y_chunkwise.append(y_i) - r_n_1 = r_i - - - Y_chunkwise = torch.concat(Y_chunkwise, dim=1) - - - assert torch.allclose(Y_parallel, Y_recurrent, atol=1e-5) - assert torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5) - - - def test_multiscale(self): - """ - verify that the three implementations of MultiScaleRetention are identical - """ - batch_size = 2 - hidden_size = 6 - sequence_length = 12 - heads = 3 - chunk_size = 2 - - X = torch.rand(batch_size, sequence_length, hidden_size) - retention = MultiScaleRetention(hidden_size, heads, double_v_dim=False) - # print total number of parameters - print("Default v_dim:",sum(p.numel() for p in retention.parameters() if p.requires_grad)) - - retention = MultiScaleRetention(hidden_size, heads, double_v_dim=True) - print("Double v_dim:",sum(p.numel() for p in retention.parameters() if p.requires_grad)) - - Y_parallel = retention(X) - - s_n_1s = [ - torch.zeros(hidden_size // heads, retention.v_dim // heads).unsqueeze(0).repeat(batch_size, 1, 1) - for _ in range(heads) - ] - Y_recurrent = [] - for i in range(sequence_length): - y_n, s_ns = retention.forward_recurrent(X[:, i:i+1, :], s_n_1s, i) - Y_recurrent.append(y_n) - s_n_1s = s_ns - - Y_recurrent = torch.concat(Y_recurrent, dim=1) - - r_n_1s = [ - torch.zeros(hidden_size // heads, retention.v_dim // heads).unsqueeze(0).repeat(batch_size, 1, 1) - for _ in range(heads) - ] - Y_chunkwise = [] - for i in range(sequence_length // chunk_size): - y_i, r_i = retention.forward_chunkwise(X[:, i*chunk_size:(i+1)*chunk_size, :], r_n_1s, i) - Y_chunkwise.append(y_i) - r_n_1s = r_i - - Y_chunkwise = torch.concat(Y_chunkwise, dim=1) - - self.assertTrue(torch.allclose(Y_parallel, Y_recurrent, atol=1e-5)) - self.assertTrue(torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5)) # fails - -class TestRetNet(unittest.TestCase): - - def test_retnet(self): - """ - verify that the three implementations of RetNet are identical - """ - batch_size = 2 - hidden_size = 36 - sequence_length = 5 - heads = 3 - layers = 4 - ffn_size = 128 - - X = torch.rand(batch_size, sequence_length, hidden_size) - retnet = RetNet(layers, hidden_size, ffn_size, heads, double_v_dim=False) - # print total number of parameters - print("Default v_dim:",sum(p.numel() for p in retnet.parameters() if p.requires_grad)) - - retnet = RetNet(layers, hidden_size, ffn_size, heads, double_v_dim=True) - print("Double v_dim:",sum(p.numel() for p in retnet.parameters() if p.requires_grad)) - - Y_parallel = retnet(X) - - s_n_1s = [ - [ - torch.zeros(hidden_size // heads, retnet.v_dim // heads).unsqueeze(0).repeat(batch_size, 1, 1) - for _ in range(heads) - ] - for _ in range(layers) - ] - Y_recurrent = [] - for i in range(sequence_length): - y_n, s_ns = retnet.forward_recurrent(X[:, i:i+1, :], s_n_1s, i) - Y_recurrent.append(y_n) - s_n_1s = s_ns - - Y_recurrent = torch.concat(Y_recurrent, dim=1) - - r_n_1s = [ - [ - torch.zeros(hidden_size // heads, retnet.v_dim // heads).unsqueeze(0).repeat(batch_size, 1, 1) - for _ in range(heads) - ] - for _ in range(layers) - ] - Y_chunkwise = [] - for i in range(sequence_length): - y_i, r_i = retnet.forward_chunkwise(X[:, i:i+1, :], r_n_1s, i) - Y_chunkwise.append(y_i) - r_n_1s = r_i - - Y_chunkwise = torch.concat(Y_chunkwise, dim=1) - - self.assertTrue(torch.allclose(Y_parallel, Y_recurrent, atol=1e-5)) - self.assertTrue(torch.allclose(Y_parallel, Y_chunkwise, atol=1e-5)) - -if __name__ == "__main__": - unittest.main() diff --git a/src/clm/src/retnet/xpos_relative_position.py b/src/clm/src/retnet/xpos_relative_position.py deleted file mode 100644 index 1c445e5a..00000000 --- a/src/clm/src/retnet/xpos_relative_position.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) 2022 Microsoft -# Licensed under The MIT License (https://github.com/microsoft/torchscale/blob/main/LICENSE) -import torch -import torch.nn as nn - -def fixed_pos_embedding(x): - seq_len, dim = x.shape - inv_freq = 1.0 / (10000 ** (torch.arange(0, dim) / dim)) - sinusoid_inp = ( - torch.einsum("i , j -> i j", torch.arange(0, seq_len, dtype=torch.float), inv_freq).to(x) - ) - return torch.sin(sinusoid_inp), torch.cos(sinusoid_inp) - -def rotate_every_two(x): - x1 = x[:, :, ::2] - x2 = x[:, :, 1::2] - x = torch.stack((-x2, x1), dim=-1) - if x.shape[-1]%2 == 1: - # fill last dim with zero if hidden_size is odd - x2 = torch.concat((x2, torch.zeros_like(x2[:, :, :1])), dim=-1) - return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\ - -def duplicate_interleave(m): - """ - A simple version of `torch.repeat_interleave` for duplicating a matrix while interleaving the copy. - """ - dim0 = m.shape[0] - m = m.view(-1, 1) # flatten the matrix - m = m.repeat(1, 2) # repeat all elements into the 2nd dimension - m = m.view(dim0, -1) # reshape into a matrix, interleaving the copy - return m - -def apply_rotary_pos_emb(x, sin, cos, scale=1): - sin, cos = map(lambda t: duplicate_interleave(t * scale), (sin, cos)) - # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) - return (x * cos[:, :x.shape[-1]]) + (rotate_every_two(x) * sin)[:, :, :x.shape[-1]] - - -class XPOS(nn.Module): - def __init__( - self, head_dim, scale_base=512 - ): - super().__init__() - self.head_dim = head_dim - self.scale_base = scale_base - self.register_buffer( - "scale", (torch.arange(0, head_dim, 2) + 0.4 * head_dim) / (1.4 * head_dim) - ) - - def forward(self, x, offset=0, downscale=False): - length = x.shape[1] - min_pos = 0 - max_pos = length + offset + min_pos - scale = self.scale ** torch.arange(min_pos, max_pos, 1).to(self.scale).div(self.scale_base)[:, None] - sin, cos = fixed_pos_embedding(scale) - - if scale.shape[0] > length: - scale = scale[-length:] - sin = sin[-length:] - cos = cos[-length:] - - if downscale: - scale = 1 / scale - - x = apply_rotary_pos_emb(x, sin, cos, scale) - return x - - def forward_reverse(self, x, offset=0, downscale=False): - length = x.shape[1] - min_pos = -(length + offset) // 2 - max_pos = length + offset + min_pos - scale = self.scale ** torch.arange(min_pos, max_pos, 1).to(self.scale).div(self.scale_base)[:, None] - sin, cos = fixed_pos_embedding(scale) - - if scale.shape[0] > length: - scale = scale[-length:] - sin = sin[-length:] - cos = cos[-length:] - - if downscale: - scale = 1 / scale - - x = apply_rotary_pos_emb(x, -sin, cos, scale) - return x - -# test -if __name__ == "__main__": - x = torch.eye(4).unsqueeze(0) - xpos = XPOS(4) - x_rot = xpos(x) - # apply reverse - x_rot_rev = xpos.forward(x) - - print(x_rot @ x_rot_rev.transpose(-1, -2)) \ No newline at end of file diff --git a/src/clm/src/tasks/decoders.py b/src/clm/src/tasks/decoders.py deleted file mode 100644 index f95e5005..00000000 --- a/src/clm/src/tasks/decoders.py +++ /dev/null @@ -1,319 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange, reduce - -import clm.src.models.nn.utils as U -import clm.src.utils as utils -import clm.src.utils.config -import clm.src.utils.train - -log = clm.src.utils.train.get_logger(__name__) - - -class Decoder(nn.Module): - """This class doesn't do much but just signals the interface that Decoders are expected to adhere to - TODO: is there a way to enforce the signature of the forward method? - """ - - def forward(self, x, **kwargs): - """ - x: (batch, length, dim) input tensor - state: additional state from the model backbone - *args, **kwargs: additional info from the dataset - - Returns: - y: output tensor - *args: other arguments to pass into the loss function - """ - return x - - def step(self, x): - """ - x: (batch, dim) - """ - return self.forward(x.unsqueeze(1)).squeeze(1) - - -class SequenceDecoder(Decoder): - def __init__( - self, d_model, d_output=None, l_output=None, use_lengths=False, mode="last" - ): - super().__init__() - - self.output_transform = nn.Identity() if d_output is None else nn.Linear(d_model, d_output) - - if l_output is None: - self.l_output = None - self.squeeze = False - elif l_output == 0: - # Equivalent to getting an output of length 1 and then squeezing - self.l_output = 1 - self.squeeze = True - else: - assert l_output > 0 - self.l_output = l_output - self.squeeze = False - - self.use_lengths = use_lengths - self.mode = mode - - if mode == 'ragged': - assert not use_lengths - - def forward(self, x, state=None, lengths=None, l_output=None): - """ - x: (n_batch, l_seq, d_model) - Returns: (n_batch, l_output, d_output) - """ - - if self.l_output is None: - if l_output is not None: - assert isinstance(l_output, int) # Override by pass in - else: - # Grab entire output - l_output = x.size(-2) - squeeze = False - else: - l_output = self.l_output - squeeze = self.squeeze - - if self.mode == "last": - restrict = lambda x: x[..., -l_output:, :] - elif self.mode == "first": - restrict = lambda x: x[..., :l_output, :] - elif self.mode == "pool": - restrict = lambda x: ( - torch.cumsum(x, dim=-2) - / torch.arange( - 1, 1 + x.size(-2), device=x.device, dtype=x.dtype - ).unsqueeze(-1) - )[..., -l_output:, :] - - def restrict(x): - L = x.size(-2) - s = x.sum(dim=-2, keepdim=True) - if l_output > 1: - c = torch.cumsum(x[..., -(l_output - 1) :, :].flip(-2), dim=-2) - c = F.pad(c, (0, 0, 1, 0)) - s = s - c # (B, l_output, D) - s = s.flip(-2) - denom = torch.arange( - L - l_output + 1, L + 1, dtype=x.dtype, device=x.device - ) - s = s / denom - return s - - elif self.mode == "sum": - restrict = lambda x: torch.cumsum(x, dim=-2)[..., -l_output:, :] - # TODO use same restrict function as pool case - elif self.mode == 'ragged': - assert lengths is not None, "lengths must be provided for ragged mode" - # remove any additional padding (beyond max length of any sequence in the batch) - restrict = lambda x: x[..., : max(lengths), :] - else: - raise NotImplementedError( - "Mode must be ['last' | 'first' | 'pool' | 'sum']" - ) - - # Restrict to actual length of sequence - if self.use_lengths: - assert lengths is not None - x = torch.stack( - [ - restrict(out[..., :length, :]) - for out, length in zip(torch.unbind(x, dim=0), lengths) - ], - dim=0, - ) - else: - x = restrict(x) - - if squeeze: - assert x.size(-2) == 1 - x = x.squeeze(-2) - - x = self.output_transform(x) - - return x - - def step(self, x, state=None): - # Ignore all length logic - return self.output_transform(x) - -class NDDecoder(Decoder): - """Decoder for single target (e.g. classification or regression)""" - def __init__( - self, d_model, d_output=None, mode="pool" - ): - super().__init__() - - assert mode in ["pool", "full"] - self.output_transform = nn.Identity() if d_output is None else nn.Linear(d_model, d_output) - - self.mode = mode - - def forward(self, x, state=None): - """ - x: (n_batch, l_seq, d_model) - Returns: (n_batch, l_output, d_output) - """ - - if self.mode == 'pool': - x = reduce(x, 'b ... h -> b h', 'mean') - x = self.output_transform(x) - return x - -class StateDecoder(Decoder): - """Use the output state to decode (useful for stateful models such as RNNs or perhaps Transformer-XL if it gets implemented""" - - def __init__(self, d_model, state_to_tensor, d_output): - super().__init__() - self.output_transform = nn.Linear(d_model, d_output) - self.state_transform = state_to_tensor - - def forward(self, x, state=None): - return self.output_transform(self.state_transform(state)) - - -class RetrievalHead(nn.Module): - def __init__(self, d_input, d_model, n_classes, nli=True, activation="relu"): - super().__init__() - self.nli = nli - - if activation == "relu": - activation_fn = nn.ReLU() - elif activation == "gelu": - activation_fn = nn.GELU() - else: - raise NotImplementedError - - if ( - self.nli - ): # Architecture from https://github.com/mlpen/Nystromformer/blob/6539b895fa5f798ea0509d19f336d4be787b5708/reorganized_code/LRA/model_wrapper.py#L74 - self.classifier = nn.Sequential( - nn.Linear(4 * d_input, d_model), - activation_fn, - nn.Linear(d_model, n_classes), - ) - else: # Head from https://github.com/google-research/long-range-arena/blob/ad0ff01a5b3492ade621553a1caae383b347e0c1/lra_benchmarks/models/layers/common_layers.py#L232 - self.classifier = nn.Sequential( - nn.Linear(2 * d_input, d_model), - activation_fn, - nn.Linear(d_model, d_model // 2), - activation_fn, - nn.Linear(d_model // 2, n_classes), - ) - - def forward(self, x): - """ - x: (2*batch, dim) - """ - outs = rearrange(x, "(z b) d -> z b d", z=2) - outs0, outs1 = outs[0], outs[1] # (n_batch, d_input) - if self.nli: - features = torch.cat( - [outs0, outs1, outs0 - outs1, outs0 * outs1], dim=-1 - ) # (batch, dim) - else: - features = torch.cat([outs0, outs1], dim=-1) # (batch, dim) - logits = self.classifier(features) - return logits - - -class RetrievalDecoder(Decoder): - """Combines the standard FeatureDecoder to extract a feature before passing through the RetrievalHead""" - - def __init__( - self, - d_input, - n_classes, - d_model=None, - nli=True, - activation="relu", - *args, - **kwargs - ): - super().__init__() - if d_model is None: - d_model = d_input - self.feature = SequenceDecoder( - d_input, d_output=None, l_output=0, *args, **kwargs - ) - self.retrieval = RetrievalHead( - d_input, d_model, n_classes, nli=nli, activation=activation - ) - - def forward(self, x, state=None, **kwargs): - x = self.feature(x, state=state, **kwargs) - x = self.retrieval(x) - return x - -class PackedDecoder(Decoder): - def forward(self, x, state=None): - x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True) - return x - - -# For every type of encoder/decoder, specify: -# - constructor class -# - list of attributes to grab from dataset -# - list of attributes to grab from model - -registry = { - "stop": Decoder, - "id": nn.Identity, - "linear": nn.Linear, - "sequence": SequenceDecoder, - "nd": NDDecoder, - "retrieval": RetrievalDecoder, - "state": StateDecoder, - "pack": PackedDecoder, -} -model_attrs = { - "linear": ["d_output"], - "sequence": ["d_output"], - "nd": ["d_output"], - "retrieval": ["d_output"], - "state": ["d_state", "state_to_tensor"], - "forecast": ["d_output"], -} - -dataset_attrs = { - "linear": ["d_output"], - "sequence": ["d_output", "l_output"], - "nd": ["d_output"], - "retrieval": ["d_output"], - "state": ["d_output"], - "forecast": ["d_output", "l_output"], -} - - -def _instantiate(decoder, model=None, dataset=None): - """Instantiate a single decoder""" - if decoder is None: - return None - - if isinstance(decoder, str): - name = decoder - else: - name = decoder["_name_"] - - # Extract arguments from attribute names - dataset_args = utils.config.extract_attrs_from_obj( - dataset, *dataset_attrs.get(name, []) - ) - model_args = utils.config.extract_attrs_from_obj(model, *model_attrs.get(name, [])) - # Instantiate decoder - obj = utils.instantiate(registry, decoder, *model_args, *dataset_args) - return obj - - -def instantiate(decoder, model=None, dataset=None): - """Instantiate a full decoder config, e.g. handle list of configs - Note that arguments are added in reverse order compared to encoder (model first, then dataset) - """ - decoder = utils.to_list(decoder) - return U.PassthroughSequential( - *[_instantiate(d, model=model, dataset=dataset) for d in decoder] - ) diff --git a/src/clm/src/tasks/encoders.py b/src/clm/src/tasks/encoders.py deleted file mode 100644 index e6eac313..00000000 --- a/src/clm/src/tasks/encoders.py +++ /dev/null @@ -1,358 +0,0 @@ -import datetime -import math -from typing import ForwardRef - -import torch -from torch import nn -import torch.nn.functional as F -from einops import rearrange, repeat - -import clm.src.models.nn.utils as U -import clm.src.utils as utils -import clm.src.utils.config -from clm.src.models.sequence.block import SequenceResidualBlock -from clm.src.models.nn.components import Normalization - -class Encoder(nn.Module): - """Encoder abstraction - Accepts a tensor and optional kwargs. Outside of the main tensor, all other arguments should be kwargs. - Returns a tensor and optional kwargs. - Encoders are combined via U.PassthroughSequential which passes these kwargs through in a pipeline. The resulting kwargs are accumulated and passed into the model backbone. - - """ - - def forward(self, x, **kwargs): - """ - x: input tensor - *args: additional info from the dataset (e.g. sequence lengths) - - Returns: - y: output tensor - *args: other arguments to pass into the model backbone - """ - return x, {} - -class PositionalIDEncoder(Encoder): - def forward(self, x): - position_ids = torch.arange(x.shape[-1], dtype=torch.long, device=x.device) - position_ids = repeat(position_ids, 'l -> b l', b=x.shape[0]) - return x, { 'position_ids': position_ids } - -# Adapted from https://github.com/pytorch/examples/blob/master/word_language_model/model.py -class PositionalEncoder(Encoder): - r"""Inject some information about the relative or absolute position of the tokens - in the sequence. The positional encodings have the same dimension as - the embeddings, so that the two can be summed. Here, we use sine and cosine - functions of different frequencies. - .. math:: - \text{PosEncoder}(pos, 2i) = sin(pos/10000^(2i/d_model)) - \text{PosEncoder}(pos, 2i+1) = cos(pos/10000^(2i/d_model)) - \text{where pos is the word position and i is the embed idx) - Args: - d_model: the embed dim (required). - dropout: the dropout value (default=0.1). - max_len: the max. length of the incoming sequence (default=5000). - Examples: - >>> pos_encoder = PositionalEncoder(d_model) - """ - - def __init__(self, d_model, dropout=0.1, max_len=16384, pe_init=None): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - if pe_init is not None: - self.pe = nn.Parameter(torch.empty(max_len, 1, d_model)) - nn.init.normal_(self.pe, 0, pe_init) - # self.pe = pe.unsqueeze(1) - else: - pe = torch.zeros(max_len, d_model) - position = torch.arange(0.0, max_len).unsqueeze(1) - div_term = torch.exp( - -math.log(10000.0) * torch.arange(0.0, d_model, 2.0) / d_model - ) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - self.register_buffer("pe", pe) - - self.attn_mask = None - - def forward(self, x): - r"""Inputs of forward function - Args: - x: the sequence fed to the positional encoder model (required). - lens: actual lengths of sequences - Shape: - x: [l_sequence, n_batch, d_model] - Returns: [l_sequence, n_batch, d_model] - attn_mask: [l_sequence, l_sequence] - padding_mask: - """ - - x = x + self.pe[: x.size(-2)] - return self.dropout(x) - - -class ClassEmbedding(Encoder): - # Should also be able to define this by subclassing Embedding - def __init__(self, n_classes, d_model): - super().__init__() - self.embedding = nn.Embedding(n_classes, d_model) - - def forward(self, x, y): - x = x + self.embedding(y).unsqueeze(-2) # (B, L, D) - return x - - -class Conv1DEncoder(Encoder): - def __init__(self, d_input, d_model, kernel_size=25, stride=1, padding='same'): - super().__init__() - self.conv = nn.Conv1d( - in_channels=d_input, - out_channels=d_model, - kernel_size=kernel_size, - stride=stride, - padding=padding, - ) - - def forward(self, x): - # BLD -> BLD - x = self.conv(x.transpose(1, 2)).transpose(1, 2) - return x - -class LayerEncoder(Encoder): - """Use an arbitary SequenceModule layer""" - - def __init__(self, d_model, prenorm=False, norm='layer', layer=None): - super().__init__() - - # Simple stack of blocks - layer["transposed"] = False - self.layer = SequenceResidualBlock( - d_input=d_model, - prenorm=prenorm, - layer=layer, - residual='R', - norm=norm, - pool=None, - ) - - def forward(self, x): - x, _ = self.layer(x) # Discard state - return x - - -class TimestampEmbeddingEncoder(Encoder): - """ - General time encoder for Pandas Timestamp objects (encoded as torch tensors). - See MonashDataset for an example of how to return time features as 'z's. - """ - - cardinalities = { - 'day': (1, 31), - 'hour': (0, 23), - 'minute': (0, 59), - 'second': (0, 59), - 'month': (1, 12), - 'year': (1950, 2010), # (1800, 3000) used to be (1970, datetime.datetime.now().year + 1) but was not enough for all datasets in monash - 'dayofweek': (0, 6), - 'dayofyear': (1, 366), - 'quarter': (1, 4), - 'week': (1, 53), - 'is_month_start': (0, 1), - 'is_month_end': (0, 1), - 'is_quarter_start': (0, 1), - 'is_quarter_end': (0, 1), - 'is_year_start': (0, 1), - 'is_year_end': (0, 1), - 'is_leap_year': (0, 1), - } - - def __init__(self, d_model, table=False, features=None): - super().__init__() - self.table = table - self.ranges = {k: max_val - min_val + 2 for k, (min_val, max_val) in self.cardinalities.items()} # padding for null included - - if features is None: - pass - else: - self.cardinalities = {k: v for k, v in self.cardinalities.items() if k in features} - - if table: - self.embedding = nn.ModuleDict({ - attr: nn.Embedding(maxval - minval + 2, d_model, padding_idx=0) - for attr, (minval, maxval) in self.cardinalities.items() - }) - else: - self.embedding = nn.ModuleDict({ - attr: nn.Linear(1, d_model) - for attr in self.cardinalities - }) - - - - def forward(self, x, timestamps=None): - for attr in timestamps: - mask = timestamps[attr] == -1 - timestamps[attr] = timestamps[attr] - self.cardinalities[attr][0] - timestamps[attr][mask] = 0 - if self.table: - x = x + self.embedding[attr](timestamps[attr].to(torch.long)) - else: - x = x + self.embedding[attr]((2 * timestamps[attr] / self.ranges[attr] - 1).unsqueeze(-1)) - - #x = x + self.embedding(timestamps[attr].to(torch.float)).unsqueeze(1) - return x - - -class TimeEncoder(Encoder): - def __init__(self, n_tokens_time, d_model, timeenc=0): - super().__init__() - - self.timeenc = timeenc - if self.timeenc == 0: - self.encoders = nn.ModuleList( - [nn.Embedding(v, d_model) for v in n_tokens_time] - ) - else: - self.encoders = nn.Linear(len(n_tokens_time), d_model) - self.mask_embed = nn.Embedding(2, d_model) - - def forward(self, x, mark=None, mask=None): - assert mark is not None and mask is not None, "Extra arguments should be returned by collate function" - if self.timeenc == 0: - assert mark.size(-1) == len(self.encoders) - embeddings = [ - embed(z) for embed, z in zip(self.encoders, torch.unbind(mark, dim=-1)) - ] - time_encode = torch.sum(torch.stack(embeddings), dim=0) - else: - time_encode = self.encoders(mark) - mask_encode = self.mask_embed(mask.squeeze(-1)) - return x + time_encode + mask_encode # (B, L, d_model) - - -class PackedEncoder(Encoder): - def forward(self, x, len_batch=None): - assert len_batch is not None - x = nn.utils.rnn.pack_padded_sequence( - x, len_batch.cpu(), enforce_sorted=False, batch_first=True, - ) - return x - - -class OneHotEncoder(Encoder): - def __init__(self, n_tokens, d_model): - super().__init__() - assert n_tokens <= d_model - self.d_model = d_model - - def forward(self, x): - return F.one_hot(x.squeeze(-1), self.d_model).float() - - -class Conv2DPatchEncoder(Encoder): - - """ - For encoding images into a sequence of patches. - """ - - def __init__(self, d_input, d_model, filter_sizes, flat=False): - - """ - d_input: dim of encoder input (data dimension) - d_model: dim of encoder output (model dimension) - filter_sizes: tuple with fh, fw - flat: if image is flattened from dataloader (like in cifar), - then we need to reshape back to 2D before conv - """ - - fh, fw = filter_sizes - self.flat = flat - - super().__init__() - assert len(filter_sizes) == 2 - - self.encoder = nn.Conv2d(d_input, d_model, kernel_size=(fh, fw), stride=(fh, fw)) - - def forward(self, x): - - """ - x shape expected = [b, h, w, c] - returns tuple with x, with new shape = [b, seq_len, c_out] - - """ - - x = rearrange(x, 'b h w c -> b c h w') - x = self.encoder(x) - x = rearrange(x, 'b c h w -> b (h w) c') - return x - - -# For every type of encoder/decoder, specify: -# - constructor class -# - list of attributes to grab from dataset -# - list of attributes to grab from model - -registry = { - "stop": Encoder, - "id": nn.Identity, - "embedding": nn.Embedding, - "linear": nn.Linear, - "position": PositionalEncoder, - "position_id": PositionalIDEncoder, - "class": ClassEmbedding, - "pack": PackedEncoder, - "time": TimeEncoder, - "onehot": OneHotEncoder, - "conv1d": Conv1DEncoder, - "patch2d": Conv2DPatchEncoder, - "timestamp_embedding": TimestampEmbeddingEncoder, - "layer": LayerEncoder, -} -dataset_attrs = { - "embedding": ["n_tokens"], - "linear": ["d_input"], # TODO make this d_data? - "class": ["n_classes"], - "time": ["n_tokens_time"], - "onehot": ["n_tokens"], - "conv1d": ["d_input"], - "patch2d": ["d_input"], -} -model_attrs = { - "embedding": ["d_model"], - "linear": ["d_model"], - "position": ["d_model"], - "class": ["d_model"], - "time": ["d_model"], - "onehot": ["d_model"], - "conv1d": ["d_model"], - "patch2d": ["d_model"], - "timestamp_embedding": ["d_model"], - "layer": ["d_model"], -} - - -def _instantiate(encoder, dataset=None, model=None): - """Instantiate a single encoder""" - if encoder is None: - return None - if isinstance(encoder, str): - name = encoder - else: - name = encoder["_name_"] - - # Extract dataset/model arguments from attribute names - dataset_args = utils.config.extract_attrs_from_obj( - dataset, *dataset_attrs.get(name, []) - ) - model_args = utils.config.extract_attrs_from_obj(model, *model_attrs.get(name, [])) - - # Instantiate encoder - obj = utils.instantiate(registry, encoder, *dataset_args, *model_args) - return obj - - -def instantiate(encoder, dataset=None, model=None): - encoder = utils.to_list(encoder) - return U.PassthroughSequential( - *[_instantiate(e, dataset=dataset, model=model) for e in encoder] - ) diff --git a/src/clm/src/tasks/metrics.py b/src/clm/src/tasks/metrics.py deleted file mode 100644 index 234547e0..00000000 --- a/src/clm/src/tasks/metrics.py +++ /dev/null @@ -1,225 +0,0 @@ -import math -import torch -import torch.nn.functional as F -from sklearn.metrics import f1_score, roc_auc_score -from functools import partial -import torchmetrics.functional as tm_f - -def _student_t_map(mu, sigma, nu): - sigma = F.softplus(sigma) - nu = 2.0 + F.softplus(nu) - return mu.squeeze(axis=-1), sigma.squeeze(axis=-1), nu.squeeze(axis=-1) - -def student_t_loss(outs, y): - mu, sigma, nu = outs[..., 0], outs[..., 1], outs[..., 2] - mu, sigma, nu = _student_t_map(mu, sigma, nu) - y = y.squeeze(axis=-1) - - nup1_half = (nu + 1.0) / 2.0 - part1 = 1.0 / nu * torch.square((y - mu) / sigma) - Z = ( - torch.lgamma(nup1_half) - - torch.lgamma(nu / 2.0) - - 0.5 * torch.log(math.pi * nu) - - torch.log(sigma) - ) - - ll = Z - nup1_half * torch.log1p(part1) - return -ll.mean() - -def gaussian_ll_loss(outs, y): - mu, sigma = outs[..., 0], outs[..., 1] - y = y.squeeze(axis=-1) - sigma = F.softplus(sigma) - ll = -1.0 * ( - torch.log(sigma) - + 0.5 * math.log(2 * math.pi) - + 0.5 * torch.square((y - mu) / sigma) - ) - return -ll.mean() - -def binary_cross_entropy(logits, y): - # BCE loss requires squeezing last dimension of logits so it has the same shape as y - # requires y to be float, since it's overloaded to represent a probability - return F.binary_cross_entropy_with_logits(logits.squeeze(-1), y.float()) - - -def binary_accuracy(logits, y): - return torch.eq(logits.squeeze(-1) >= 0, y).float().mean() - - -def cross_entropy(logits, y): - logits = logits.view(-1, logits.shape[-1]) - y = y.view(-1) - return F.cross_entropy(logits, y) - - -def soft_cross_entropy(logits, y, label_smoothing=0.0): - logits = logits.view(-1, logits.shape[-1]) - # target is now 2d (no target flattening) - return F.cross_entropy(logits, y, label_smoothing=label_smoothing) - - -def accuracy(logits, y): - logits = logits.view(-1, logits.shape[-1]) - if y.numel() > logits.shape[0]: - # Mixup leads to this case: use argmax class - y = y.argmax(dim=-1) - y = y.view(-1) - return torch.eq(torch.argmax(logits, dim=-1), y).float().mean() - -def accuracy_ignore_index(logits, y, ignore_index=-100): - num_classes = logits.shape[-1] - preds = torch.argmax(logits, dim=-1) - logits = logits.view(-1, logits.shape[-1]) - y = y.view(-1) - return tm_f.classification.accuracy(preds, y, 'multiclass', num_classes=num_classes, ignore_index=ignore_index, average='micro') - - -def accuracy_at_k(logits, y, k=1): - logits = logits.view(-1, logits.shape[-1]) - if y.numel() > logits.shape[0]: - # Mixup leads to this case: use argmax class - y = y.argmax(dim=-1) - y = y.view(-1) - return torch.topk(logits, k, dim=-1)[1].eq(y.unsqueeze(-1)).any(dim=-1).float().mean() - - -def f1_binary(logits, y): - logits = logits.view(-1, logits.shape[-1]) - y = y.view(-1) - y_hat = torch.argmax(logits, dim=-1) - return f1_score(y.cpu().numpy(), y_hat.cpu().numpy(), average="binary") - - -def f1_macro(logits, y): - logits = logits.view(-1, logits.shape[-1]) - y = y.view(-1) - y_hat = torch.argmax(logits, dim=-1) - return f1_score(y.cpu().numpy(), y_hat.cpu().numpy(), average="macro") - - -def f1_micro(logits, y): - logits = logits.view(-1, logits.shape[-1]) - y = y.view(-1) - y_hat = torch.argmax(logits, dim=-1) - return f1_score(y.cpu().numpy(), y_hat.cpu().numpy(), average="micro") - - -def roc_auc_macro(logits, y): - logits = logits.view( - -1, logits.shape[-1] - ).detach() # KS: had to add detach to eval while training - y = y.view(-1) - return roc_auc_score( - y.cpu().numpy(), F.softmax(logits, dim=-1).cpu().numpy()[:, 1], average="macro" - ) - - -def roc_auc_micro(logits, y): - logits = logits.view(-1, logits.shape[-1]) - y = y.view(-1) - return roc_auc_score( - y.cpu().numpy(), F.softmax(logits, dim=-1).cpu().numpy()[:, 1], average="micro" - ) - - -def mse(outs, y, len_batch=None): - # assert outs.shape[:-1] == y.shape and outs.shape[-1] == 1 - # outs = outs.squeeze(-1) - if len(y.shape) < len(outs.shape): - assert outs.shape[-1] == 1 - outs = outs.squeeze(-1) - if len_batch is None: - return F.mse_loss(outs, y) - else: - # Computes the loss of the first `lens` items in the batches - # TODO document the use case of this - mask = torch.zeros_like(outs, dtype=torch.bool) - for i, l in enumerate(len_batch): - mask[i, :l, :] = 1 - outs_masked = torch.masked_select(outs, mask) - y_masked = torch.masked_select(y, mask) - return F.mse_loss(outs_masked, y_masked) - -def forecast_rmse(outs, y, len_batch=None): - # TODO: generalize, currently for Monash dataset - return torch.sqrt(F.mse_loss(outs, y, reduction='none').mean(1)).mean() - -def mae(outs, y, len_batch=None): - # assert outs.shape[:-1] == y.shape and outs.shape[-1] == 1 - # outs = outs.squeeze(-1) - if len(y.shape) < len(outs.shape): - assert outs.shape[-1] == 1 - outs = outs.squeeze(-1) - if len_batch is None: - return F.l1_loss(outs, y) - else: - # Computes the loss of the first `lens` items in the batches - mask = torch.zeros_like(outs, dtype=torch.bool) - for i, l in enumerate(len_batch): - mask[i, :l, :] = 1 - outs_masked = torch.masked_select(outs, mask) - y_masked = torch.masked_select(y, mask) - return F.l1_loss(outs_masked, y_masked) - - -# Metrics that can depend on the loss -def loss(x, y, loss_fn): - """ This metric may be useful because the training loss may add extra regularization (e.g. weight decay implemented as L2 penalty), while adding this as a metric skips the additional losses """ - return loss_fn(x, y) - - -def bpb(x, y, loss_fn): - """ bits per byte (image density estimation, speech generation, char LM) """ - return loss_fn(x, y) / math.log(2) - - -def ppl(x, y, loss_fn): - return torch.exp(loss_fn(x, y)) - - -# should have a better way to do this -output_metric_fns = { - "binary_cross_entropy": binary_cross_entropy, - "cross_entropy": cross_entropy, - "binary_accuracy": binary_accuracy, - "accuracy": accuracy, - "accuracy_ignore_index": accuracy_ignore_index, - 'accuracy@3': partial(accuracy_at_k, k=3), - 'accuracy@5': partial(accuracy_at_k, k=5), - 'accuracy@10': partial(accuracy_at_k, k=10), - "eval_loss": loss, - "mse": mse, - "mae": mae, - "forecast_rmse": forecast_rmse, - "f1_binary": f1_binary, - "f1_macro": f1_macro, - "f1_micro": f1_micro, - "roc_auc_macro": roc_auc_macro, - "roc_auc_micro": roc_auc_micro, - "soft_cross_entropy": soft_cross_entropy, # only for pytorch 1.10+ - "student_t": student_t_loss, - "gaussian_ll": gaussian_ll_loss, -} - -try: - from segmentation_models_pytorch.utils.functional import iou - from segmentation_models_pytorch.losses.focal import focal_loss_with_logits - - def iou_with_logits(pr, gt, eps=1e-7, threshold=None, ignore_channels=None): - return iou(pr.sigmoid(), gt, eps=eps, threshold=threshold, ignore_channels=ignore_channels) - - output_metric_fns["iou"] = partial(iou, threshold=0.5) - output_metric_fns["iou_with_logits"] = partial(iou_with_logits, threshold=0.5) - output_metric_fns["focal_loss"] = focal_loss_with_logits -except ImportError: - pass - -loss_metric_fns = { - "loss": loss, - "bpb": bpb, - "ppl": ppl, -} -metric_fns = {**output_metric_fns, **loss_metric_fns} # TODO py3.9 - diff --git a/src/clm/src/tasks/tasks.py b/src/clm/src/tasks/tasks.py deleted file mode 100644 index 8f0be6ea..00000000 --- a/src/clm/src/tasks/tasks.py +++ /dev/null @@ -1,371 +0,0 @@ -from typing import Optional, List, Tuple -import math -import functools -import collections -import torch -import torch.nn as nn -import torch.nn.functional as F -from einops import rearrange -from omegaconf import ListConfig -from clm.src.models.nn.components import ReversibleInstanceNorm1dInput, ReversibleInstanceNorm1dOutput, \ - TSNormalization, TSInverseNormalization - -from clm.src.models.nn.adaptive_softmax import AdaptiveEmbedding, ProjectedAdaptiveLogSoftmax -import clm.src.tasks.metrics as M -from clm.src.tasks.torchmetrics import torchmetric_fns as tm_mine -import clm.src.models.nn.utils as U -import torchmetrics as tm -from clm.src.utils.config import to_list, instantiate -from torchmetrics import MetricCollection - -class BaseTask: - """ Abstract class that takes care of: - - loss function - - arbitrary metrics - - forward pass - - (optional) encoder module that interfaces with dataset (inputs) and model - - (optional) decoder module that interfaces with dataset (targets) and model - """ - encoder = None - decoder = None - - def __init__(self, dataset=None, model=None, loss=None, loss_val=None, metrics=None, torchmetrics=None): - """ This class is allowed to grab attributes directly off a constructed dataset and model object """ - self.dataset = dataset - self.model = model - if metrics is None: metrics = [] - self.metric_names = to_list(metrics) - - if torchmetrics is None: torchmetrics = [] - self.torchmetric_names = to_list(torchmetrics) - self._tracked_torchmetrics = {} - - # The decoder might pass through arguments that the loss needs (e.g. sequence lengths) - # but might also pass through extraneous arguments (e.g. sampling rate) - # Wrap loss and metrics so that they accept kwargs and - - # Create loss function - self.loss = instantiate(M.output_metric_fns, loss, partial=True) - self.loss = U.discard_kwargs(self.loss) - if loss_val is not None: - self.loss_val = instantiate(M.output_metric_fns, loss_val, partial=True) - self.loss_val = U.discard_kwargs(self.loss_val) - torchmetrics = MetricCollection(self._init_torchmetrics()) - self.train_torchmetrics = torchmetrics.clone(prefix='train/') - self.val_torchmetrics = torchmetrics.clone(prefix='val/') - self.test_torchmetrics = torchmetrics.clone(prefix='test/') - - def _init_torchmetrics(self): - """ - Instantiate torchmetrics. - """ - tracked_torchmetrics = {} - - for name in self.torchmetric_names: - if name in tm_mine: - tracked_torchmetrics[name] = tm_mine[name]().to('cuda') - elif name in ['AUROC', 'StatScores', 'Precision', 'Recall', 'F1', 'F1Score']: - tracked_torchmetrics[name] = getattr(tm, name)(average='macro', num_classes=self.dataset.d_output, compute_on_step=False).to('cuda') - elif '@' in name: - k = int(name.split('@')[1]) - mname = name.split('@')[0] - tracked_torchmetrics[name] = getattr(tm, mname)(average='macro', num_classes=self.dataset.d_output, compute_on_step=False, top_k=k).to('cuda') - else: - tracked_torchmetrics[name] = getattr(tm, name)(compute_on_step=False).to('cuda') - - return tracked_torchmetrics - - def _reset_torchmetrics(self, prefix=None): - """ - Reset torchmetrics for a prefix - associated with a particular dataloader (e.g. train, val, test). - - Generally do this at the start of an epoch. - """ - all_prefixes = [prefix] if prefix is not None else self._tracked_torchmetrics - - for prefix in all_prefixes: - if prefix in self._tracked_torchmetrics: - self._tracked_torchmetrics[prefix].reset() - - def get_torchmetrics(self, prefix): - """ - Compute torchmetrics for a prefix associated with - a particular dataloader (e.g. train, val, test). - - Generally do this at the end of an epoch. - """ - return {name: self._tracked_torchmetrics[prefix][name].compute() for name in self.torchmetric_names} - - def torchmetrics(self, x, y, prefix, loss=None): - """ - Update torchmetrics with new x, y . - Prefix corresponds to a particular dataloader (e.g. train, val, test). - - Generally call this every batch. - """ - if prefix not in self._tracked_torchmetrics: - self._init_torchmetrics(prefix) - self._tracked_torchmetrics[prefix](x, y, loss=loss) - - # for name in self.torchmetric_names: - # if name.startswith('Accuracy'): - # if len(x.shape) > 2: - # # Multi-dimensional, multi-class - # self._tracked_torchmetrics[prefix][name].update(x.transpose(1, 2), y.squeeze()) - # continue - # self._tracked_torchmetrics[prefix][name].update(x, y) - - def get_torchmetrics(self, prefix): - return self._tracked_torchmetrics[prefix] - - def metrics(self, x, y, **kwargs): - """ - Metrics are just functions - output metrics are a function of output and target - loss metrics are a function of loss (e.g. perplexity) - """ - output_metrics = { - name: U.discard_kwargs(M.output_metric_fns[name])(x, y, **kwargs) - for name in self.metric_names if name in M.output_metric_fns - } - loss_metrics = { - name: U.discard_kwargs(M.loss_metric_fns[name])(x, y, self.loss, **kwargs) - for name in self.metric_names if name in M.loss_metric_fns - } - return {**output_metrics, **loss_metrics} - - def forward(self, batch, encoder, model, decoder, _state): - """Passes a batch through the encoder, backbone, and decoder""" - # z holds arguments such as sequence length - x, y, *z = batch # z holds extra dataloader info such as resolution - if len(z) == 0: - z = {} - else: - assert len(z) == 1 and isinstance(z[0], dict), "Dataloader must return dictionary of extra arguments" - z = z[0] - - x, w = encoder(x, **z) # w can model-specific constructions such as key_padding_mask for transformers or state for RNNs - x, state = model(x, **w, state=_state) - self._state = state - x, w = decoder(x, state=state, **z) - return x, y, w - - -class Scalar(nn.Module): - def __init__(self, c=1): - super().__init__() - self.c = c - def forward(self, x): - return x * self.c - -class LMTask(BaseTask): - def forward(self, batch, encoder, model, decoder, _state): - """Passes a batch through the encoder, backbone, and decoder""" - # z holds arguments such as sequence length - x, y, *z = batch # z holds extra dataloader info such as resolution - if len(z) == 0: - z = {} - else: - assert len(z) == 1 and isinstance(z[0], dict), "Dataloader must return dictionary of extra arguments" - z = z[0] - x, w = encoder(x, **z) # w can model-specific constructions such as key_padding_mask for transformers or state for RNNs - x, state = model(x, **w, state=_state) - self._state = state - x, w = decoder(x, state=state, **z) - - x = x.logits - x = rearrange(x, '... C -> (...) C') - y = rearrange(y, '... -> (...)') - - return x, y, w - -class ForecastingTask(BaseTask): - - class DummyModule(nn.Module): - - def forward(self, *args): - return args - - def __init__(self, norm='mean', **kwargs): - super().__init__(**kwargs) - - if norm == 'revnorm': - self.encoder = ReversibleInstanceNorm1dInput(self.dataset.d_input, transposed=False) - self.decoder = ReversibleInstanceNorm1dOutput(self.encoder) - elif norm == 'mean': - self.encoder = TSNormalization(method='mean', horizon=self.dataset.dataset_train.forecast_horizon) - self.decoder = TSInverseNormalization(method='mean', normalizer=self.encoder) - elif norm == 'last': - self.encoder = TSNormalization(method='last', horizon=self.dataset.dataset_train.forecast_horizon) - self.decoder = TSInverseNormalization(method='last', normalizer=self.encoder) - else: - self.encoder = None - self.decoder = None - - try: - if hasattr(self.dataset.dataset_train, 'mean'): - self.mean = torch.tensor(self.dataset.dataset_train.mean) - self.std = torch.tensor(self.dataset.dataset_train.std) - elif hasattr(self.dataset.dataset_train, 'standardization'): - self.mean = torch.tensor(self.dataset.dataset_train.standardization['means']) - self.std = torch.tensor(self.dataset.dataset_train.standardization['stds']) - else: - self.mean = None - self.std = None - except AttributeError: - raise AttributeError('Dataset does not have mean/std attributes') - self.mean = torch.tensor(self.dataset.dataset_train.standardization['means']) - self.std = torch.tensor(self.dataset.dataset_train.standardization['stds']) - - if hasattr(self.dataset.dataset_train, 'log_transform'): - self.log_transform = self.dataset.dataset_train.log_transform - else: - self.log_transform = False - print("Log Transform", self.log_transform) - - def metrics(self, x, y, state=None, timestamps=None, ids=None): # Explicit about which arguments the decoder might pass through, but can future-proof with **kwargs - if self.mean is not None: - means = self.mean[ids].to(x.device) - stds = self.std[ids].to(x.device) - x_ = x * stds[:, None, None] + means[:, None, None] - y_ = y * stds[:, None, None] + means[:, None, None] - else: - x_ = x - y_ = y - - if self.log_transform: - x_ = torch.exp(x_) - y_ = torch.exp(y_) - - return super().metrics(x_, y_) - -class VideoTask(BaseTask): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - # self._y_to_logits = {} - self._vid_to_logits = {} - self._vid_to_label = {} - - # TODO needed to extract the first element of y, which includes the video idea; there should be a cleaner pattern to this - import copy - loss_fn = copy.deepcopy(self.loss) - self.loss = lambda x, y: loss_fn(x, y[0]) - if hasattr(self, 'loss_val'): - loss_val_fn = copy.deepcopy(self.loss_val) - self.loss_val = lambda x, y: loss_val_fn(x, y[0]) - - def metrics(self, logits, y, **kwargs): - labels, vids = y - return super().metrics(logits, labels, **kwargs) - - def torchmetrics(self, logits, y, prefix): - """ - logits: (batch, n_classes) - y = tuple of labels and video ids - labels: (batch) - vids: (batch) - """ - for _logits, _label, _vid in zip(logits, y[0], y[1]): - _vid = _vid.item() - # Check that labels are consistent per video id - assert self._vid_to_label[prefix].get(_vid, _label) == _label - self._vid_to_label[prefix][_vid] = _label - - self._vid_to_logits[prefix][_vid].append(_logits) - - def _reset_torchmetrics(self, prefix): - self._vid_to_logits[prefix] = collections.defaultdict(list) - self._vid_to_label[prefix] = {} - - def get_torchmetrics(self, prefix): - vid_to_average_logits = {vid: torch.mean(torch.stack(logits, dim=0), dim=0) for vid, logits in self._vid_to_logits[prefix].items()} - # y is (label, vid) pair - all_labels = torch.stack(list(self._vid_to_label[prefix].values()), dim=0) # (n_videos) - all_logits = torch.stack(list(vid_to_average_logits.values()), dim=0) # (n_videos, n_classes) - m = M.accuracy(all_logits, all_labels) - return {'aggregate_accuracy': m} - - -class AdaptiveLMTask(BaseTask): - def __init__( - self, - div_val, - cutoffs : List[int], - tie_weights : bool, - tie_projs : List[bool], - init_scale=1.0, - bias_scale=0.0, - dropemb=0.0, - dropsoft=0.0, - **kwargs, - ): - super().__init__(**kwargs) - n_tokens = self.dataset.n_tokens - d_model = self.model.d_model - d_output = self.model.d_output - - encoder = AdaptiveEmbedding( - n_tokens, - d_model, - d_model, - cutoffs=cutoffs, - div_val=div_val, - init_scale=init_scale, - dropout=dropemb, - ) - - if tie_weights: - assert d_model == d_output - emb_layers = [i.weight for i in encoder.emb_layers] - else: - emb_layers = None - - # Construct decoder/loss - emb_projs = encoder.emb_projs - loss = ProjectedAdaptiveLogSoftmax( - n_tokens, d_output, d_output, - cutoffs, div_val=div_val, - tie_projs=tie_projs, - out_projs=emb_projs, - out_layers_weights=emb_layers, - bias_scale=bias_scale, - dropout=dropsoft, - ) - - self.encoder = encoder - self.loss = loss - - -class ImageNetTask(BaseTask): - """ - Imagenet training uses mixup augmentations, which require a separate loss for train and val, - which we overide the base task here. - """ - - def __init__(self, **kwargs): - import hydra - - super().__init__( - dataset=kwargs.get("dataset", None), - model=kwargs.get("model", None), - loss=kwargs.get("loss", None), # we still create the base loss here, but will overide below - metrics=kwargs.get("metrics", None), - torchmetrics=kwargs.get("torchmetrics", None) - ) - - # if using mixup, overide loss (train) and loss_val, otherwise - # we have just one loss from the base task above - if "loss_val" in kwargs and "loss_train" in kwargs: - self.loss = hydra.utils.instantiate(kwargs.get("loss_train")) - self.loss_val = hydra.utils.instantiate(kwargs.get('loss_val')) - - -registry = { - 'base': BaseTask, - 'lm': LMTask, - 'imagenet': ImageNetTask, - 'forecasting': ForecastingTask, - 'video': VideoTask, -} diff --git a/src/clm/src/tasks/torchmetrics.py b/src/clm/src/tasks/torchmetrics.py deleted file mode 100644 index f3580b43..00000000 --- a/src/clm/src/tasks/torchmetrics.py +++ /dev/null @@ -1,120 +0,0 @@ -# Inspired by https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/common/metrics/perplexity.py -# But we compute the perplexity correctly: exp(average(nll)), not average(exp(nll)) -# Also adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/text/perplexity.py -# But we pass in the loss to avoid recomputation - -from typing import Any, Dict, Optional - -import torch -import torch.nn.functional as F -from torch import Tensor -from torchmetrics import Metric - -try: - from flash_attn.losses.cross_entropy import CrossEntropyLoss -except ImportError: - CrossEntropyLoss = torch.nn.CrossEntropyLoss - -try: - from apex.transformer import parallel_state -except ImportError: - parallel_state = None - - -class Perplexity(Metric): - r""" - Perplexity measures how well a language model predicts a text sample. It's calculated as the average number of bits - per word a model needs to represent the sample. - Args: - kwargs: - Additional keyword arguments, see :ref:`Metric kwargs` for more info. - Examples: - >>> import torch - >>> preds = torch.rand(2, 8, 5, generator=torch.manual_seed(22)) - >>> target = torch.randint(5, (2, 8), generator=torch.manual_seed(22)) - >>> target[0, 6:] = -100 - >>> metric = Perplexity(ignore_index=-100) - >>> metric(preds, target) - tensor(5.2545) - """ - is_differentiable = True - higher_is_better = False - full_state_update = False - total_log_probs: Tensor - count: Tensor - - def __init__(self, **kwargs: Dict[str, Any]): - super().__init__(**kwargs) - self.add_state("total_log_probs", default=torch.tensor(0.0, dtype=torch.float64), - dist_reduce_fx="sum") - self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum") - - self.loss_fn = CrossEntropyLoss() - - def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore - """Compute and store intermediate statistics for Perplexity. - Args: - preds: - Probabilities assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size]. - target: - Ground truth values with a shape [batch_size, seq_len]. - """ - count = target.numel() - if loss is None: - loss = self.loss_fn(preds, target) - self.total_log_probs += loss.double() * count - self.count += count - - def compute(self) -> Tensor: - """Compute the Perplexity. - Returns: - Perplexity - """ - return torch.exp(self.total_log_probs / self.count) - -class NumTokens(Metric): - """Keep track of how many tokens we've seen. - """ - # TODO: how do we prevent the reset between the epochs? The reset happens on the 1st batch - # of the next epoch. - # Right now the hack is that we override reset(), which would mess up the forward method. - # We then override forward to do the right thing. - - is_differentiable = False - higher_is_better = False - full_state_update = False - count: Tensor - - def __init__(self, **kwargs: Dict[str, Any]): - super().__init__(**kwargs) - self.add_state("count", default=torch.tensor(0, dtype=torch.int64), dist_reduce_fx="sum", - persistent=True) # We want the count to be saved to state-dict - if parallel_state is not None and not parallel_state.is_unitialized(): - self.tensor_parallel_world_size = parallel_state.get_tensor_model_parallel_world_size() - else: - self.tensor_parallel_world_size = 1 - - def update(self, preds: Tensor, target: Tensor, loss: Optional[Tensor] = None) -> None: # type: ignore - self.count += target.numel() // self.tensor_parallel_world_size - - def compute(self) -> Tensor: - return self.count - - def reset(self): - count = self.count - super().reset() - self.count = count - - # Adapted from https://github.com/Lightning-AI/metrics/blob/master/src/torchmetrics/metric.py - def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: - """forward computation using single call to `update` to calculate the metric value on the current batch and - accumulate global state. - This can be done when the global metric state is a sinple reduction of batch states. - """ - self.update(*args, **kwargs) - return self.compute() - -torchmetric_fns = { - "perplexity": Perplexity, - "num_tokens": NumTokens, -} \ No newline at end of file diff --git a/src/clm/src/utils/__init__.py b/src/clm/src/utils/__init__.py deleted file mode 100644 index 960c2b9a..00000000 --- a/src/clm/src/utils/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .config import is_list, is_dict, to_list, to_dict, get_class, instantiate diff --git a/src/clm/src/utils/config.py b/src/clm/src/utils/config.py deleted file mode 100644 index 9020037c..00000000 --- a/src/clm/src/utils/config.py +++ /dev/null @@ -1,124 +0,0 @@ -""" Utilities for dealing with collection objects (lists, dicts) and configs """ -from typing import Sequence, Mapping, Optional, Callable -import functools -import hydra -from omegaconf import ListConfig, DictConfig - -# TODO this is usually used in a pattern where it's turned into a list, so can just do that here -def is_list(x): - return isinstance(x, Sequence) and not isinstance(x, str) - - -def is_dict(x): - return isinstance(x, Mapping) - - -def to_dict(x, recursive=True): - """Convert Sequence or Mapping object to dict - - lists get converted to {0: x[0], 1: x[1], ...} - """ - if is_list(x): - x = {i: v for i, v in enumerate(x)} - if is_dict(x): - if recursive: - return {k: to_dict(v, recursive=recursive) for k, v in x.items()} - else: - return dict(x) - else: - return x - - -def to_list(x, recursive=False): - """Convert an object to list. - - If Sequence (e.g. list, tuple, Listconfig): just return it - - Special case: If non-recursive and not a list, wrap in list - """ - if is_list(x): - if recursive: - return [to_list(_x) for _x in x] - else: - return list(x) - else: - if recursive: - return x - else: - return [x] - - -def extract_attrs_from_obj(obj, *attrs): - if obj is None: - assert len(attrs) == 0 - return [] - return [getattr(obj, attr, None) for attr in attrs] - - -def auto_assign_attrs(cls, **kwargs): - for k, v in kwargs.items(): - setattr(cls, k, v) - - -def instantiate(registry, config, *args, partial=False, wrap=None, **kwargs): - """ - registry: Dictionary mapping names to functions or target paths (e.g. {'model': 'models.SequenceModel'}) - config: Dictionary with a '_name_' key indicating which element of the registry to grab, and kwargs to be passed into the target constructor - wrap: wrap the target class (e.g. ema optimizer or tasks.wrap) - *args, **kwargs: additional arguments to override the config to pass into the target constructor - """ - - # Case 1: no config - if config is None: - return None - # Case 2a: string means _name_ was overloaded - if isinstance(config, str): - _name_ = None - _target_ = registry[config] - config = {} - # Case 2b: grab the desired callable from name - else: - _name_ = config.pop("_name_") - _target_ = registry[_name_] - - # Retrieve the right constructor automatically based on type - if isinstance(_target_, str): - fn = hydra.utils.get_method(path=_target_) - elif isinstance(_target_, Callable): - fn = _target_ - else: - raise NotImplementedError("instantiate target must be string or callable") - - # Instantiate object - if wrap is not None: - fn = wrap(fn) - obj = functools.partial(fn, *args, **config, **kwargs) - - # Restore _name_ - if _name_ is not None: - config["_name_"] = _name_ - - if partial: - return obj - else: - return obj() - - -def get_class(registry, _name_): - return hydra.utils.get_class(path=registry[_name_]) - - -def omegaconf_filter_keys(d, fn=None): - """Only keep keys where fn(key) is True. Support nested DictConfig. - # TODO can make this inplace? - """ - if fn is None: - fn = lambda _: True - if is_list(d): - return ListConfig([omegaconf_filter_keys(v, fn) for v in d]) - elif is_dict(d): - return DictConfig( - {k: omegaconf_filter_keys(v, fn) for k, v in d.items() if fn(k)} - ) - else: - return d diff --git a/src/clm/src/utils/distributed.py b/src/clm/src/utils/distributed.py deleted file mode 100644 index 77d7ecf2..00000000 --- a/src/clm/src/utils/distributed.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -from contextlib import contextmanager - -import torch - - -def init_distributed(cuda): - """ - Initializes distributed backend. - - :param cuda: (bool) if True initializes nccl backend, if False initializes - gloo backend - """ - world_size = int(os.environ.get('WORLD_SIZE', 1)) - distributed = (world_size > 1) - if distributed: - backend = 'nccl' if cuda else 'gloo' - torch.distributed.init_process_group(backend=backend, - init_method='env://') - assert torch.distributed.is_initialized() - return distributed - - -def barrier(): - """ - Call torch.distributed.barrier() if distritubed is in use - """ - if torch.distributed.is_available() and torch.distributed.is_initialized(): - torch.distributed.barrier() - - -def get_rank(): - """ - Gets distributed rank or returns zero if distributed is not initialized. - """ - if torch.distributed.is_available() and torch.distributed.is_initialized(): - rank = torch.distributed.get_rank() - else: - rank = 0 - return rank - - -def get_world_size(): - """ - Gets total number of distributed workers or returns one if distributed is - not initialized. - """ - if torch.distributed.is_available() and torch.distributed.is_initialized(): - world_size = torch.distributed.get_world_size() - else: - world_size = 1 - return world_size - - -def all_reduce_item(value, op='sum'): - """ - All-reduces single scalar value if distributed is in use - """ - if torch.distributed.is_available() and torch.distributed.is_initialized(): - if op == 'sum' or op == 'mean': - dop = torch.distributed.ReduceOp.SUM - elif op == 'min': - dop = torch.distributed.ReduceOp.MIN - elif op == 'max': - dop = torch.distributed.ReduceOp.MAX - elif op == 'product': - dop = torch.distributed.ReduceOp.PRODUCT - else: - raise RuntimeError('Unsupported reduce op') - - backend = torch.distributed.get_backend() - if backend == torch.distributed.Backend.NCCL: - device = torch.device('cuda') - elif backend == torch.distributed.Backend.GLOO: - device = torch.device('cpu') - else: - raise RuntimeError('Unsupported distributed backend') - - tensor = torch.tensor(value, device=device) - torch.distributed.all_reduce(tensor, dop) - if op == 'mean': - tensor /= get_world_size() - ret = tensor.item() - else: - ret = value - return ret - - -def all_reduce_tensor(value, op='sum'): - """ - All-reduces single scalar value if distributed is in use - """ - if torch.distributed.is_available() and torch.distributed.is_initialized(): - if op == 'sum' or op == 'mean': - dop = torch.distributed.ReduceOp.SUM - elif op == 'min': - dop = torch.distributed.ReduceOp.MIN - elif op == 'max': - dop = torch.distributed.ReduceOp.MAX - elif op == 'product': - dop = torch.distributed.ReduceOp.PRODUCT - else: - raise RuntimeError('Unsupported reduce op') - - backend = torch.distributed.get_backend() - if backend == torch.distributed.Backend.NCCL: - device = torch.device('cuda') - elif backend == torch.distributed.Backend.GLOO: - device = torch.device('cpu') - else: - raise RuntimeError('Unsupported distributed backend') - - tensor = value - torch.distributed.all_reduce(tensor, dop) - if op == 'mean': - tensor /= get_world_size() - ret = tensor - else: - ret = value - return ret - - -@contextmanager -def sync_workers(): - """ - Yields distributed rank and synchronizes all workers on exit. - """ - rank = get_rank() - yield rank - barrier() diff --git a/src/clm/src/utils/optim/lamb.py b/src/clm/src/utils/optim/lamb.py deleted file mode 100644 index 8bbdf3a2..00000000 --- a/src/clm/src/utils/optim/lamb.py +++ /dev/null @@ -1,251 +0,0 @@ -# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# MIT License -# -# Copyright (c) 2019 cybertronai -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -"""Lamb optimizer.""" - -import torch -from torch.optim import Optimizer - - -class Lamb(Optimizer): - r"""Implements Lamb algorithm. - - It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - adam (bool, optional): always use trust ratio = 1, which turns this into - Adam. Useful for comparison purposes. - - .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: - https://arxiv.org/abs/1904.00962 - """ - - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, - weight_decay=0, adam=False): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay) - self.adam = adam - super(Lamb, self).__init__(params, defaults) - - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - loss = closure() - - for group in self.param_groups: - for p in group['params']: - if p.grad is None: - continue - grad = p.grad.data - if grad.is_sparse: - raise RuntimeError('Lamb does not support sparse gradients.') - - state = self.state[p] - - # State initialization - if len(state) == 0: - state['step'] = 0 - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p.data) - - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - beta1, beta2 = group['betas'] - - state['step'] += 1 - - # Decay the first and second moment running average coefficient - # m_t - exp_avg.mul_(beta1).add_(1 - beta1, grad) - # v_t - exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) - - # Paper v3 does not use debiasing. - # bias_correction1 = 1 - beta1 ** state['step'] - # bias_correction2 = 1 - beta2 ** state['step'] - # Apply bias to lr to avoid broadcast. - step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1 - - weight_norm = p.data.norm(p=2).clamp_(0, 10) - - adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps']) - if group['weight_decay'] != 0: - adam_step.add_(group['weight_decay'], p.data) - - adam_norm = adam_step.norm(p=2) - - if weight_norm == 0.0 or adam_norm == 0.0: - trust_ratio = 1 - else: - trust_ratio = weight_norm / (adam_norm + group['eps']) - - state['weight_norm'] = weight_norm - state['adam_norm'] = adam_norm - state['trust_ratio'] = trust_ratio - if self.adam: - trust_ratio = 1 - - p.data.add_(-step_size * trust_ratio, adam_step) - - return loss - - -@torch.jit.script -def lamb_kernel(param, grad, exp_avg, exp_avg_sq, beta1: float, - beta2: float, step_size: float, eps: float, weight_decay: float): - exp_avg = exp_avg * beta1 + (1 - beta1) * grad - exp_avg_sq = exp_avg_sq * beta2 + (1 - beta2) * (grad * grad) - - adam_step = exp_avg / (exp_avg_sq.sqrt() + eps) - adam_step = adam_step + weight_decay * param - - weight_norm = param.norm(p=2).clamp(0, 10) - adam_norm = adam_step.norm(p=2) - - trust_ratio = weight_norm / (adam_norm + eps) - trust_ratio = (weight_norm == 0.0) * 1.0 + (weight_norm != 0.0) * trust_ratio - trust_ratio = (adam_norm == 0.0) * 1.0 + (adam_norm != 0.0) * trust_ratio - trust_ratio = trust_ratio.float() - - param = param - step_size * trust_ratio * adam_step - return param, exp_avg, exp_avg_sq - - -class JITLamb(Optimizer): - r"""Implements Lamb algorithm. - - It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay (L2 penalty) (default: 0) - adam (bool, optional): always use trust ratio = 1, which turns this into - Adam. Useful for comparison purposes. - - .. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes: - https://arxiv.org/abs/1904.00962 - """ - - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, - weight_decay=0, adam=False): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - defaults = dict(lr=lr, betas=betas, eps=eps, - weight_decay=weight_decay) - self.adam = adam - super().__init__(params, defaults) - - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - loss = closure() - - for group in self.param_groups: - for p in group['params']: - if p.grad is None: - continue - grad = p.grad.data - if grad.is_sparse: - raise RuntimeError('Lamb does not support sparse gradients.') - - state = self.state[p] - - # State initialization - if len(state) == 0: - state['step'] = 0 - # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data) - # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p.data) - - exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] - beta1, beta2 = group['betas'] - - state['step'] += 1 - step_size = group['lr'] - - param, exp_avg, exp_avg_sq = lamb_kernel(p.data, grad, exp_avg, - exp_avg_sq, beta1, - beta2, step_size, - group['eps'], - group['weight_decay'], - ) - state['exp_avg'] = exp_avg - state['exp_avg_sq'] = exp_avg_sq - p.data = param - - return loss diff --git a/src/clm/src/utils/optim/schedulers.py b/src/clm/src/utils/optim/schedulers.py deleted file mode 100644 index 35e6d877..00000000 --- a/src/clm/src/utils/optim/schedulers.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Custom learning rate schedulers""" - -import math -import warnings -import torch - -from timm.scheduler import CosineLRScheduler - - -# https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html -class CosineWarmup(torch.optim.lr_scheduler.CosineAnnealingLR): - - def __init__(self, optimizer, T_max, eta_min=0, warmup_step=0, **kwargs): - self.warmup_step = warmup_step - super().__init__(optimizer, T_max - warmup_step, eta_min, *kwargs) - - # Copied from CosineAnnealingLR, but adding warmup and changing self.last_epoch to - # self.last_epoch - self.warmup_step. - def get_lr(self): - if not self._get_lr_called_within_step: - warnings.warn("To get the last learning rate computed by the scheduler, " - "please use `get_last_lr()`.", UserWarning) - - if self.last_epoch == self.warmup_step: # also covers the case where both are 0 - return self.base_lrs - elif self.last_epoch < self.warmup_step: - return [base_lr * (self.last_epoch + 1) / self.warmup_step for base_lr in self.base_lrs] - elif (self.last_epoch - self.warmup_step - 1 - self.T_max) % (2 * self.T_max) == 0: - return [group['lr'] + (base_lr - self.eta_min) * - (1 - math.cos(math.pi / self.T_max)) / 2 - for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)] - return [(1 + math.cos(math.pi * (self.last_epoch - self.warmup_step) / self.T_max)) / - (1 + math.cos(math.pi * (self.last_epoch - self.warmup_step - 1) / self.T_max)) * - (group['lr'] - self.eta_min) + self.eta_min - for group in self.optimizer.param_groups] - - _get_closed_form_lr = None - - -def InvSqrt(optimizer, warmup_step): - """ Originally used for Transformer (in Attention is all you need) - """ - - def lr_lambda(step): - # return a multiplier instead of a learning rate - if step == warmup_step: # also covers the case where both are 0 - return 1. - else: - return 1. / (step ** 0.5) if step > warmup_step else (step + 1) / (warmup_step ** 1.5) - - return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) - - -def Constant(optimizer, warmup_step): - - def lr_lambda(step): - if step == warmup_step: # also covers the case where both are 0 - return 1. - else: - return 1. if step > warmup_step else (step + 1) / warmup_step - - return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) - - -class TimmCosineLRScheduler(CosineLRScheduler, torch.optim.lr_scheduler._LRScheduler): - """ Wrap timm.scheduler.CosineLRScheduler so we can call scheduler.step() without passing in epoch. - It supports resuming as well. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._last_epoch = -1 - self.step(epoch=0) - - def step(self, epoch=None): - if epoch is None: - self._last_epoch += 1 - else: - self._last_epoch = epoch - # We call either step or step_update, depending on whether we're using the scheduler every - # epoch or every step. - # Otherwise, lightning will always call step (i.e., meant for each epoch), and if we set - # scheduler interval to "step", then the learning rate update will be wrong. - if self.t_in_epochs: - super().step(epoch=self._last_epoch) - else: - super().step_update(num_updates=self._last_epoch) diff --git a/src/clm/src/utils/optim_groups.py b/src/clm/src/utils/optim_groups.py deleted file mode 100644 index b935a8f3..00000000 --- a/src/clm/src/utils/optim_groups.py +++ /dev/null @@ -1,144 +0,0 @@ -"""Utilities for special optimizer hyperparameters. - -group_parameters_for_optimizer is a modification of timm's optimizer logic, which is currently unused -add_optimizer_hooks is an improved version that uses this codebase's _optim dictionary -""" - -import inspect - -import torch.nn as nn - -import hydra - - -def add_optimizer_hooks( - model, - bias_weight_decay=False, - normalization_weight_decay=False, -): - """Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with - attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False, for - normalization parameters if normalization_weight_decay==False - """ - - # Separate out all parameters to those that will and won't experience regularizing weight decay - blacklist_weight_modules = (nn.Embedding, ) - if not normalization_weight_decay: - blacklist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, - # Not compatible with Pytorch 1.8.1 - # nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d, - nn.GroupNorm, nn.SyncBatchNorm, - nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, - nn.LayerNorm, nn.LocalResponseNorm) - for mn, m in model.named_modules(): - for pn, p in m.named_parameters(): - if (not bias_weight_decay and pn.endswith('bias')) \ - or getattr(p, '_no_weight_decay', False) \ - or isinstance(m, blacklist_weight_modules): - setattr(p, "_optim", {"weight_decay": 0.0}) - - -def group_parameters_for_optimizer( - model, - optimizer_cfg, - bias_weight_decay=False, - normalization_weight_decay=False, -): - """Set weight_decay=0.0 for parameters in model.no_weight_decay, for parameters with - attribute _no_weight_decay==True, for bias parameters if bias_weight_decay==False, for - normalization parameters if normalization_weight_decay==False - """ - # Get the weight decay from the config, or from the default value of the optimizer constructor - # if it's not specified in the config. - if 'weight_decay' in optimizer_cfg: - weight_decay = optimizer_cfg.weight_decay - else: - # https://stackoverflow.com/questions/12627118/get-a-function-arguments-default-value - signature = inspect.signature(hydra.utils.get_class(optimizer_cfg._target_)) - if 'weight_decay' in signature.parameters: - weight_decay = signature.parameters['weight_decay'].default - if weight_decay is inspect.Parameter.empty: - weight_decay = 0.0 - else: - weight_decay = 0.0 - - # If none of the parameters have weight decay anyway, and there are no parameters with special - # optimization params - if weight_decay == 0.0 and not any(hasattr(p, '_optim') for p in model.parameters()): - return model.parameters() - - skip = model.no_weight_decay() if hasattr(model, 'no_weight_decay') else set() - skip_keywords = (model.no_weight_decay_keywords() if hasattr(model, 'no_weight_decay_keywords') - else set()) - - # Adapted from https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L134 - """ - This long function is unfortunately doing something very simple and is being very defensive: - We are separating out all parameters of the model into two buckets: those that will experience - weight decay for regularization and those that won't (biases, and layernorm/embedding weights). - We are then returning the PyTorch optimizer object. - """ - - # separate out all parameters to those that will and won't experience regularizing weight decay - decay = set() - no_decay = set() - special = set() - whitelist_weight_modules = (nn.Linear, ) - blacklist_weight_modules = (nn.Embedding, ) - if not normalization_weight_decay: - blacklist_weight_modules += (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, - # Not compatible with Pytorch 1.8.1 - # nn.LazyBatchNorm1d, nn.LazyBatchNorm2d, nn.LazyBatchNorm3d, - nn.GroupNorm, nn.SyncBatchNorm, - nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d, - nn.LayerNorm, nn.LocalResponseNorm) - for mn, m in model.named_modules(): - for pn, p in m.named_parameters(): - fpn = '%s.%s' % (mn, pn) if mn else pn # full param name - if not p.requires_grad: - continue # frozen weights - if hasattr(p, '_optim'): - special.add(fpn) - elif fpn in skip or any(skip_keyword in fpn for skip_keyword in skip_keywords): - no_decay.add(fpn) - elif getattr(p, '_no_weight_decay', False): - no_decay.add(fpn) - elif not bias_weight_decay and pn.endswith('bias'): - no_decay.add(fpn) - elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): - # weights of whitelist modules will be weight decayed - decay.add(fpn) - elif isinstance(m, blacklist_weight_modules): - # weights of blacklist modules will NOT be weight decayed - no_decay.add(fpn) - - param_dict = {pn: p for pn, p in model.named_parameters() if p.requires_grad} - # special case the position embedding parameter in the root GPT module as not decayed - if 'pos_emb' in param_dict: - no_decay.add('pos_emb') - - # In case of parameter sharing, some parameters show up in decay but are not in param_dict.keys() - decay &= param_dict.keys() - decay |= (param_dict.keys() - no_decay - special) - # validate that we considered every parameter - inter_params = decay & no_decay - union_params = decay | no_decay - assert len(inter_params) == 0, f"Parameters {str(inter_params)} made it into both decay/no_decay sets!" - assert len(param_dict.keys() - special - union_params) == 0, f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!" - - if weight_decay == 0.0 or not no_decay: - param_groups = [{"params": [param_dict[pn] for pn in sorted(list(no_decay | decay))], - "weight_decay": weight_decay}] - else: - param_groups = [ - {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, - {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, - ] - # Add parameters with special hyperparameters - # Unique dicts - hps = [dict(s) for s in set(frozenset(param_dict[pn]._optim.items()) for pn in special)] - for hp in hps: - params = [param_dict[pn] for pn in sorted(list(special)) if param_dict[pn]._optim == hp] - param_groups.append({"params": params, **hp}) - - return param_groups diff --git a/src/clm/src/utils/permutations.py b/src/clm/src/utils/permutations.py deleted file mode 100644 index b8f6a0d7..00000000 --- a/src/clm/src/utils/permutations.py +++ /dev/null @@ -1,180 +0,0 @@ -import math -import numpy as np -import torch - - -### Bit reversal permutation - -def bitreversal_po2(n): - m = int(math.log(n)/math.log(2)) - perm = np.arange(n).reshape(n,1) - for i in range(m): - n1 = perm.shape[0]//2 - perm = np.hstack((perm[:n1],perm[n1:])) - return perm.squeeze(0) - -def bitreversal_permutation(n): - m = int(math.ceil(math.log(n)/math.log(2))) - N = 1 << m - perm = bitreversal_po2(N) - return np.extract(perm < n, perm) - -def transpose_permutation(h, w): - indices = np.arange(h*w) - indices = indices.reshape((h, w)) - indices = indices.T - indices = indices.reshape(h*w) - return indices - -def snake_permutation(h, w): - indices = np.arange(h*w) - indices = indices.reshape((h, w)) - indices[1::2, :] = indices[1::2, ::-1] - indices = indices.reshape(h*w) - return indices - -def hilbert_permutation(n): - m = int(math.log2(n)) - assert n == 2**m - inds = decode(list(range(n*n)), 2, m) - ind_x, ind_y = inds.T - indices = np.arange(n*n).reshape((n, n)) - indices = indices[ind_x, ind_y] - return(indices) - -""" Hilbert curve utilities taken from https://github.com/PrincetonLIPS/numpy-hilbert-curve """ -def decode(hilberts, num_dims, num_bits): - ''' Decode an array of Hilbert integers into locations in a hypercube. - This is a vectorized-ish version of the Hilbert curve implementation by John - Skilling as described in: - Skilling, J. (2004, April). Programming the Hilbert curve. In AIP Conference - Proceedings (Vol. 707, No. 1, pp. 381-387). American Institute of Physics. - Params: - ------- - hilberts - An ndarray of Hilbert integers. Must be an integer dtype and - cannot have fewer bits than num_dims * num_bits. - num_dims - The dimensionality of the hypercube. Integer. - num_bits - The number of bits for each dimension. Integer. - Returns: - -------- - The output is an ndarray of unsigned integers with the same shape as hilberts - but with an additional dimension of size num_dims. - ''' - - if num_dims*num_bits > 64: - raise ValueError( - ''' - num_dims=%d and num_bits=%d for %d bits total, which can't be encoded - into a uint64. Are you sure you need that many points on your Hilbert - curve? - ''' % (num_dims, num_bits) - ) - - # Handle the case where we got handed a naked integer. - hilberts = np.atleast_1d(hilberts) - - # Keep around the shape for later. - orig_shape = hilberts.shape - - # Treat each of the hilberts as a sequence of eight uint8. - # This treats all of the inputs as uint64 and makes things uniform. - hh_uint8 = np.reshape(hilberts.ravel().astype('>u8').view(np.uint8), (-1, 8)) - - # Turn these lists of uints into lists of bits and then truncate to the size - # we actually need for using Skilling's procedure. - hh_bits = np.unpackbits(hh_uint8, axis=1)[:,-num_dims*num_bits:] - - # Take the sequence of bits and Gray-code it. - gray = binary2gray(hh_bits) - - # There has got to be a better way to do this. - # I could index them differently, but the eventual packbits likes it this way. - gray = np.swapaxes( - np.reshape(gray, (-1, num_bits, num_dims)), - axis1=1, axis2=2, - ) - - # Iterate backwards through the bits. - for bit in range(num_bits-1, -1, -1): - - # Iterate backwards through the dimensions. - for dim in range(num_dims-1, -1, -1): - - # Identify which ones have this bit active. - mask = gray[:,dim,bit] - - # Where this bit is on, invert the 0 dimension for lower bits. - gray[:,0,bit+1:] = np.logical_xor(gray[:,0,bit+1:], mask[:,np.newaxis]) - - # Where the bit is off, exchange the lower bits with the 0 dimension. - to_flip = np.logical_and( - np.logical_not(mask[:,np.newaxis]), - np.logical_xor(gray[:,0,bit+1:], gray[:,dim,bit+1:]) - ) - gray[:,dim,bit+1:] = np.logical_xor(gray[:,dim,bit+1:], to_flip) - gray[:,0,bit+1:] = np.logical_xor(gray[:,0,bit+1:], to_flip) - - # Pad back out to 64 bits. - extra_dims = 64 - num_bits - padded = np.pad(gray, ((0,0), (0,0), (extra_dims,0)), - mode='constant', constant_values=0) - - # Now chop these up into blocks of 8. - locs_chopped = np.reshape(padded[:,:,::-1], (-1, num_dims, 8, 8)) - - # Take those blocks and turn them unto uint8s. - locs_uint8 = np.squeeze(np.packbits(locs_chopped, bitorder='little', axis=3)) - - # Finally, treat these as uint64s. - flat_locs = locs_uint8.view(np.uint64) - - # Return them in the expected shape. - return np.reshape(flat_locs, (*orig_shape, num_dims)) - -def right_shift(binary, k=1, axis=-1): - ''' Right shift an array of binary values. - Parameters: - ----------- - binary: An ndarray of binary values. - k: The number of bits to shift. Default 1. - axis: The axis along which to shift. Default -1. - Returns: - -------- - Returns an ndarray with zero prepended and the ends truncated, along - whatever axis was specified. -''' - - # If we're shifting the whole thing, just return zeros. - if binary.shape[axis] <= k: - return np.zeros_like(binary) - - # Determine the padding pattern. - padding = [(0,0)] * len(binary.shape) - padding[axis] = (k,0) - - # Determine the slicing pattern to eliminate just the last one. - slicing = [slice(None)] * len(binary.shape) - slicing[axis] = slice(None, -k) - - shifted = np.pad(binary[tuple(slicing)], padding, - mode='constant', constant_values=0) - - return shifted - -def binary2gray(binary, axis=-1): - ''' Convert an array of binary values into Gray codes. - This uses the classic X ^ (X >> 1) trick to compute the Gray code. - Parameters: - ----------- - binary: An ndarray of binary values. - axis: The axis along which to compute the gray code. Default=-1. - Returns: - -------- - Returns an ndarray of Gray codes. - ''' - shifted = right_shift(binary, axis=axis) - - # Do the X ^ (X >> 1) trick. - gray = np.logical_xor(binary, shifted) - - return gray diff --git a/src/clm/src/utils/registry.py b/src/clm/src/utils/registry.py deleted file mode 100644 index 7943bdcc..00000000 --- a/src/clm/src/utils/registry.py +++ /dev/null @@ -1,53 +0,0 @@ -optimizer = { - "adam": "torch.optim.Adam", - "adamw": "torch.optim.AdamW", - "rmsprop": "torch.optim.RMSprop", - "sgd": "torch.optim.SGD", - "lamb": "src.utils.optim.lamb.JITLamb", -} - -scheduler = { - "constant": "transformers.get_constant_schedule", - "plateau": "torch.optim.lr_scheduler.ReduceLROnPlateau", - "step": "torch.optim.lr_scheduler.StepLR", - "multistep": "torch.optim.lr_scheduler.MultiStepLR", - "cosine": "torch.optim.lr_scheduler.CosineAnnealingLR", - "constant_warmup": "transformers.get_constant_schedule_with_warmup", - "linear_warmup": "transformers.get_linear_schedule_with_warmup", - "cosine_warmup": "transformers.get_cosine_schedule_with_warmup", - "cosine_warmup_timm": "src.utils.optim.schedulers.TimmCosineLRScheduler", -} - -model = { - # Backbones from this repo - "model": "src.models.sequence.SequenceModel", - "lm": "src.models.sequence.long_conv_lm.ConvLMHeadModel", - "lm_simple": "src.models.sequence.simple_lm.SimpleLMHeadModel", - "vit_b_16": "src.models.baselines.vit_all.vit_base_patch16_224", -} - -layer = { - "id": "src.models.sequence.base.SequenceIdentity", - "ff": "src.models.sequence.ff.FF", - "mha": "src.models.sequence.mha.MultiheadAttention", - "s4d": "src.models.sequence.ssm.s4d.S4D", - "s4_simple": "src.models.sequence.ssm.s4_simple.SimpleS4Wrapper", - "long-conv": "src.models.sequence.long_conv.LongConv", - "h3": "src.models.sequence.h3.H3", - "h3-conv": "src.models.sequence.h3_conv.H3Conv", - "hyena": "src.models.sequence.hyena.HyenaOperator", - "hyena-filter": "src.models.sequence.hyena.HyenaFilter", - "vit": "src.models.sequence.mha.VitAttention", -} - -callbacks = { - "timer": "src.callbacks.timer.Timer", - "params": "src.callbacks.params.ParamsLog", - "learning_rate_monitor": "pytorch_lightning.callbacks.LearningRateMonitor", - "model_checkpoint": "pytorch_lightning.callbacks.ModelCheckpoint", - "early_stopping": "pytorch_lightning.callbacks.EarlyStopping", - "swa": "pytorch_lightning.callbacks.StochasticWeightAveraging", - "rich_model_summary": "pytorch_lightning.callbacks.RichModelSummary", - "rich_progress_bar": "pytorch_lightning.callbacks.RichProgressBar", - "progressive_resizing": "src.callbacks.progressive_resizing.ProgressiveResizing", -} diff --git a/src/clm/src/utils/train.py b/src/clm/src/utils/train.py deleted file mode 100644 index 12e5dbb4..00000000 --- a/src/clm/src/utils/train.py +++ /dev/null @@ -1,156 +0,0 @@ -""" Utils for the training loop. Copied from https://github.com/HazyResearch/transformers/blob/master/src/utils/utils.py """ -import logging -import os -import warnings -from typing import List, Sequence - -import torch.nn as nn -import pytorch_lightning as pl -import rich.syntax -import rich.tree -from omegaconf import DictConfig, OmegaConf -from pytorch_lightning.utilities import rank_zero_only - -from clm.src.utils.config import omegaconf_filter_keys - - -# Copied from https://docs.python.org/3/howto/logging-cookbook.html#using-a-context-manager-for-selective-logging -class LoggingContext: - def __init__(self, logger, level=None, handler=None, close=True): - self.logger = logger - self.level = level - self.handler = handler - self.close = close - - def __enter__(self): - if self.level is not None: - self.old_level = self.logger.level - self.logger.setLevel(self.level) - if self.handler: - self.logger.addHandler(self.handler) - - def __exit__(self, et, ev, tb): - if self.level is not None: - self.logger.setLevel(self.old_level) - if self.handler: - self.logger.removeHandler(self.handler) - if self.handler and self.close: - self.handler.close() - # implicit return of None => don't swallow exceptions - - -def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: - """Initializes multi-GPU-friendly python logger.""" - - logger = logging.getLogger(name) - logger.setLevel(level) - - # this ensures all logging levels get marked with the rank zero decorator - # otherwise logs would get multiplied for each GPU process in multi-GPU setup - for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"): - setattr(logger, level, rank_zero_only(getattr(logger, level))) - - return logger - - -def process_config(config: DictConfig) -> DictConfig: # TODO because of filter_keys, this is no longer in place - """A couple of optional utilities, controlled by main config file: - - disabling warnings - - easier access to debug mode - - forcing debug friendly configuration - Modifies DictConfig in place. - Args: - config (DictConfig): Configuration composed by Hydra. - """ - log = get_logger() - - # Filter out keys that were used just for interpolation - # config = dictconfig_filter_keys(config, lambda k: not k.startswith('__')) - config = omegaconf_filter_keys(config, lambda k: not k.startswith('__')) - - # enable adding new keys to config - OmegaConf.set_struct(config, False) - - # disable python warnings if - if config.get("ignore_warnings"): - log.info("Disabling python warnings! ") - warnings.filterwarnings("ignore") - - if config.get("debug"): - log.info("Running in debug mode! ") - config.trainer.fast_dev_run = True - - # force debugger friendly configuration - log.info("Forcing debugger friendly configuration! ") - # Debuggers don't like GPUs or multiprocessing - if config.trainer.get("gpus"): - config.trainer.gpus = 0 - if config.loader.get("pin_memory"): - config.loader.pin_memory = False - if config.loader.get("num_workers"): - config.loader.num_workers = 0 - - # disable adding new keys to config - # OmegaConf.set_struct(config, True) # [21-09-17 AG] I need this for .pop(_name_) pattern among other things - - return config - -@rank_zero_only -def print_config( - config: DictConfig, - resolve: bool = True, - save_cfg=True, -) -> None: - """Prints content of DictConfig using Rich library and its tree structure. - Args: - config (DictConfig): Configuration composed by Hydra. - fields (Sequence[str], optional): Determines which main fields from config will - be printed and in what order. - resolve (bool, optional): Whether to resolve reference fields of DictConfig. - """ - - style = "dim" - tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) - - fields = config.keys() - for field in fields: - branch = tree.add(field, style=style, guide_style=style) - - config_section = config.get(field) - branch_content = str(config_section) - if isinstance(config_section, DictConfig): - branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) - - branch.add(rich.syntax.Syntax(branch_content, "yaml")) - - rich.print(tree) - - if save_cfg: - with open("config_tree.txt", "w") as fp: - rich.print(tree, file=fp) - -def log_optimizer(logger, optimizer, keys): - """ Log values of particular keys from the optimizer's param groups """ - keys = sorted(keys) - for i, g in enumerate(optimizer.param_groups): - group_hps = {k: g.get(k, None) for k in keys} - logger.info(' | '.join([ - f"Optimizer group {i}", - f"{len(g['params'])} tensors", - ] + [f"{k} {v}" for k, v in group_hps.items()])) - -class OptimModule(nn.Module): - """ Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters """ - - def register(self, name, tensor, lr=None, wd=0.0): - """Register a tensor with a configurable learning rate and 0 weight decay""" - - if lr == 0.0: - self.register_buffer(name, tensor) - else: - self.register_parameter(name, nn.Parameter(tensor)) - - optim = {} - if lr is not None: optim["lr"] = lr - if wd is not None: optim["weight_decay"] = wd - setattr(getattr(self, name), "_optim", optim) \ No newline at end of file From 8272469ade69337b883727f4a2febae088a9e3d1 Mon Sep 17 00:00:00 2001 From: Vishu Gupta Date: Sat, 13 Dec 2025 15:09:41 -0500 Subject: [PATCH 03/21] run pre-commit fixing --- src/clm/commands/sample_molecules_RNN.py | 19 +- src/clm/commands/train_models_RNN.py | 18 +- src/clm/loggers.py | 2 +- src/clm/models.py | 384 +++++++++++----------- src/clm/module_library/README.md | 2 +- src/clm/module_library/sequence_module.py | 20 +- 6 files changed, 236 insertions(+), 209 deletions(-) diff --git a/src/clm/commands/sample_molecules_RNN.py b/src/clm/commands/sample_molecules_RNN.py index 19b2b831..5e5ecfd3 100644 --- a/src/clm/commands/sample_molecules_RNN.py +++ b/src/clm/commands/sample_molecules_RNN.py @@ -6,7 +6,12 @@ from tqdm import tqdm from clm.datasets import Vocabulary, SelfiesVocabulary -from clm.models import RNN, ConditionalRNN, Transformer, StructuredStateSpaceSequenceModel#, H3Model, H3ConvModel, HyenaModel +from clm.models import ( + RNN, + ConditionalRNN, + Transformer, + StructuredStateSpaceSequenceModel, +) # , H3Model, H3ConvModel, HyenaModel from clm.functions import load_dataset, write_to_csv_file logger = logging.getLogger(__name__) @@ -122,7 +127,7 @@ def sample_molecules_RNN( vocab = Vocabulary(vocab_file=vocab_file) heldout_dataset = None - + if rnn_type == "S4": assert ( heldout_file is not None @@ -133,7 +138,7 @@ def sample_molecules_RNN( vocab_file=vocab_file, ) model = StructuredStateSpaceSequenceModel( - vocabulary=vocab, # heldout_dataset.vocabulary + vocabulary=vocab, # heldout_dataset.vocabulary model_dim=embedding_size, state_dim=64, n_layers=n_layers, @@ -196,7 +201,7 @@ def sample_molecules_RNN( # dropout=dropout, # max_len=250, # inner_factor=1, - # ) + # ) elif rnn_type == "Transformer": assert ( @@ -216,7 +221,7 @@ def sample_molecules_RNN( exp_factor=4, bias=True, ) - + else: if conditional: assert ( @@ -268,7 +273,9 @@ def sample_molecules_RNN( descriptors = None if heldout_dataset is not None: # Use modulo to cycle through heldout_dataset - descriptor_indices = [(i + j) % len(heldout_dataset) for j in range(n_sequences)] + descriptor_indices = [ + (i + j) % len(heldout_dataset) for j in range(n_sequences) + ] descriptors = torch.stack( [heldout_dataset[idx][1] for idx in descriptor_indices] ) diff --git a/src/clm/commands/train_models_RNN.py b/src/clm/commands/train_models_RNN.py index 3ad3cdd2..0eb550b2 100644 --- a/src/clm/commands/train_models_RNN.py +++ b/src/clm/commands/train_models_RNN.py @@ -6,11 +6,17 @@ from torch.utils.data import DataLoader from tqdm import tqdm from rdkit import rdBase -from clm.models import RNN, ConditionalRNN, Transformer, StructuredStateSpaceSequenceModel#, H3Model, H3ConvModel, HyenaModel +from clm.models import ( + RNN, + ConditionalRNN, + Transformer, + StructuredStateSpaceSequenceModel, +) # , H3Model, H3ConvModel, HyenaModel from clm.loggers import EarlyStopping, track_loss, print_update from clm.functions import write_smiles, load_dataset import warnings + warnings.filterwarnings("ignore", category=FutureWarning) # suppress Chem.MolFromSmiles error output @@ -184,14 +190,14 @@ def train_models_RNN( if rnn_type == "S4": model = StructuredStateSpaceSequenceModel( - vocabulary=dataset.vocabulary, - model_dim=embedding_size, + vocabulary=dataset.vocabulary, + model_dim=embedding_size, state_dim=64, n_layers=n_layers, n_ssm=1, dropout=dropout, ) - + # elif rnn_type == "H3": # model = H3Model( # vocabulary=dataset.vocabulary, @@ -214,7 +220,7 @@ def train_models_RNN( # max_len=250, # use_fast_fftconv=False, # ) - + # elif rnn_type == "Hyena": # model = HyenaModel( # vocabulary=dataset.vocabulary, @@ -238,7 +244,7 @@ def train_models_RNN( exp_factor=4, bias=True, ) - + else: if conditional: model = ConditionalRNN( diff --git a/src/clm/loggers.py b/src/clm/loggers.py index 785c7f13..ebd34c3f 100644 --- a/src/clm/loggers.py +++ b/src/clm/loggers.py @@ -49,7 +49,7 @@ def __call__(self, val_loss, model, output_file, step_idx): if self.best_loss is not None: print(f"Best model (loss={self.best_loss:.4f}) already saved.") return - + # do nothing if early stopping is disabled if self.patience > 0: if self.best_loss is None: diff --git a/src/clm/models.py b/src/clm/models.py index c1c93464..865a0778 100644 --- a/src/clm/models.py +++ b/src/clm/models.py @@ -5,6 +5,7 @@ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence from einops import rearrange + # from clm.src.models.sequence.h3 import H3 # from clm.src.models.sequence.h3_conv import H3Conv # from clm.src.models.sequence.hyena_components import HyenaOperator @@ -12,7 +13,7 @@ from .module_library.sequence_model import SequenceModel -# class H3Model(nn.Module): +# class H3Model(nn.Module): # def __init__( # self, # vocabulary, @@ -25,21 +26,21 @@ # use_fast_fftconv=False, # ): # super(H3Model, self).__init__() - + # if H3 is None: # raise ImportError( # "H3 modules not found. Make sure src.models.sequence.h3 is available." # ) - + # # detect device # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - + # # vocabulary # self.vocabulary = vocabulary # self.vocabulary_size = len(self.vocabulary) # self.padding_idx = self.vocabulary.dictionary[""] # padding_t = torch.tensor(self.padding_idx).to(self.device) - + # # hyperparams # self.n_layers = n_layers # self.d_model = d_model @@ -48,12 +49,12 @@ # self.dropout = dropout # self.max_len = max_len # self.use_fast_fftconv = use_fast_fftconv - + # # model components # self.embedding = nn.Embedding( # self.vocabulary_size, self.d_model, padding_idx=padding_t # ) - + # # H3 layers # self.layers = nn.ModuleList([ # H3( @@ -67,68 +68,68 @@ # ) # for i in range(self.n_layers) # ]) - + # # dropout and output # self.norm = nn.LayerNorm(self.d_model) # self.dropout_layer = nn.Dropout(dropout) # self.output_projection = nn.Linear(self.d_model, self.vocabulary_size) - + # # loss function (ignoring padding) # self.loss_fn = nn.CrossEntropyLoss( # ignore_index=self.padding_idx, reduction="none" # ) - + # # move to GPU # if torch.cuda.is_available(): # self.cuda() - + # def forward(self, x, inference_params=None): # batch_size, seq_len = x.size() - + # # Embed the input # x = self.embedding(x) # (batch_size, seq_len, d_model) - + # # Pass through H3 layers # for layer in self.layers: # x = layer(x, inference_params=inference_params) # if self.dropout > 0: # x = self.dropout_layer(x) - + # # Normalize and project to vocabulary # x = self.norm(x) # logits = self.output_projection(x) # (batch_size, seq_len, vocab_size) - + # return logits - + # def loss(self, batch): # if len(batch) == 3: # padded, lengths, _ = batch # else: # padded, lengths = batch - + # padded = padded.to(self.device) - + # # Handle different input formats # if padded.dim() == 2: # if padded.shape[0] > padded.shape[1]: # padded = padded.transpose(0, 1) - + # # Forward pass # logits = self(padded) - + # # Calculate loss # targets = padded[:, 1:] # logits = logits[:, :-1, :] - + # loss = 0.0 # actual_len = min(logits.shape[1], targets.shape[1]) - + # for char_idx in range(actual_len): # loss += self.loss_fn(logits[:, char_idx, :], targets[:, char_idx]) - + # return loss.mean() - + # def sample( # self, # *, @@ -140,14 +141,14 @@ # ): # if max_len is None: # max_len = self.max_len - + # self.eval() - + # # Get start/stop tokens # start_token = self.vocabulary.dictionary["SOS"] # stop_token = self.vocabulary.dictionary["EOS"] # pad_token = self.vocabulary.dictionary[""] - + # # Create inference params # class InferenceParams: # def __init__(self, max_seqlen, batch_size): @@ -155,64 +156,64 @@ # self.max_batch_size = batch_size # self.sequence_len_offset = 0 # self.key_value_memory_dict = {} - + # inference_params = InferenceParams(max_len, n_sequences) - + # # Initialize with start tokens - keep only current token for recurrent stepping # current_token = torch.full( # (n_sequences, 1), start_token, dtype=torch.long, device=self.device # ) - + # loss_fn = nn.NLLLoss(reduction="none", ignore_index=pad_token) - + # finished = torch.zeros(n_sequences, dtype=torch.bool, device=self.device) # log_probs = torch.zeros(n_sequences, device=self.device) # sequences = [] - + # with torch.no_grad(): # for step in range(max_len): # # Process only the current token in recurrent mode # logits = self(current_token, inference_params=inference_params) # logits = logits[:, -1, :] # Get last (and only) position - + # logits = torch.clamp(logits, min=-1e4, max=1e4) # prob = F.softmax(logits, dim=-1) - + # if torch.isnan(prob).any() or torch.isinf(prob).any(): # break - + # outputs = torch.multinomial(prob, num_samples=1) # sequences.append(outputs) - + # log_prob = F.log_softmax(logits, dim=-1) # losses = loss_fn(log_prob, outputs.squeeze(1)) # losses[finished] = 0 # log_probs += losses - + # # Update current token for next step (don't accumulate) # current_token = outputs # inference_params.sequence_len_offset += 1 - + # finished = finished | (outputs.squeeze(1) == stop_token) # if finished.all(): # break - + # seqs = torch.cat(sequences, 1) if sequences else torch.full( # (n_sequences, 1), start_token, dtype=torch.long, device=self.device # ) - + # if return_smiles: # outputs = [self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs] # else: # outputs = sequences - + # if return_losses: # return outputs, log_probs.detach().cpu().numpy() # else: # return outputs -# class H3ConvModel(nn.Module): +# class H3ConvModel(nn.Module): # def __init__( # self, # vocabulary, @@ -224,21 +225,21 @@ # use_fast_fftconv=False, # ): # super(H3ConvModel, self).__init__() - + # if H3Conv is None: # raise ImportError( # "H3Conv modules not found. Make sure src.models.sequence.h3_conv is available." # ) - + # # detect device # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - + # # vocabulary # self.vocabulary = vocabulary # self.vocabulary_size = len(self.vocabulary) # self.padding_idx = self.vocabulary.dictionary[""] # padding_t = torch.tensor(self.padding_idx).to(self.device) - + # # hyperparams # self.n_layers = n_layers # self.d_model = d_model @@ -246,12 +247,12 @@ # self.dropout = dropout # self.max_len = max_len # self.use_fast_fftconv = use_fast_fftconv - + # # model components # self.embedding = nn.Embedding( # self.vocabulary_size, self.d_model, padding_idx=padding_t # ) - + # # H3Conv layers # self.layers = nn.ModuleList([ # H3Conv( @@ -264,67 +265,67 @@ # ) # for i in range(self.n_layers) # ]) - + # # dropout and output # self.norm = nn.LayerNorm(self.d_model) # self.dropout_layer = nn.Dropout(dropout) # self.output_projection = nn.Linear(self.d_model, self.vocabulary_size) - + # # loss function (ignoring padding) # self.loss_fn = nn.CrossEntropyLoss( # ignore_index=self.padding_idx, reduction="none" # ) - + # # move to GPU # if torch.cuda.is_available(): # self.cuda() - + # def forward(self, x, inference_params=None): # batch_size, seq_len = x.size() - + # # Embed the input # x = self.embedding(x) # (batch_size, seq_len, d_model) - + # # Pass through H3Conv layers # for layer in self.layers: # x = layer(x, inference_params=inference_params) # if self.dropout > 0: # x = self.dropout_layer(x) - + # # Normalize and project to vocabulary # x = self.norm(x) # logits = self.output_projection(x) - + # return logits - + # def loss(self, batch): # if len(batch) == 3: # padded, lengths, _ = batch # else: # padded, lengths = batch - + # padded = padded.to(self.device) - + # # Handle different input formats # if padded.dim() == 2: # if padded.shape[0] > padded.shape[1]: # padded = padded.transpose(0, 1) - + # # Forward pass # logits = self(padded) - + # # Calculate loss # targets = padded[:, 1:] # logits = logits[:, :-1, :] - + # loss = 0.0 # actual_len = min(logits.shape[1], targets.shape[1]) - + # for char_idx in range(actual_len): # loss += self.loss_fn(logits[:, char_idx, :], targets[:, char_idx]) - + # return loss.mean() - + # def sample( # self, # *, @@ -336,65 +337,65 @@ # ): # if max_len is None: # max_len = self.max_len - + # self.eval() - + # start_token = self.vocabulary.dictionary["SOS"] # stop_token = self.vocabulary.dictionary["EOS"] # pad_token = self.vocabulary.dictionary[""] - + # # H3Conv doesn't use stateful inference, process full sequence each time # inputs = torch.full( # (n_sequences, 1), start_token, dtype=torch.long, device=self.device # ) - + # loss_fn = nn.NLLLoss(reduction="none", ignore_index=pad_token) - + # finished = torch.zeros(n_sequences, dtype=torch.bool, device=self.device) # log_probs = torch.zeros(n_sequences, device=self.device) # sequences = [] - + # with torch.no_grad(): # for step in range(max_len): # logits = self(inputs) # logits = logits[:, -1, :] - + # logits = torch.clamp(logits, min=-1e4, max=1e4) # prob = F.softmax(logits, dim=-1) - + # if torch.isnan(prob).any() or torch.isinf(prob).any(): # break - + # outputs = torch.multinomial(prob, num_samples=1) # sequences.append(outputs) - + # log_prob = F.log_softmax(logits, dim=-1) # losses = loss_fn(log_prob, outputs.squeeze(1)) # losses[finished] = 0 # log_probs += losses - + # inputs = torch.cat([inputs, outputs], dim=1) - + # finished = finished | (outputs.squeeze(1) == stop_token) # if finished.all(): # break - + # seqs = torch.cat(sequences, 1) if sequences else torch.full( # (n_sequences, 1), start_token, dtype=torch.long, device=self.device # ) - + # if return_smiles: # outputs = [self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs] # else: # outputs = sequences - + # if return_losses: # return outputs, log_probs.detach().cpu().numpy() # else: # return outputs -# class HyenaModel(nn.Module): +# class HyenaModel(nn.Module): # def __init__( # self, # vocabulary, @@ -408,16 +409,16 @@ # inner_factor=1, # ): # super(HyenaModel, self).__init__() - + # # detect device # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - + # # vocabulary # self.vocabulary = vocabulary # self.vocabulary_size = len(self.vocabulary) # self.padding_idx = self.vocabulary.dictionary[""] # padding_t = torch.tensor(self.padding_idx).to(self.device) - + # # hyperparams # self.n_layers = n_layers # self.d_model = d_model @@ -427,12 +428,12 @@ # self.dropout = dropout # self.max_len = max_len # self.inner_factor = inner_factor - + # # model components # self.embedding = nn.Embedding( # self.vocabulary_size, self.d_model, padding_idx=padding_t # ) - + # # Hyena layers # self.layers = nn.ModuleList([ # HyenaOperator( @@ -446,27 +447,27 @@ # ) # for i in range(self.n_layers) # ]) - + # # dropout and output # self.norm = nn.LayerNorm(self.d_model) # self.dropout_layer = nn.Dropout(dropout) # self.output_projection = nn.Linear(self.d_model, self.vocabulary_size) - + # # loss function (ignoring padding) # self.loss_fn = nn.CrossEntropyLoss( # ignore_index=self.padding_idx, reduction="none" # ) - + # # move to GPU # if torch.cuda.is_available(): # self.cuda() - + # def forward(self, x): # batch_size, seq_len = x.size() - + # # Embed the input # x = self.embedding(x) # (batch_size, seq_len, d_model) - + # # Pass through Hyena layers # for layer in self.layers: # residual = x @@ -474,41 +475,41 @@ # x = x + residual # Residual connection # if self.dropout > 0: # x = self.dropout_layer(x) - + # # Normalize and project to vocabulary # x = self.norm(x) # logits = self.output_projection(x) - + # return logits - + # def loss(self, batch): # if len(batch) == 3: # padded, lengths, _ = batch # else: # padded, lengths = batch - + # padded = padded.to(self.device) - + # # Handle different input formats # if padded.dim() == 2: # if padded.shape[0] > padded.shape[1]: # padded = padded.transpose(0, 1) - + # # Forward pass # logits = self(padded) - + # # Calculate loss # targets = padded[:, 1:] # logits = logits[:, :-1, :] - + # loss = 0.0 # actual_len = min(logits.shape[1], targets.shape[1]) - + # for char_idx in range(actual_len): # loss += self.loss_fn(logits[:, char_idx, :], targets[:, char_idx]) - + # return loss.mean() - + # def sample( # self, # *, @@ -520,59 +521,59 @@ # ): # if max_len is None: # max_len = self.max_len - + # self.eval() - + # start_token = self.vocabulary.dictionary["SOS"] # stop_token = self.vocabulary.dictionary["EOS"] # pad_token = self.vocabulary.dictionary[""] - + # # Initialize with start tokens # inputs = torch.full( # (n_sequences, 1), start_token, dtype=torch.long, device=self.device # ) - + # loss_fn = nn.NLLLoss(reduction="none", ignore_index=pad_token) - + # finished = torch.zeros(n_sequences, dtype=torch.bool, device=self.device) # log_probs = torch.zeros(n_sequences, device=self.device) # sequences = [] - + # with torch.no_grad(): # for step in range(max_len): # # Hyena processes full sequence each time (stateless) # logits = self(inputs) # logits = logits[:, -1, :] - + # logits = torch.clamp(logits, min=-1e4, max=1e4) # prob = F.softmax(logits, dim=-1) - + # if torch.isnan(prob).any() or torch.isinf(prob).any(): # break - + # outputs = torch.multinomial(prob, num_samples=1) # sequences.append(outputs) - + # log_prob = F.log_softmax(logits, dim=-1) # losses = loss_fn(log_prob, outputs.squeeze(1)) # losses[finished] = 0 # log_probs += losses - + # inputs = torch.cat([inputs, outputs], dim=1) - + # finished = finished | (outputs.squeeze(1) == stop_token) # if finished.all(): # break - + # seqs = torch.cat(sequences, 1) if sequences else torch.full( # (n_sequences, 1), start_token, dtype=torch.long, device=self.device # ) - + # if return_smiles: # outputs = [self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs] # else: # outputs = sequences - + # if return_losses: # return outputs, log_probs.detach().cpu().numpy() # else: @@ -591,16 +592,16 @@ def __init__( max_len=250, ): super(StructuredStateSpaceSequenceModel, self).__init__() - + # detect device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - + # vocabulary self.vocabulary = vocabulary self.vocabulary_size = len(self.vocabulary) self.padding_idx = self.vocabulary.dictionary[""] padding_t = torch.tensor(self.padding_idx).to(self.device) - + # hyperparams self.model_dim = model_dim self.state_dim = state_dim @@ -608,7 +609,7 @@ def __init__( self.n_ssm = n_ssm self.dropout = dropout self.max_len = max_len - + # S4 layer configuration self.layer_config = [ { @@ -624,15 +625,15 @@ def __init__( {"_name_": "ff"}, ] self.pool_config = {"_name_": "pool", "stride": 1, "expand": None} - + # model components self.embedding = nn.Embedding( self.vocabulary_size, self.model_dim, padding_idx=padding_t ) - + # Import SequenceModel from your module library from .module_library.sequence_model import SequenceModel - + self.model = SequenceModel( d_model=self.model_dim, n_layers=self.n_layers, @@ -641,57 +642,57 @@ def __init__( layer=self.layer_config, pool=self.pool_config, ) - + self.output_embedding = nn.Linear(self.model_dim, self.vocabulary_size) self.recurrent_state = None - + # loss function (ignoring padding) self.loss_fn = nn.CrossEntropyLoss( ignore_index=self.padding_idx, reduction="none" ) - + # move to GPU if torch.cuda.is_available(): self.cuda() - + def forward(self, x): batch_size, seq_len = x.size() - + # Embed the input x = self.embedding(x) # (batch_size, seq_len, model_dim) - + # Pass through S4 model (without state in training mode) x, _ = self.model(x, state=None) - + # Project to vocabulary logits = self.output_embedding(x) # (batch_size, seq_len, vocab_size) - + return logits - + def reset_state(self, batch_size, device=None): if device is None: device = self.device self.recurrent_state = self.model.default_state(batch_size, device=device) - + def recurrent_step(self, x_t): if x_t.dim() == 1: x_t = x_t.unsqueeze(1) - + x_t = self.embedding(x_t).squeeze(1) # (batch_size, model_dim) x_t, state = self.model.step(x_t, state=self.recurrent_state) self.recurrent_state = state x_t = self.output_embedding(x_t) # (batch_size, vocab_size) - + return x_t - + def loss(self, batch): if len(batch) == 3: padded, lengths, _ = batch else: padded, lengths = batch - + padded = padded.to(self.device) - + # Handle different input formats # RNN format is typically (seq_len, batch_size) # S4/Transformer format is typically (batch_size, seq_len) @@ -699,30 +700,30 @@ def loss(self, batch): if padded.shape[0] > padded.shape[1]: # Likely (seq_len, batch_size), transpose to (batch_size, seq_len) padded = padded.transpose(0, 1) - + batch_size = padded.shape[0] seq_len = padded.shape[1] - + # Don't use recurrent state during training - use full convolution mode self.recurrent_state = None - + # Forward pass logits = self(padded) # (batch_size, seq_len, vocab_size) - + # Calculate loss # Shift targets: predict next token targets = padded[:, 1:] # (batch_size, seq_len-1) logits = logits[:, :-1, :] # (batch_size, seq_len-1, vocab_size) - + # Reshape for loss calculation loss = 0.0 actual_len = min(logits.shape[1], targets.shape[1]) - + for char_idx in range(actual_len): loss += self.loss_fn(logits[:, char_idx, :], targets[:, char_idx]) - + return loss.mean() - + def sample( self, *, @@ -734,84 +735,83 @@ def sample( ): if max_len is None: max_len = self.max_len - + # IMPORTANT: Set model to eval mode before sampling self.eval() - + # Setup for recurrent mode for module in self.model.modules(): if hasattr(module, "setup_step"): module.setup_step() - + # Reset state self.reset_state(n_sequences, device=self.device) - + # Get start/stop tokens start_token = self.vocabulary.dictionary["SOS"] stop_token = self.vocabulary.dictionary["EOS"] pad_token = self.vocabulary.dictionary[""] - + # Create start token tensor - inputs = ( - torch.empty(n_sequences) - .fill_(start_token) - .long() - .to(self.device) - ) - + inputs = torch.empty(n_sequences).fill_(start_token).long().to(self.device) + # Setup loss function loss_fn = nn.NLLLoss(reduction="none", ignore_index=pad_token) - + # Sample sequences finished = torch.zeros(n_sequences).byte().to(self.device) log_probs = torch.zeros(n_sequences).to(self.device) sequences = [] - + with torch.no_grad(): # Also add no_grad for efficiency for step in range(max_len): # Get logits for current input logits = self.recurrent_step(inputs) - + # Clamp logits to prevent inf/nan logits = torch.clamp(logits, min=-1e4, max=1e4) - + # Sample from distribution prob = F.softmax(logits, dim=-1) - + # Check for invalid values if torch.isnan(prob).any() or torch.isinf(prob).any(): break - + outputs = torch.multinomial(prob, num_samples=1).squeeze(1) - + sequences.append(outputs.view(-1, 1)) - + # Calculate NLL log_prob = F.log_softmax(logits, dim=-1) losses = loss_fn(log_prob, outputs) - + # Zero losses if we are finished sampling losses[finished.bool()] = 0 log_probs += losses - + # Update inputs for next step inputs = outputs - + # Track whether sampling is done for all molecules finished = torch.ge(finished + (outputs == stop_token), 1) if torch.prod(finished) == 1: break - + # Concatenate sequences and decode - seqs = torch.cat(sequences, 1) if sequences else torch.empty( - n_sequences, 1, dtype=torch.long - ).fill_(start_token).to(self.device) - + seqs = ( + torch.cat(sequences, 1) + if sequences + else torch.empty(n_sequences, 1, dtype=torch.long) + .fill_(start_token) + .to(self.device) + ) + if return_smiles: outputs = [self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs] else: outputs = sequences - + # Optionally return losses if return_losses: return outputs, log_probs.detach().cpu().numpy() @@ -1231,21 +1231,23 @@ def loss(self, batch): padded, lengths, _ = batch else: padded, lengths = batch - + padded = padded.to(self.device) - + # Get actual sequence length from batch actual_seq_len = padded.shape[1] - + decoded = self(padded) # batch_size x seq_len x vocab_size - + loss = 0.0 targets = padded[:, 1:] # batch_size x (seq_len-1) - + # Loop only up to actual decoded sequence length minus 1 - for char_idx in range(min(actual_seq_len - 1, decoded.shape[1], targets.shape[1])): + for char_idx in range( + min(actual_seq_len - 1, decoded.shape[1], targets.shape[1]) + ): loss += self.loss_fn(decoded[:, char_idx, :], targets[:, char_idx]) - + return loss.mean() def sample( @@ -1253,7 +1255,7 @@ def sample( ): # Reset recurrent state before sampling self.reset_state(n_sequences, device=self.device) - + # get start/stop tokens start_token = self.vocabulary.dictionary["SOS"] stop_token = self.vocabulary.dictionary["EOS"] @@ -1280,11 +1282,11 @@ def sample( # Clamp logits to prevent inf/nan logits = torch.clamp(logits, min=-1e4, max=1e4) prob = F.softmax(logits, dim=-1) - + # Check for invalid values and skip if found if torch.isnan(prob).any() or torch.isinf(prob).any(): break - + outputs = torch.multinomial(prob, num_samples=1) # append to growing sequence inputs = torch.cat((inputs, outputs), dim=1) @@ -1301,7 +1303,13 @@ def sample( break # concatenate sequences and decode - seqs = torch.cat(sequences, 1) if sequences else torch.empty(n_sequences, 1, dtype=torch.long).fill_(start_token).to(self.device) + seqs = ( + torch.cat(sequences, 1) + if sequences + else torch.empty(n_sequences, 1, dtype=torch.long) + .fill_(start_token) + .to(self.device) + ) if return_smiles: outputs = [self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs] else: diff --git a/src/clm/module_library/README.md b/src/clm/module_library/README.md index 42b03fa3..ec077c0a 100644 --- a/src/clm/module_library/README.md +++ b/src/clm/module_library/README.md @@ -1 +1 @@ -These modules are heavily borrowed from the [original codebase for S4](https://github.com/HazyResearch/state-spaces) and empower the S4 model. Visit the original repository for more information. \ No newline at end of file +These modules are heavily borrowed from the [original codebase for S4](https://github.com/HazyResearch/state-spaces) and empower the S4 model. Visit the original repository for more information. diff --git a/src/clm/module_library/sequence_module.py b/src/clm/module_library/sequence_module.py index 4f8a4ffa..7daa2842 100644 --- a/src/clm/module_library/sequence_module.py +++ b/src/clm/module_library/sequence_module.py @@ -1,6 +1,7 @@ from torch import nn import functools + class SequenceModule(nn.Module): """Abstract sequence model class. All models must adhere to this interface @@ -40,7 +41,9 @@ def d_output(self): It is used by the rest of the pipeline (e.g. model backbone, decoder) to track the internal shapes of the full model. """ if getattr(self, "_d_output", None) is None: - raise NotImplementedError("SequenceModule instantiation must specify d_output for decoder") + raise NotImplementedError( + "SequenceModule instantiation must specify d_output for decoder" + ) return self._d_output @d_output.setter @@ -69,10 +72,9 @@ def state_to_tensor(self): @property def d_state(self): - """ Returns dimension of output of self.state_to_tensor """ + """Returns dimension of output of self.state_to_tensor""" return None - def default_state(self, *batch_shape, device=None): """Create initial state for a batch of inputs.""" @@ -87,6 +89,7 @@ def step(self, x, state=None, **kwargs): """ raise NotImplementedError + def TransposedModule(module): """Wrap a SequenceModule class to accept transposed parameter, handle state, absorb kwargs""" # https://stackoverflow.com/a/65470430/1980685 @@ -97,15 +100,19 @@ def __init__(self, *args, transposed=False, **kwargs): self.transposed = transposed def forward(self, x, state=None, **kwargs): - if self.transposed: x = x.transpose(-1, -2) - x, next_state = super().forward(x, state) # Don't use kwarg because nn.LSTM + if self.transposed: + x = x.transpose(-1, -2) + x, next_state = super().forward(x, state) # Don't use kwarg because nn.LSTM next_state = None if state is None else next_state - if self.transposed: x = x.transpose(-1,-2) + if self.transposed: + x = x.transpose(-1, -2) return x, next_state + # https://stackoverflow.com/questions/5352781/how-to-set-class-names-dynamically # TransposedModule.__name__ = module.__name__ # functools wraps is better solution return TransposedModule + @TransposedModule class SequenceIdentity(SequenceModule): """Simple SequenceModule for testing purposes""" @@ -120,7 +127,6 @@ def __init__(self, d_model, dropout=0.0, **kwargs): self.d_model = d_model self.d_output = d_model - def forward(self, x, state=None): return x, state From 5a2e7005930d9b156f0464a5ef498182e1d67a65 Mon Sep 17 00:00:00 2001 From: Vishu Gupta Date: Sat, 13 Dec 2025 15:15:01 -0500 Subject: [PATCH 04/21] resolve some flake8 linting errors --- src/clm/module_library/s4.py | 6 +++--- src/clm/module_library/sequence_model.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/clm/module_library/s4.py b/src/clm/module_library/s4.py index cb462e26..bed7b06c 100644 --- a/src/clm/module_library/s4.py +++ b/src/clm/module_library/s4.py @@ -4,6 +4,9 @@ import opt_einsum as oe from einops import rearrange +from .kernel import SSKernel +from .util_modules import LinearActivation, Activation, DropoutNd + optimized = True if optimized: @@ -11,9 +14,6 @@ else: contract = torch.einsum -from .kernel import SSKernel -from .util_modules import LinearActivation, Activation, DropoutNd - class S4(nn.Module): def __init__( diff --git a/src/clm/module_library/sequence_model.py b/src/clm/module_library/sequence_model.py index 03949f98..803746a1 100644 --- a/src/clm/module_library/sequence_model.py +++ b/src/clm/module_library/sequence_model.py @@ -102,12 +102,12 @@ def __init__( # Instantiate layers _layers = [] d = d_model - for l, layer in enumerate(layers): + for layer_idx, layer in enumerate(layers): # Pool at the end of every n_repeat blocks - pool_cfg = pool if (l + 1) % n_repeat == 0 else None + pool_cfg = pool if (layer_idx + 1) % n_repeat == 0 else None block = SequenceResidualBlock( d, - l + 1, + layer_idx + 1, prenorm=prenorm, dropout=dropout, tie_dropout=tie_dropout, From 7716a7fff781604bd399fc24e11a5a66e2a537bf Mon Sep 17 00:00:00 2001 From: Vishu Gupta Date: Sat, 13 Dec 2025 15:19:19 -0500 Subject: [PATCH 05/21] resolve some flake8 linting errors --- src/clm/module_library/krylov.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/clm/module_library/krylov.py b/src/clm/module_library/krylov.py index 9d2b1e4b..8122edbe 100644 --- a/src/clm/module_library/krylov.py +++ b/src/clm/module_library/krylov.py @@ -182,11 +182,11 @@ def krylov_toeplitz_(L, A, b, c=None): A = F.pad(A, (0, N)) done = L == 1 while not done: - l = x.shape[0] + length = x.shape[0] # Save memory on last iteration - if L - l <= l: + if L - length <= length: done = True - _x = x[: L - l] + _x = x[: L - length] else: _x = x Af = torch.fft.rfft(A, n=2 * N, dim=-1) From 8a49bf83707ad046fbe205158b4e568c86a1bd6a Mon Sep 17 00:00:00 2001 From: Vishu Gupta Date: Sat, 13 Dec 2025 15:25:18 -0500 Subject: [PATCH 06/21] resolve some flake8 linting errors --- src/clm/module_library/krylov.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/clm/module_library/krylov.py b/src/clm/module_library/krylov.py index 8122edbe..baa785cf 100644 --- a/src/clm/module_library/krylov.py +++ b/src/clm/module_library/krylov.py @@ -100,24 +100,24 @@ def power(L, A, v=None): v: (..., N, L) """ - I = torch.eye(A.shape[-1]).to(A) # , dtype=A.dtype, device=A.device) + identity = torch.eye(A.shape[-1]).to(A) # Changed from I to identity powers = [A] - l = 1 + power_of_2 = 1 # Changed from l to power_of_2 while True: if L % 2 == 1: - I = powers[-1] @ I + identity = powers[-1] @ identity # Changed from I to identity L //= 2 if L == 0: break - l *= 2 + power_of_2 *= 2 # Changed from l to power_of_2 if v is None: powers = [powers[-1] @ powers[-1]] else: powers.append(powers[-1] @ powers[-1]) if v is None: - return I + return identity # Changed from I to identity # Invariants: # powers[-1] := A^l @@ -130,16 +130,16 @@ def power(L, A, v=None): # Take care of edge case for non-po2 arrays # Note that this initial step is a no-op for the case of power of 2 (l == L) - k = v.size(-1) - l - v_ = powers.pop() @ v[..., l:] - v = v[..., :l] + k = v.size(-1) - power_of_2 # Changed from l to power_of_2 + v_ = powers.pop() @ v[..., power_of_2:] # Changed from l to power_of_2 + v = v[..., :power_of_2] # Changed from l to power_of_2 v[..., :k] = v[..., :k] + v_ # Handle reduction for power of 2 while v.size(-1) > 1: v = rearrange(v, "... (z l) -> ... z l", z=2) v = v[..., 0, :] + powers.pop() @ v[..., 1, :] - return I, v.squeeze(-1) + return identity, v.squeeze(-1) # Changed from I to identity def krylov_toeplitz(L, A, b, c=None): From 4b436082e222bec22053f09dc30e7ebe23983f31 Mon Sep 17 00:00:00 2001 From: Vishu Gupta Date: Sat, 13 Dec 2025 15:29:23 -0500 Subject: [PATCH 07/21] resolve some flake8 linting errors --- src/clm/module_library/krylov.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/clm/module_library/krylov.py b/src/clm/module_library/krylov.py index baa785cf..035e00d0 100644 --- a/src/clm/module_library/krylov.py +++ b/src/clm/module_library/krylov.py @@ -67,10 +67,10 @@ def krylov(L, A, b, c=None, return_power=False): _L //= 2 # Save memory on last iteration - l = x.shape[-1] - if L - l <= l: + current_length = x.shape[-1] + if L - current_length <= current_length: done = True - _x = x[..., : L - l] + _x = x[..., : L - current_length] else: _x = x @@ -184,7 +184,7 @@ def krylov_toeplitz_(L, A, b, c=None): while not done: length = x.shape[0] # Save memory on last iteration - if L - length <= length: + if L - length <= length: done = True _x = x[: L - length] else: From 10177e1a244a7ab1cc8a377d1a12da40c70babe2 Mon Sep 17 00:00:00 2001 From: Vishu Gupta Date: Sat, 13 Dec 2025 15:35:59 -0500 Subject: [PATCH 08/21] resolve some flake8 linting errors --- src/clm/module_library/kernel.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/clm/module_library/kernel.py b/src/clm/module_library/kernel.py index 6f9425cf..437e3750 100644 --- a/src/clm/module_library/kernel.py +++ b/src/clm/module_library/kernel.py @@ -481,10 +481,10 @@ def _setup_step(self, mode="dense"): else: # self.C represents C_tilde dA_L = power(self.L.item(), self.dA) - I = torch.eye(self.dA.size(-1)).to(dA_L) + identity_matrix = torch.eye(self.dA. size(-1)).to(dA_L) dC = torch.linalg.solve( - I - dA_L.transpose(-1, -2), + identity_matrix - dA_L.transpose(-1, -2), C.unsqueeze(-1), ).squeeze(-1) self.dC = dC @@ -593,7 +593,7 @@ def __init__( measure_args={}, **kernel_args, ): - """State Space Kernel which computes the convolution kernel $\\bar{K}$ + """State Space Kernel which computes the convolution kernel $\\\\bar{K}$ H: Number of independent SSM copies; controls the size of the model. Also called d_model in the config. N: State size (dimensionality of parameters A, B, C). Also called d_state in the config. Generally shouldn't need to be adjusted and doens't affect speed much. From 837a0e406fa1c9fe376a283878478b7c94056268 Mon Sep 17 00:00:00 2001 From: Vishu Gupta Date: Wed, 17 Dec 2025 21:43:34 -0500 Subject: [PATCH 09/21] add setup.sh for environment installation after removing src/clm/module_library and use code from s4dd git repo directly --- requirements.txt | 2 + setup.sh | 27 + src/clm/module_library/README.md | 1 - src/clm/module_library/__init__.py | 0 src/clm/module_library/cauchy.py | 18 - src/clm/module_library/dplr.py | 129 ---- src/clm/module_library/ff.py | 65 -- src/clm/module_library/hippo.py | 277 ------- src/clm/module_library/kernel.py | 718 ------------------ src/clm/module_library/krylov.py | 209 ----- src/clm/module_library/pool.py | 62 -- src/clm/module_library/residual.py | 23 - src/clm/module_library/s4.py | 290 ------- src/clm/module_library/sequence_model.py | 204 ----- src/clm/module_library/sequence_module.py | 137 ---- .../module_library/sequence_residual_block.py | 148 ---- src/clm/module_library/toeplitz.py | 156 ---- src/clm/module_library/util_modules.py | 318 -------- 18 files changed, 29 insertions(+), 2755 deletions(-) create mode 100644 setup.sh delete mode 100644 src/clm/module_library/README.md delete mode 100644 src/clm/module_library/__init__.py delete mode 100644 src/clm/module_library/cauchy.py delete mode 100644 src/clm/module_library/dplr.py delete mode 100644 src/clm/module_library/ff.py delete mode 100644 src/clm/module_library/hippo.py delete mode 100644 src/clm/module_library/kernel.py delete mode 100644 src/clm/module_library/krylov.py delete mode 100644 src/clm/module_library/pool.py delete mode 100644 src/clm/module_library/residual.py delete mode 100644 src/clm/module_library/s4.py delete mode 100644 src/clm/module_library/sequence_model.py delete mode 100644 src/clm/module_library/sequence_module.py delete mode 100644 src/clm/module_library/sequence_residual_block.py delete mode 100644 src/clm/module_library/toeplitz.py delete mode 100644 src/clm/module_library/util_modules.py diff --git a/requirements.txt b/requirements.txt index 7657e54b..89725aae 100644 --- a/requirements.txt +++ b/requirements.txt @@ -128,3 +128,5 @@ virtualenv==20.26.2 wrapt==1.16.0 yte==1.5.4 zipp==3.19.0 +einops==0.6.0 +opt_einsum==3.3.0 \ No newline at end of file diff --git a/setup.sh b/setup.sh new file mode 100644 index 00000000..90063e3c --- /dev/null +++ b/setup.sh @@ -0,0 +1,27 @@ +#!/bin/bash +# setup.sh + +set -e # Exit on any error + +# Initialize conda for bash +eval "$(conda shell.bash hook)" + +# Create and activate environment +conda create --name clm python=3.10 pip -y +conda activate clm + +# Install main requirements +conda env update --file environment.yml + +# Install s4dd from source +if [ ! -d "s4-for-de-novo-drug-design" ]; then + git clone https://github.com/molML/s4-for-de-novo-drug-design.git +fi +cd s4-for-de-novo-drug-design +pip install -e . +cd .. + +# Install CLM package +pip install -e . --no-deps + +echo "Environment setup complete!" \ No newline at end of file diff --git a/src/clm/module_library/README.md b/src/clm/module_library/README.md deleted file mode 100644 index ec077c0a..00000000 --- a/src/clm/module_library/README.md +++ /dev/null @@ -1 +0,0 @@ -These modules are heavily borrowed from the [original codebase for S4](https://github.com/HazyResearch/state-spaces) and empower the S4 model. Visit the original repository for more information. diff --git a/src/clm/module_library/__init__.py b/src/clm/module_library/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/clm/module_library/cauchy.py b/src/clm/module_library/cauchy.py deleted file mode 100644 index e774a534..00000000 --- a/src/clm/module_library/cauchy.py +++ /dev/null @@ -1,18 +0,0 @@ -import torch - - -_conj = lambda x: torch.cat([x, x.conj()], dim=-1) - - -def cauchy_naive(v, z, w, conj=True): - """ - v: (..., N) - z: (..., L) - w: (..., N) - returns: (..., L) \sum v/(z-w) - """ - if conj: - v = _conj(v) - w = _conj(w) - cauchy_matrix = v.unsqueeze(-1) / (z.unsqueeze(-2) - w.unsqueeze(-1)) # (... N L) - return torch.sum(cauchy_matrix, dim=-2) diff --git a/src/clm/module_library/dplr.py b/src/clm/module_library/dplr.py deleted file mode 100644 index 03ee00a6..00000000 --- a/src/clm/module_library/dplr.py +++ /dev/null @@ -1,129 +0,0 @@ -import math -import torch -from einops import repeat -from . import hippo - - -def dplr( - scaling="linear", - N=64, - rank=1, - H=1, - dtype=torch.float, - real_scale=1.0, - imag_scale=1.0, - random_real=False, - random_imag=False, - normalize=False, - diagonal=True, - random_B=False, -): - assert dtype == torch.float or dtype == torch.double - dtype = torch.cfloat if dtype == torch.float else torch.cdouble - - pi = torch.tensor(math.pi) - if random_real: - real_part = torch.rand(H, N // 2) - else: - real_part = 0.5 * torch.ones(H, N // 2) - if random_imag: - imag_part = N // 2 * torch.rand(H, N // 2) - else: - imag_part = repeat(torch.arange(N // 2), "n -> h n", h=H) - - real_part = real_scale * real_part - if scaling == "random": - imag_part = torch.randn(H, N // 2) - elif scaling == "real": - imag_part = 0 * imag_part - real_part = 1 + repeat(torch.arange(N // 2), "n -> h n", h=H) - elif scaling in ["linear", "lin"]: - imag_part = pi * imag_part - elif scaling in [ - "inverse", - "inv", - ]: # Based on asymptotics of the default HiPPO matrix - imag_part = 1 / pi * N * (N / (1 + 2 * imag_part) - 1) - elif scaling in ["inverse2", "inv2"]: - imag_part = 1 / pi * N * (N / (1 + imag_part) - 1) - elif scaling in ["quadratic", "quad"]: - imag_part = 1 / pi * (1 + 2 * imag_part) ** 2 - elif scaling in ["legs", "hippo"]: - w, _, _, _ = hippo.nplr("legsd", N) - imag_part = w.imag - - else: - raise NotImplementedError - imag_part = imag_scale * imag_part - w = -real_part + 1j * imag_part - - # Initialize B - if random_B: - B = torch.randn(H, N // 2, dtype=dtype) - else: - B = torch.ones(H, N // 2, dtype=dtype) - - if normalize: - norm = ( - -B / w - ) # (H, N) # Result if you integrate the kernel with constant 1 function - zeta = 2 * torch.sum( - torch.abs(norm) ** 2, dim=-1, keepdim=True - ) # Variance with a random C vector - B = B / zeta**0.5 - - P = torch.randn(rank, H, N // 2, dtype=dtype) - if diagonal: - P = P * 0.0 - V = torch.eye(N, dtype=dtype)[:, : N // 2] # Only used in testing - V = repeat(V, "n m -> h n m", h=H) - - return w, P, B, V - - -def ssm(measure, N, R, H, **ssm_args): - """Dispatcher to create single SSM initialization - - N: state size - R: rank (for DPLR parameterization) - H: number of independent SSM copies - """ - - if measure == "dplr": - w, P, B, V = dplr(N=N, rank=R, H=H, **ssm_args) - elif measure.startswith("diag"): - args = measure.split("-") - assert args[0] == "diag" and len(args) > 1 - scaling = args[1] - w, P, B, V = dplr(scaling=scaling, N=N, rank=R, H=H, diagonal=True, **ssm_args) - else: - w, P, B, V = hippo.nplr(measure, N, R, **ssm_args) - w = repeat(w, "n -> s n", s=H) - P = repeat(P, "r n -> r s n", s=H) - B = repeat(B, "n -> s n", s=H) - V = repeat(V, "n m -> s n m", s=H) - return w, P, B, V - - -combinations = { - "hippo": ["legs", "fourier"], - "diag": ["diag-inv", "diag-lin"], - "all": ["legs", "fourier", "diag-inv", "diag-lin"], -} - - -def combination(measures, N, R, S, **ssm_args): - if isinstance(measures, str): - measures = combinations[measures] if measures in combinations else [measures] - - assert ( - S % len(measures) == 0 - ), f"{S} independent trainable SSM copies must be multiple of {len(measures)} different measures" - w, P, B, V = zip( - *[ssm(measure, N, R, S // len(measures), **ssm_args) for measure in measures] - ) - w = torch.cat(w, dim=0) # (S N) - P = torch.cat(P, dim=1) # (R S N) - B = torch.cat(B, dim=0) # (S N) - V = torch.cat(V, dim=0) # (S N N) - return w, P, B, V diff --git a/src/clm/module_library/ff.py b/src/clm/module_library/ff.py deleted file mode 100644 index 16341a0b..00000000 --- a/src/clm/module_library/ff.py +++ /dev/null @@ -1,65 +0,0 @@ -from functools import partial -from torch import nn -from .sequence_module import SequenceModule -from .util_modules import LinearActivation, DropoutNd - - -class FF(SequenceModule): - def __init__( - self, - d_input, - # expand=2, # changed the default value from 2 to 4 - expand=4, # changed the default value from 2 to 4 - d_output=None, - transposed=False, - activation="gelu", - initializer=None, - dropout=0.0, - tie_dropout=False, - ): - super().__init__() - self.d_output = d_input if d_output is None else d_output - self.transposed = transposed - d_inner = expand * d_input - - linear1 = LinearActivation( - d_input, - d_inner, - transposed=transposed, - activation=activation, - initializer=initializer, - activate=True, - ) - dropout_cls = ( - partial(DropoutNd, transposed=self.transposed) - if tie_dropout - else nn.Dropout - ) - # dropout_cls = nn.Dropout2d if self.transposed else nn.Dropout - drop = dropout_cls(dropout) if dropout > 0.0 else nn.Identity() - - linear2 = LinearActivation( - d_inner, - self.d_output, - transposed=transposed, - activation=None, - initializer=initializer, - activate=False, - ) - - self.ff = nn.Sequential( - linear1, - drop, - linear2, - ) - - def forward(self, x, *args, **kwargs): - return self.ff(x), None - - def step(self, x, state, **kwargs): - # x: [batch, d_input] - if self.transposed: - # expects: [batch, d_input, seq_len] - return self.ff(x.unsqueeze(-1)).squeeze(-1), state - else: - return self.ff(x), state diff --git a/src/clm/module_library/hippo.py b/src/clm/module_library/hippo.py deleted file mode 100644 index 9bd1daba..00000000 --- a/src/clm/module_library/hippo.py +++ /dev/null @@ -1,277 +0,0 @@ -import torch -import numpy as np -from scipy import special as ss -from einops import rearrange -from opt_einsum import contract - - -def embed_c2r(A): - A = rearrange(A, "... m n -> ... m () n ()") - A = np.pad(A, ((0, 0), (0, 1), (0, 0), (0, 1))) + np.pad( - A, ((0, 0), (1, 0), (0, 0), (1, 0)) - ) - return rearrange(A, "m x n y -> (m x) (n y)") - - -# TODO take in 'torch' option to return torch instead of numpy, and converts the shape of B from (N, 1) to (N) -def transition(measure, N, **measure_args): - """A, B transition matrices for different measures - - measure: the type of measure - legt - Legendre (translated) - legs - Legendre (scaled) - glagt - generalized Laguerre (translated) - lagt, tlagt - previous versions of (tilted) Laguerre with slightly different normalization - """ - # Laguerre (translated) - if measure == "lagt": - b = measure_args.get("beta", 1.0) - A = np.eye(N) / 2 - np.tril(np.ones((N, N))) - B = b * np.ones((N, 1)) - # Generalized Laguerre - # alpha 0, beta small is most stable (limits to the 'lagt' measure) - # alpha 0, beta 1 has transition matrix A = [lower triangular 1] - elif measure == "glagt": - alpha = measure_args.get("alpha", 0.0) - beta = measure_args.get("beta", 0.01) - A = -np.eye(N) * (1 + beta) / 2 - np.tril(np.ones((N, N)), -1) - B = ss.binom(alpha + np.arange(N), np.arange(N))[:, None] - - L = np.exp( - 0.5 * (ss.gammaln(np.arange(N) + alpha + 1) - ss.gammaln(np.arange(N) + 1)) - ) - A = (1.0 / L[:, None]) * A * L[None, :] - B = ( - (1.0 / L[:, None]) - * B - * np.exp(-0.5 * ss.gammaln(1 - alpha)) - * beta ** ((1 - alpha) / 2) - ) - # Legendre (translated) - elif measure == "legt": - Q = np.arange(N, dtype=np.float64) - R = (2 * Q + 1) ** 0.5 - j, i = np.meshgrid(Q, Q) - A = R[:, None] * np.where(i < j, (-1.0) ** (i - j), 1) * R[None, :] - B = R[:, None] - A = -A - - # Halve again for timescale correctness - A *= 0.5 - B *= 0.5 - # LMU: equivalent to LegT up to normalization - elif measure == "lmu": - Q = np.arange(N, dtype=np.float64) - R = (2 * Q + 1)[:, None] # / theta - j, i = np.meshgrid(Q, Q) - A = np.where(i < j, -1, (-1.0) ** (i - j + 1)) * R - B = (-1.0) ** Q[:, None] * R - # Legendre (scaled) - elif measure == "legs": - q = np.arange(N, dtype=np.float64) - col, row = np.meshgrid(q, q) - r = 2 * q + 1 - M = -(np.where(row >= col, r, 0) - np.diag(q)) - T = np.sqrt(np.diag(2 * q + 1)) - A = T @ M @ np.linalg.inv(T) - B = np.diag(T)[:, None] - B = ( - B.copy() - ) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) - elif measure == "legsd": - q = np.arange(N, dtype=np.float64) - col, row = np.meshgrid(q, q) - r = 2 * q + 1 - M = -(np.where(row >= col, r, 0) - np.diag(q)) - T = np.sqrt(np.diag(2 * q + 1)) - A = T @ M @ np.linalg.inv(T) - B = np.diag(T)[:, None] - B = ( - B.copy() - ) # Otherwise "UserWarning: given NumPY array is not writeable..." after torch.as_tensor(B) - A += 0.5 * B * B[None, :, 0] - B = B / 2.0 - elif measure in ["fourier_diag", "foud"]: - freqs = np.arange(N // 2) - d = np.stack([freqs, np.zeros(N // 2)], axis=-1).reshape(-1)[:-1] - A = 2 * np.pi * (-np.diag(d, 1) + np.diag(d, -1)) - A = A - 0.5 * np.eye(N) - B = np.zeros(N) - B[0::2] = 2**0.5 - B[0] = 1 - B = B[:, None] - elif measure in ["fourier", "fout"]: - freqs = np.arange(N // 2) - d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:] - A = np.pi * (-np.diag(d, 1) + np.diag(d, -1)) - B = np.zeros(N) - B[0::2] = 2**0.5 - B[0] = 1 - - # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case - A = A - B[:, None] * B[None, :] - B = B[:, None] - elif measure == "fourier_decay": - freqs = np.arange(N // 2) - d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:] - A = np.pi * (-np.diag(d, 1) + np.diag(d, -1)) - B = np.zeros(N) - B[0::2] = 2**0.5 - B[0] = 1 - - # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case - A = A - 0.5 * B[:, None] * B[None, :] - B = 0.5 * B[:, None] - elif measure == "fourier2": # Double everything: orthonormal on [0, 1] - freqs = 2 * np.arange(N // 2) - d = np.stack([np.zeros(N // 2), freqs], axis=-1).reshape(-1)[1:] - A = np.pi * (-np.diag(d, 1) + np.diag(d, -1)) - B = np.zeros(N) - B[0::2] = 2**0.5 - B[0] = 1 - - # Subtract off rank correction - this corresponds to the other endpoint u(t-1) in this case - A = A - B[:, None] * B[None, :] * 2 - B = B[:, None] * 2 - elif measure == "random": - A = np.random.randn(N, N) / N - B = np.random.randn(N, 1) - elif measure == "diagonal": - A = -np.diag(np.exp(np.random.randn(N))) - B = np.random.randn(N, 1) - else: - raise NotImplementedError - - return A, B - - -def rank_correction(measure, N, rank=1, dtype=torch.float): - """Return low-rank matrix L such that A + L is normal""" - - if measure == "legs": - assert rank >= 1 - P = torch.sqrt(0.5 + torch.arange(N, dtype=dtype)).unsqueeze(0) # (1 N) - elif measure == "legt": - assert rank >= 2 - P = torch.sqrt(1 + 2 * torch.arange(N, dtype=dtype)) # (N) - P0 = P.clone() - P0[0::2] = 0.0 - P1 = P.clone() - P1[1::2] = 0.0 - P = torch.stack([P0, P1], dim=0) # (2 N) - P *= 2 ** ( - -0.5 - ) # Halve the rank correct just like the original matrix was halved - elif measure == "lagt": - assert rank >= 1 - P = 0.5**0.5 * torch.ones(1, N, dtype=dtype) - elif measure in ["fourier", "fout"]: - P = torch.zeros(N) - P[0::2] = 2**0.5 - P[0] = 1 - P = P.unsqueeze(0) - elif measure == "fourier_decay": - P = torch.zeros(N) - P[0::2] = 2**0.5 - P[0] = 1 - P = P.unsqueeze(0) - P = P / 2**0.5 - elif measure == "fourier2": - P = torch.zeros(N) - P[0::2] = 2**0.5 - P[0] = 1 - P = 2**0.5 * P.unsqueeze(0) - elif measure in ["fourier_diag", "foud", "legsd"]: - P = torch.zeros(1, N, dtype=dtype) - else: - raise NotImplementedError - - d = P.size(0) - if rank > d: - P = torch.cat([P, torch.zeros(rank - d, N, dtype=dtype)], dim=0) # (rank N) - return P - - -def initial_C(measure, N, dtype=torch.float): - """Return C that captures the other endpoint in the HiPPO approximation""" - - if measure == "legt": - C = (torch.arange(N, dtype=dtype) * 2 + 1) ** 0.5 * (-1) ** torch.arange(N) - elif measure == "fourier": - C = torch.zeros(N) - C[0::2] = 2**0.5 - C[0] = 1 - else: - C = torch.zeros(N, dtype=dtype) # (N) - - return C - - -def nplr(measure, N, rank=1, dtype=torch.float, diagonalize_precision=True): - """Return w, p, q, V, B such that - (w - p q^*, B) is unitarily equivalent to the original HiPPO A, B by the matrix V - i.e. A = V[w - p q^*]V^*, B = V B - """ - assert dtype == torch.float or dtype == torch.double - cdtype = torch.cfloat if dtype == torch.float else torch.cdouble - - A, B = transition(measure, N) - A = torch.as_tensor(A, dtype=dtype) # (N, N) - B = torch.as_tensor(B, dtype=dtype)[:, 0] # (N,) - - P = rank_correction(measure, N, rank=rank, dtype=dtype) # (r N) - AP = A + torch.sum(P.unsqueeze(-2) * P.unsqueeze(-1), dim=-3) - - # We require AP to be nearly skew-symmetric - _A = AP + AP.transpose(-1, -2) - if ( - err := torch.sum((_A - _A[0, 0] * torch.eye(N)) ** 2) / N - ) > 1e-5: # if not torch.allclose(_A - _A[0,0]*torch.eye(N), torch.zeros(N, N), atol=1e-5): - print("WARNING: HiPPO matrix not skew symmetric", err) - - # Take advantage of identity + skew-symmetric form to calculate real and imaginary parts separately - # Imaginary part can use eigh instead of eig - w_re = torch.mean(torch.diagonal(AP), -1, keepdim=True) - - # Diagonalize in double precision - if diagonalize_precision: - AP = AP.to(torch.double) - # w, V = torch.linalg.eig(AP) # (..., N) (..., N, N) - w_im, V = torch.linalg.eigh(AP * -1j) # (..., N) (..., N, N) - if diagonalize_precision: - w_im, V = w_im.to(cdtype), V.to(cdtype) - w = w_re + 1j * w_im - # Check: V w V^{-1} = A - # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)) - - # Only keep half of each conjugate pair - _, idx = torch.sort(w.imag) - w_sorted = w[idx] - V_sorted = V[:, idx] - - # There is an edge case when eigenvalues can be 0, which requires some machinery to handle - # We use a huge hack here: Assume only one pair is 0, and that it is the first row/column of A (only happens in Fourier case) - V = V_sorted[:, : N // 2] - w = w_sorted[: N // 2] - assert w[-2].abs() > 1e-4, "Only 1 zero eigenvalue allowed in diagonal part of A" - if w[-1].abs() < 1e-4: - V[:, -1] = 0.0 - V[0, -1] = 2**-0.5 - V[1, -1] = 2**-0.5 * 1j - - _AP = V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2) - if (err := torch.sum((2 * _AP.real - AP) ** 2) / N) > 1e-5: - print( - "Warning: Diagonalization of A matrix not numerically precise - error", err - ) - # print("check", V @ torch.diag_embed(w) @ V.conj().transpose(-1, -2)) - - V_inv = V.conj().transpose(-1, -2) - - # C = initial_C(measure, N, dtype=dtype) - B = contract("ij, j -> i", V_inv, B.to(V)) # V^* B - # C = contract('ij, j -> i', V_inv, C.to(V)) # V^* C - P = contract("ij, ...j -> ...i", V_inv, P.to(V)) # V^* P - - # return w, P, B, C, V - return w, P, B, V diff --git a/src/clm/module_library/kernel.py b/src/clm/module_library/kernel.py deleted file mode 100644 index 437e3750..00000000 --- a/src/clm/module_library/kernel.py +++ /dev/null @@ -1,718 +0,0 @@ -import math -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -from einops import rearrange, repeat -from opt_einsum import contract, contract_expression - -from . import dplr -from .krylov import krylov, power -from .cauchy import cauchy_naive - - -_conj = lambda x: torch.cat([x, x.conj()], dim=-1) -_c2r = torch.view_as_real -_r2c = torch.view_as_complex - -if tuple(map(int, torch.__version__.split(".")[:2])) >= (1, 10): - _resolve_conj = lambda x: x.conj().resolve_conj() -else: - _resolve_conj = lambda x: x.conj() - - -class OptimModule(nn.Module): - """Interface for Module that allows registering buffers/parameters with configurable optimizer hyperparameters""" - - def register(self, name, tensor, lr=None): - """Register a tensor with a configurable learning rate and 0 weight decay""" - - if lr == 0.0: - self.register_buffer(name, tensor) - else: - self.register_parameter(name, nn.Parameter(tensor)) - - optim = {"weight_decay": 0.0} - if lr is not None: - optim["lr"] = lr - setattr(getattr(self, name), "_optim", optim) - - -class SSKernelNPLR(OptimModule): - """ - Stores a representation of and computes the SSKernel function K_L(dt, A, B, C) corresponding to a discretized state space, where A is Normal + Low Rank (NPLR) - """ - - @torch.no_grad() - def _setup_C(self, L): - """Construct C~ from C - - Two modes are supported: go directly to length L if self.L is 1, or length is doubled - """ - - if self.L.item() == 0: - double_length = False - elif L > self.L.item(): # 2*int(self.L) == L: - double_length = True - L = self.L.item() # Convenience for the math below - else: - return - - C = _r2c(self.C) - dA, _ = self._setup_state() - dA_L = power(L, dA) - # Multiply C by I - dA_L - C_ = _conj(C) - prod = contract("h m n, c h n -> c h m", dA_L.transpose(-1, -2), C_) - if double_length: - prod = -prod # Multiply by I + dA_L instead - C_ = C_ - prod - C_ = C_[..., : self.N] # Take conjugate pairs again - self.C.copy_(_c2r(C_)) - - self.L = 2 * self.L if double_length else self.L + L # Preserve type/device - - def _omega(self, L, dtype, device, cache=True): - """Calculate (and cache) FFT nodes and their "unprocessed" version with the bilinear transform - This should be called everytime the internal length self.L changes""" - - # Use cached if available - if cache and hasattr(self, "omega") and self.omega.size(-1) == L // 2 + 1: - return self.omega, self.z - - omega = torch.tensor( - np.exp(-2j * np.pi / (L)), dtype=dtype, device=device - ) # \omega_{2L} - omega = omega ** torch.arange(0, L // 2 + 1, device=device) - z = 2 * (1 - omega) / (1 + omega) - - # Cache if necessary - if cache: - self.omega = omega - self.z = z - return omega, z - - def __init__( - self, - w, - P, - B, - C, - log_dt, - L=None, # starting/maximum length of kernel - lr=None, - verbose=False, - keops=False, - real_type="exp", # ['none' | 'exp' | 'relu' | sigmoid'] - real_tolerance=1e-3, - bandlimit=None, - ): - """ - L: Maximum length; this module computes an SSM kernel of length L - A is represented by diag(w) - PP^* - w: (S, N) diagonal part - P: (R, S, N) low-rank part - - B: (S, N) - C: (C, H, N) - dt: (H) timescale per feature - lr: [dict | float | None] hook to set lr of special parameters (A, B, dt) - - Dimensions: - N (or d_state): state size - H (or d_model): total SSM copies - S (or n_ssm): number of trainable copies of (A, B, dt); must divide H - R (or rank): rank of low-rank part - C (or channels): system is 1-dim to C-dim - - The forward pass of this Module returns a tensor of shape (C, H, L) - - Note: tensor shape N here denotes half the true state size, because of conjugate symmetry - """ - - super().__init__() - self.verbose = verbose - self.keops = keops - self.bandlimit = bandlimit - self.real_type = real_type - self.real_tolerance = real_tolerance - - # Rank of low-rank correction - self.rank = P.shape[-3] - assert w.size(-1) == P.size(-1) == B.size(-1) == C.size(-1) - self.H = log_dt.size(-1) - self.N = w.size(-1) - - # Check different SSM inits - assert w.size(-2) == P.size(-2) == B.size(-2) # n_ssm - assert self.H % w.size(0) == 0 - self.n_ssm = w.size(0) - self.repeat = self.H // w.size( - 0 - ) # Each trainable SSM needs to be duplicated this many times - - # Broadcast everything to correct shapes - C = C.expand(torch.broadcast_shapes(C.shape, (1, self.H, self.N))) # (C, H, N) - B = B.unsqueeze(0) # (1, 1, N) - - # Register parameters - self.C = nn.Parameter(_c2r(_resolve_conj(C))) - if lr is None or isinstance(lr, float): - lr_dict = {} - else: - lr_dict, lr = lr, None - self.register("log_dt", log_dt, lr_dict.get("dt", lr)) - self.register("B", _c2r(B), lr_dict.get("B", lr)) - self.register("P", _c2r(P), lr_dict.get("A", lr)) - self.register("inv_w_real", self._w_init(w.real), lr_dict.get("A", lr)) - self.register("w_imag", w.imag, lr_dict.get("A", lr)) - - self.l_max = L - self.register_buffer("L", torch.tensor(0)) # Internal length - - def _w_init(self, w_real): - w_real = torch.clamp(w_real, max=-self.real_tolerance) - if self.real_type == "none": - return -w_real - elif self.real_type == "exp": - return torch.log(-w_real) # Some of the HiPPO methods have real part 0 - elif self.real_type == "relu": - return -w_real - elif self.real_type == "sigmoid": - return torch.logit(-w_real) - elif self.real_type == "softplus": - return torch.log(torch.exp(-w_real) - 1) - else: - raise NotImplementedError - - def _w(self): - # Get the internal w (diagonal) parameter - if self.real_type == "none": - w_real = -self.inv_w_real - elif self.real_type == "exp": - w_real = -torch.exp(self.inv_w_real) - elif self.real_type == "relu": - w_real = -F.relu(self.inv_w_real) - elif self.real_type == "sigmoid": - w_real = -F.sigmoid(self.inv_w_real) - elif self.real_type == "softplus": - w_real = -F.softplus(self.inv_w_real) - else: - raise NotImplementedError - w = w_real + 1j * self.w_imag - return w - - def forward(self, state=None, rate=1.0, L=None): - """ - state: (B, H, N) initial state - rate: sampling rate factor - L: target length - - returns: - (C, H, L) convolution kernel (generally C=1) - (B, H, L) output from initial state - """ - - # Initialize C~ if necessary (done in forward pass so it's on the correct device) - if self.L.item() == 0 and self.l_max is not None and self.l_max > 0: - self._setup_C(self.l_max) - - # Handle sampling rate logic - # The idea is that this kernel's length (in continuous units) is self.L, while we are asked to provide a kernel of length L at (relative) frequency rate - if L is None: - L = round(self.L.item() / rate) - - # Increase the internal length if needed - continuous_L = round(rate * L) - while continuous_L > self.L.item(): - self._setup_C(continuous_L) - discrete_L = round(self.L.item() / rate) - - dt = torch.exp(self.log_dt) * rate - B = _r2c(self.B) - C = _r2c(self.C) - P = _r2c(self.P) - Q = P.conj() - w = self._w() # (S, N) where S=n_ssm - - # Address bandlimiting - if self.bandlimit is not None: - freqs = w.imag.abs() / (2 * math.pi) # (H, N) - freqs = dt[:, None] / rate * freqs # (H, N) - mask = torch.where(freqs < self.bandlimit * 0.5, 1, 0) - C = C * mask - - # Get FFT nodes of right length - omega, z = self._omega( - discrete_L, dtype=w.dtype, device=w.device, cache=(rate == 1.0) - ) - - # Broadcast parameters to same hidden features H - B = repeat(B, "1 t n -> 1 (v t) n", v=self.repeat) - P = repeat(P, "r t n -> r (v t) n", v=self.repeat) - Q = repeat(Q, "r t n -> r (v t) n", v=self.repeat) - w = repeat(w, "t n -> (v t) n", v=self.repeat) - - # Augment B - if state is not None: - # Have to "unbilinear" the state to put it into the same "type" as B - # Compute 1/dt * (I + dt/2 A) @ state - - # Can do this without expanding (maybe minor speedup using conj symmetry in theory), but it's easier to read this way - s = _conj(state) if state.size(-1) == self.N else state # (B H N) - sA = s * _conj(w) - contract( # (B H N) - "bhm, rhm, rhn -> bhn", s, _conj(Q), _conj(P) - ) - s = s / dt.unsqueeze(-1) + sA / 2 - s = s[..., : self.N] - - B = torch.cat([s, B], dim=-3) # (B+1, H, N) - - # Incorporate dt into A - w = w * dt.unsqueeze(-1) # (H N) - - # Stack B and p, C and q for convenient batching - B = torch.cat([B, P], dim=-3) # (B+1+R, H, N) - C = torch.cat([C, Q], dim=-3) # (C+R, H, N) - - # Incorporate B and C batch dimensions - v = B.unsqueeze(-3) * C.unsqueeze(-4) # (B+1+R, C+R, H, N) - - # Calculate resolvent at omega - # if has_cauchy_extension and z.dtype == torch.cfloat and not self.keops: - # r = cauchy_mult(v, z, w, symmetric=True) - # elif has_pykeops: - # r = cauchy_conj(v, z, w) - # else: - r = cauchy_naive(v, z, w) - r = r * dt[None, None, :, None] # (B+1+R, C+R, H, L) - - # Low-rank Woodbury correction - if self.rank == 1: - k_f = r[:-1, :-1, :, :] - r[:-1, -1:, :, :] * r[-1:, :-1, :, :] / ( - 1 + r[-1:, -1:, :, :] - ) - elif self.rank == 2: - r00 = r[: -self.rank, : -self.rank, :, :] - r01 = r[: -self.rank, -self.rank :, :, :] - r10 = r[-self.rank :, : -self.rank, :, :] - r11 = r[-self.rank :, -self.rank :, :, :] - det = (1 + r11[:1, :1, :, :]) * (1 + r11[1:, 1:, :, :]) - r11[ - :1, 1:, :, : - ] * r11[1:, :1, :, :] - s = ( - r01[:, :1, :, :] * (1 + r11[1:, 1:, :, :]) * r10[:1, :, :, :] - + r01[:, 1:, :, :] * (1 + r11[:1, :1, :, :]) * r10[1:, :, :, :] - - r01[:, :1, :, :] * (r11[:1, 1:, :, :]) * r10[1:, :, :, :] - - r01[:, 1:, :, :] * (r11[1:, :1, :, :]) * r10[:1, :, :, :] - ) - s = s / det - k_f = r00 - s - else: - r00 = r[: -self.rank, : -self.rank, :, :] - r01 = r[: -self.rank, -self.rank :, :, :] - r10 = r[-self.rank :, : -self.rank, :, :] - r11 = r[-self.rank :, -self.rank :, :, :] - r11 = rearrange(r11, "a b h n -> h n a b") - r11 = torch.linalg.inv(torch.eye(self.rank, device=r.device) + r11) - r11 = rearrange(r11, "h n a b -> a b h n") - k_f = r00 - torch.einsum( - "i j h n, j k h n, k l h n -> i l h n", r01, r11, r10 - ) - - # Final correction for the bilinear transform - k_f = k_f * 2 / (1 + omega) - - # Move from frequency to coefficients - k = torch.fft.irfft(k_f, n=discrete_L) # (B+1, C, H, L) - - # # Truncate to target length - k = k[..., :L] - - if state is not None: - k_state = k[:-1, :, :, :] # (B, C, H, L) - else: - k_state = None - k_B = k[-1, :, :, :] # (C H L) - - return k_B, k_state - - @torch.no_grad() - def double_length(self): - self._setup_C(2 * self.L) - - @torch.no_grad() - def _check(self): - """Check if A, B, C parameters and vanilla SSKernel construction can be recovered""" - - # assert self.L > 0, "Set up module first" - - K = self.forward(L=self.l_max)[0] - - self._setup_step() - K_ = krylov(self.l_max, self.dA, self.dB, self.dC) - - diff = K - K_ - - @torch.no_grad() - def _setup_linear(self): - """Create parameters that allow fast linear stepping of state""" - w = self._w() - B = _r2c(self.B) # (H N) - P = _r2c(self.P) - Q = P.conj() - - # Repeat w shape properly - B = repeat(B, "1 t n -> 1 (v t) n", v=self.repeat) - P = repeat(P, "r t n -> r (v t) n", v=self.repeat) - Q = repeat(Q, "r t n -> r (v t) n", v=self.repeat) - w = repeat(w, "t n -> (v t) n", v=self.repeat) - - # Prepare Linear stepping - dt = torch.exp(self.log_dt) - D = (2.0 / dt.unsqueeze(-1) - w).reciprocal() # (H, N) - R = ( - torch.eye(self.rank, dtype=w.dtype, device=w.device) - + 2 * contract("r h n, h n, s h n -> h r s", Q, D, P).real - ) # (H R R) - Q_D = rearrange(Q * D, "r h n -> h r n") - try: - R = torch.linalg.solve(R, Q_D) # (H R N) - except: - R = torch.tensor( - np.linalg.solve( - R.to(Q_D).contiguous().detach().cpu(), - Q_D.contiguous().detach().cpu(), - ) - ).to(Q_D) - R = rearrange(R, "h r n -> r h n") - - self.step_params = { - "D": D, # (H N) - "R": R, # (R H N) - "P": P, # (R H N) - "Q": Q, # (R H N) - "B": B, # (1 H N) - "E": 2.0 / dt.unsqueeze(-1) + w, # (H N) - } - - def _step_state_linear(self, u=None, state=None): - """ - Version of the step function that has time O(N) instead of O(N^2) per step, which takes advantage of the DPLR form and bilinear discretization. - - Unfortunately, as currently implemented it's about 2x slower because it calls several sequential operations. Perhaps a fused CUDA kernel implementation would be much faster - - u: (H) input - state: (H, N/2) state with conjugate pairs - Optionally, the state can have last dimension N - Returns: same shape as state - """ - C = _r2c(self.C) # View used for dtype/device - - if u is None: # Special case used to find dA - u = torch.zeros(self.H, dtype=C.dtype, device=C.device) - if state is None: # Special case used to find dB - state = torch.zeros(self.H, self.N, dtype=C.dtype, device=C.device) - - step_params = self.step_params.copy() - if ( - state.size(-1) == self.N - ): # Only store half of the conjugate pairs; should be true by default - # There should be a slightly faster way using conjugate symmetry - contract_fn = lambda p, x, y: contract( - "r h n, r h m, ... h m -> ... h n", _conj(p), _conj(x), _conj(y) - )[ - ..., : self.N - ] # inner outer product - else: - assert state.size(-1) == 2 * self.N - step_params = {k: _conj(v) for k, v in step_params.items()} - # TODO worth setting up a contract_expression in default_state if we want to use this at inference time for stepping - contract_fn = lambda p, x, y: contract( - "r h n, r h m, ... h m -> ... h n", p, x, y - ) # inner outer product - D = step_params["D"] # (H N) - E = step_params["E"] # (H N) - R = step_params["R"] # (R H N) - P = step_params["P"] # (R H N) - Q = step_params["Q"] # (R H N) - B = step_params["B"] # (1 H N) - - new_state = E * state - contract_fn(P, Q, state) # (B H N) - new_state = new_state + 2.0 * B * u.unsqueeze(-1) # (B H N) - new_state = D * (new_state - contract_fn(P, R, new_state)) - - return new_state - - def _setup_state(self): - """Construct dA and dB for discretized state equation""" - - # Construct dA and dB by using the stepping - self._setup_linear() - C = _r2c(self.C) # Just returns a view that we use for finding dtype/device - - state = torch.eye(2 * self.N, dtype=C.dtype, device=C.device).unsqueeze( - -2 - ) # (N 1 N) - dA = self._step_state_linear(state=state) - dA = rearrange(dA, "n h m -> h m n") - - u = C.new_ones(self.H) - dB = self._step_state_linear(u=u) - dB = _conj(dB) - dB = rearrange(dB, "1 h n -> h n") # (H N) - return dA, dB - - def _step_state(self, u, state): - """Must be called after self.default_state() is used to construct an initial state!""" - next_state = self.state_contraction(self.dA, state) + self.input_contraction( - self.dB, u - ) - return next_state - - def _setup_step(self, mode="dense"): - """Set up dA, dB, dC discretized parameters for stepping""" - self.dA, self.dB = self._setup_state() - - # Calculate original C - C = _conj(_r2c(self.C)) # (H C N) - if self.L.item() == 0: - dC = C - else: - # self.C represents C_tilde - dA_L = power(self.L.item(), self.dA) - identity_matrix = torch.eye(self.dA. size(-1)).to(dA_L) - - dC = torch.linalg.solve( - identity_matrix - dA_L.transpose(-1, -2), - C.unsqueeze(-1), - ).squeeze(-1) - self.dC = dC - - # Do special preprocessing for different step modes - - self._step_mode = mode - if mode == "linear": - # Linear case: special step function for the state, we need to handle output - # use conjugate symmetry by default, which affects the output projection - self.dC = 2 * self.dC[:, :, : self.N] - elif mode == "diagonal": - # Eigendecomposition of the A matrix - L, V = torch.linalg.eig(self.dA) - V_inv = torch.linalg.inv(V) - # Change the parameterization to diagonalize - self.dA = L - self.dB = contract("h n m, h m -> h n", V_inv, self.dB) - self.dC = contract("h n m, c h n -> c h m", V, self.dC) - - elif mode == "dense": - pass - else: - raise NotImplementedError( - "NPLR Kernel step mode must be {'dense' | 'linear' | 'diagonal'}" - ) - - def default_state(self, *batch_shape): - C = _r2c(self.C) - N = C.size(-1) - H = C.size(-2) - - # Cache the tensor contractions we will later do, for efficiency - # These are put in this function because they depend on the batch size - step_mode = getattr( - self, "_step_mode", "dense" - ) # Used in default_state, which is called without _setup_step() in forward_state() - if step_mode != "linear": - N *= 2 - - if step_mode == "diagonal": - self.state_contraction = contract_expression( - "h n, ... h n -> ... h n", - (H, N), - batch_shape + (H, N), - ) - else: - # Dense (quadratic) case: expand all terms - self.state_contraction = contract_expression( - "h m n, ... h n -> ... h m", - (H, N, N), - batch_shape + (H, N), - ) - - self.input_contraction = contract_expression( - "h n, ... h -> ... h n", - (H, N), # self.dB.shape - batch_shape + (H,), - ) - - self.output_contraction = contract_expression( - "c h n, ... h n -> ... c h", - (C.shape[0], H, N), # self.dC.shape - batch_shape + (H, N), - ) - - state = torch.zeros(*batch_shape, H, N, dtype=C.dtype, device=C.device) - return state - - def step(self, u, state): - """Must have called self._setup_step() and created state with self.default_state() before calling this""" - - if self._step_mode == "linear": - new_state = self._step_state_linear(u, state) - else: - new_state = self._step_state(u, state) - y = self.output_contraction(self.dC, new_state) - return y.real, new_state - - -class SSKernel(nn.Module): - """Wrapper around SSKernel parameterizations. - - The SSKernel is expected to support the interface - forward() - default_state() - _setup_step() - step() - """ - - def __init__( - self, - H, - N=64, - L=None, - measure="legs", - rank=1, - channels=1, - dt_min=0.001, - dt_max=0.1, - deterministic=False, - lr=None, - mode="nplr", - n_ssm=None, - verbose=False, - measure_args={}, - **kernel_args, - ): - """State Space Kernel which computes the convolution kernel $\\\\bar{K}$ - - H: Number of independent SSM copies; controls the size of the model. Also called d_model in the config. - N: State size (dimensionality of parameters A, B, C). Also called d_state in the config. Generally shouldn't need to be adjusted and doens't affect speed much. - L: Maximum length of convolution kernel, if known. Should work in the majority of cases even if not known. - measure: Options for initialization of (A, B). For NPLR mode, recommendations are "legs", "fout", "hippo" (combination of both). For Diag mode, recommendations are "diag-inv", "diag-lin", "diag-legs", and "diag" (combination of diag-inv and diag-lin) - rank: Rank of low-rank correction for NPLR mode. Needs to be increased for measure "legt" - channels: C channels turns the SSM from a 1-dim to C-dim map; can think of it having C separate "heads" per SSM. This was partly a feature to make it easier to implement bidirectionality; it is recommended to set channels=1 and adjust H to control parameters instead - dt_min, dt_max: min and max values for the step size dt (\Delta) - mode: Which kernel algorithm to use. 'nplr' is the full S4 model; 'diag' is the simpler S4D; 'slow' is a dense version for testing - n_ssm: Number of independent trainable (A, B) SSMs, e.g. n_ssm=1 means all A/B parameters are tied across the H different instantiations of C. n_ssm=None means all H SSMs are completely independent. Generally, changing this option can save parameters but doesn't affect performance or speed much. This parameter must divide H - lr: Passing in a number (e.g. 0.001) sets attributes of SSM parameers (A, B, dt). A custom optimizer hook is needed to configure the optimizer to set the learning rates appropriately for these parameters. - """ - super().__init__() - self.N = N - self.H = H - dtype, cdtype = torch.float, torch.cfloat - self.channels = channels - self.n_ssm = n_ssm if n_ssm is not None else H - self.mode = mode - self.verbose = verbose - self.kernel_args = kernel_args - - # Generate dt - if deterministic: - log_dt = torch.exp(torch.linspace(math.log(dt_min), math.log(dt_max), H)) - else: - log_dt = torch.rand(self.H, dtype=dtype) * ( - math.log(dt_max) - math.log(dt_min) - ) + math.log(dt_min) - - # Compute the preprocessed representation - w, P, B, V = dplr.combination(measure, self.N, rank, self.n_ssm, **measure_args) - - # Broadcast C to have H channels - if deterministic: - C = torch.zeros(channels, self.n_ssm, self.N, dtype=cdtype) - C[:, :, :1] = 1.0 - C = contract("hmn, chn -> chm", V.conj().transpose(-1, -2), C) # V^* C - C = ( - repeat(C, "c t n -> c (v t) n", v=self.n_ssm // C.size(-2)) - .clone() - .contiguous() - ) - else: - C = torch.randn(channels, self.H, self.N // 2, dtype=cdtype) - - # Broadcast other parameters to have n_ssm copies - assert ( - self.n_ssm % B.size(-2) == 0 - and self.n_ssm % P.size(-2) == 0 - and self.n_ssm % w.size(-2) == 0 - ) - # Broadcast tensors to n_ssm copies - # These will be the parameters, so make sure tensors are materialized and contiguous - B = repeat(B, "t n -> (v t) n", v=self.n_ssm // B.size(-2)).clone().contiguous() - P = ( - repeat(P, "r t n -> r (v t) n", v=self.n_ssm // P.size(-2)) - .clone() - .contiguous() - ) - w = repeat(w, "t n -> (v t) n", v=self.n_ssm // w.size(-2)).clone().contiguous() - - self.kernel = SSKernelNPLR( - w, - P, - B, - C, - log_dt, - L=L, - lr=lr, - verbose=verbose, - **kernel_args, - ) - - def forward(self, state=None, L=None, rate=None): - return self.kernel(state=state, L=L, rate=rate) - - @torch.no_grad() - def forward_state(self, u, state): - """Forward the state through a sequence, i.e. computes the state after passing chunk through SSM - - state: (B, H, N) - u: (B, H, L) - - Returns: (B, H, N) - """ - - if hasattr(self.kernel, "forward_state"): - return self.kernel.forward_state(u, state) - - dA, dB = self.kernel._setup_state() # Construct dA, dB matrices - # dA, dB = self.kernel.dA, self.kernel.dB # (H N N) (H N) - - conj = state.size(-1) != dA.size(-1) - if conj: - state = _conj(state) - - v = contract( - "h n, b h l -> b h n l", dB, u.flip(-1) - ) # dB.unsqueeze(-1) * u.flip(-1).unsqueeze(-2) - AL, v = power(u.size(-1), dA, v) - next_state = contract("h m n, b h n -> b h m", AL, state) - next_state = next_state + v - - if conj: - next_state = next_state[..., : next_state.size(-1) // 2] - return next_state - - def _setup_step(self, **kwargs): - # This method is intended to be private so that setting up an S4 module with - # ``` - # if hasattr(module, 'setup_step'): module.setup_step() - # ``` - # will not trigger this method multiple times - self.kernel._setup_step(**kwargs) - - def step(self, u, state, **kwargs): - y, state = self.kernel.step(u, state, **kwargs) - return y, state - - def default_state(self, *args, **kwargs): - return self.kernel.default_state(*args, **kwargs) diff --git a/src/clm/module_library/krylov.py b/src/clm/module_library/krylov.py deleted file mode 100644 index 035e00d0..00000000 --- a/src/clm/module_library/krylov.py +++ /dev/null @@ -1,209 +0,0 @@ -import torch -import torch.nn.functional as F -from einops import rearrange - -from .toeplitz import causal_convolution - - -def krylov_sequential(L, A, b, c=None): - """Constant matrix A - - A : (..., N, N) - b : (..., N) - c : (..., N) - - Returns - if c: - x : (..., L) - x[i, l] = c[i] @ A^l @ b[i] - - else: - x : (..., N, L) - x[i, l] = A^l @ b[i] - """ - - # Check which of dim b and c is smaller to save memory - if c is not None and c.numel() < b.numel(): - return krylov_sequential(L, A.transpose(-1, -2), c, b) - - b_ = b - x = [] - for _ in range(L): - if c is not None: - x_ = torch.sum( - c * b_, dim=-1 - ) # (...) # could be faster with matmul or einsum? - else: - x_ = b_ - x.append(x_) - b_ = (A @ b_.unsqueeze(-1)).squeeze(-1) - - x = torch.stack(x, dim=-1) - return x - - -def krylov(L, A, b, c=None, return_power=False): - """ - Compute the Krylov matrix (b, Ab, A^2b, ...) using the squaring trick. - - If return_power=True, return A^{L-1} as well - """ - # TODO There is an edge case if L=1 where output doesn't get broadcasted, which might be an issue if caller is expecting broadcasting semantics... can deal with it if it arises - - x = b.unsqueeze(-1) # (..., N, 1) - A_ = A - - AL = None - if return_power: - AL = torch.eye(A.shape[-1], dtype=A.dtype, device=A.device) - _L = L - 1 - - done = L == 1 - # loop invariant: _L represents how many indices left to compute - while not done: - if return_power: - if _L % 2 == 1: - AL = A_ @ AL - _L //= 2 - - # Save memory on last iteration - current_length = x.shape[-1] - if L - current_length <= current_length: - done = True - _x = x[..., : L - current_length] - else: - _x = x - - _x = A_ @ _x - x = torch.cat( - [x, _x], dim=-1 - ) # there might be a more efficient way of ordering axes - if not done: - A_ = A_ @ A_ - - assert x.shape[-1] == L - - if c is not None: - x = torch.einsum("...nl, ...n -> ...l", x, c) - x = x.contiguous() # WOW!! - if return_power: - return x, AL - else: - return x - - -@torch.no_grad() -def power(L, A, v=None): - """Compute A^L and the scan sum_i A^i v_i - - A: (..., N, N) - v: (..., N, L) - """ - - identity = torch.eye(A.shape[-1]).to(A) # Changed from I to identity - - powers = [A] - power_of_2 = 1 # Changed from l to power_of_2 - while True: - if L % 2 == 1: - identity = powers[-1] @ identity # Changed from I to identity - L //= 2 - if L == 0: - break - power_of_2 *= 2 # Changed from l to power_of_2 - if v is None: - powers = [powers[-1] @ powers[-1]] - else: - powers.append(powers[-1] @ powers[-1]) - - if v is None: - return identity # Changed from I to identity - - # Invariants: - # powers[-1] := A^l - # l := largest po2 at most L - - # Note that an alternative divide and conquer to compute the reduction is possible and can be embedded into the above loop without caching intermediate powers of A - # We do this reverse divide-and-conquer for efficiency reasons: - # 1) it involves fewer padding steps for non-po2 L - # 2) it involves more contiguous arrays - - # Take care of edge case for non-po2 arrays - # Note that this initial step is a no-op for the case of power of 2 (l == L) - k = v.size(-1) - power_of_2 # Changed from l to power_of_2 - v_ = powers.pop() @ v[..., power_of_2:] # Changed from l to power_of_2 - v = v[..., :power_of_2] # Changed from l to power_of_2 - v[..., :k] = v[..., :k] + v_ - - # Handle reduction for power of 2 - while v.size(-1) > 1: - v = rearrange(v, "... (z l) -> ... z l", z=2) - v = v[..., 0, :] + powers.pop() @ v[..., 1, :] - return identity, v.squeeze(-1) # Changed from I to identity - - -def krylov_toeplitz(L, A, b, c=None): - """Specializes to lower triangular Toeplitz matrix A represented by its diagonals - - A : (..., N) - b : (..., N) - c : (..., N) - - Returns - x : (..., N, L) - x[i, l] = A^l @ b[i] - """ - x = b.unsqueeze(0) # (1, ..., N) - A_ = A - while x.shape[0] < L: - xx = causal_convolution(A_, x) - x = torch.cat( - [x, xx], dim=0 - ) # there might be a more efficient way of ordering axes - A_ = causal_convolution(A_, A_) - x = x[:L, ...] # (L, ..., N) - if c is not None: - x = torch.einsum("l...n, ...n -> ...l", x, c) - else: - x = rearrange(x, "l ... n -> ... n l") - x = x.contiguous() - return x - - -def krylov_toeplitz_(L, A, b, c=None): - """Padded version of krylov_toeplitz that saves some fft's - - TODO currently not faster than original version, not sure why - """ - N = A.shape[-1] - - x = b.unsqueeze(0) # (1, ..., N) - x = F.pad(x, (0, N)) - A = F.pad(A, (0, N)) - done = L == 1 - while not done: - length = x.shape[0] - # Save memory on last iteration - if L - length <= length: - done = True - _x = x[: L - length] - else: - _x = x - Af = torch.fft.rfft(A, n=2 * N, dim=-1) - xf = torch.fft.rfft(_x, n=2 * N, dim=-1) - xf_ = Af * xf - x_ = torch.fft.irfft(xf_, n=2 * N, dim=-1) - x_[..., N:] = 0 - x = torch.cat( - [x, x_], dim=0 - ) # there might be a more efficient way of ordering axes - if not done: - A = torch.fft.irfft(Af * Af, n=2 * N, dim=-1) - A[..., N:] = 0 - x = x[:L, ..., :N] # (L, ..., N) - if c is not None: - x = torch.einsum("l...n, ...n -> ...l", x, c) - else: - x = rearrange(x, "l ... n -> ... n l") - x = x.contiguous() - return x diff --git a/src/clm/module_library/pool.py b/src/clm/module_library/pool.py deleted file mode 100644 index f6f2bdf4..00000000 --- a/src/clm/module_library/pool.py +++ /dev/null @@ -1,62 +0,0 @@ -import torch.nn.functional as F -from einops import rearrange, reduce - -from .sequence_module import SequenceModule -from .util_modules import LinearActivation - - -class DownAvgPool(SequenceModule): - def __init__(self, d_input, stride=1, expand=None, transposed=True): - super().__init__() - self.d_input = d_input - self.stride = stride - self.expand = expand - self.transposed = transposed - - if self.expand is not None: - self.linear = LinearActivation( - d_input, - d_input * expand, - transposed=transposed, - ) - - def forward(self, x): - if not self.transposed: - x = rearrange(x, "b ... d -> b d ...") - - if self.stride > 1: - # einops appears slower than F - if x.ndim == 3: - x = F.avg_pool1d(x, self.stride, self.stride) - elif x.ndim == 4: - x = F.avg_pool2d(x, self.stride, self.stride) - else: - # Reduction string e.g. "b d (l1 2) (l2 2) -> b d l1 l2" - reduce_str = ( - "b d " - + " ".join([f"(l{i} {self.stride})" for i in range(x.ndim - 2)]) - + " -> b d " - + " ".join([f"l{i}" for i in range(x.ndim - 2)]) - ) - x = reduce(x, reduce_str, "mean") - - # if self.expand > 1: - # x = repeat(x, 'b d ... -> b (d e) ...', e=self.expand) - - if not self.transposed: - x = rearrange(x, "b d ... -> b ... d") - if self.expand is not None: - x = self.linear(x) - return x, None - - def step(self, x, state, **kwargs): - if self.stride > 1 or self.expand > 1: - raise NotImplementedError - return x, state - - @property - def d_output(self): - if self.expand is None: - return self.d_input - else: - return self.d_input * self.expand diff --git a/src/clm/module_library/residual.py b/src/clm/module_library/residual.py deleted file mode 100644 index 50513e25..00000000 --- a/src/clm/module_library/residual.py +++ /dev/null @@ -1,23 +0,0 @@ -from torch import nn - - -class Residual(nn.Module): - """Residual connection with constant affine weights. Can simulate standard residual, no residual, and "constant gates".""" - - def __init__(self, i_layer, d_input, d_model, alpha=1.0, beta=1.0): - # print("ConstantResidual extra kwargs", kwargs) - super().__init__() - assert (d_input == d_model) or alpha == 0.0 - self.i_layer = i_layer - self.d_input = d_input - self.d_model = d_model - self.alpha = alpha - self.beta = beta - - @property - def d_output(self): - return self.d_model - - def forward(self, x, y, transposed): # TODO documentation of transposed - y = self.beta * y if self.beta != 1.0 else y - return self.alpha * x + y if self.alpha else y diff --git a/src/clm/module_library/s4.py b/src/clm/module_library/s4.py deleted file mode 100644 index bed7b06c..00000000 --- a/src/clm/module_library/s4.py +++ /dev/null @@ -1,290 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import opt_einsum as oe -from einops import rearrange - -from .kernel import SSKernel -from .util_modules import LinearActivation, Activation, DropoutNd - -optimized = True - -if optimized: - contract = oe.contract -else: - contract = torch.einsum - - -class S4(nn.Module): - def __init__( - self, - d_model, - d_state=64, - l_max=None, - channels=1, - bidirectional=False, - # Arguments for position-wise feedforward components - activation="gelu", - postact="glu", - initializer=None, - weight_norm=False, - hyper_act=None, - dropout=0.0, - tie_dropout=False, - bottleneck=None, - gate=None, - transposed=True, - verbose=False, - shift=False, - linear=False, - # SSM Kernel arguments - **kernel_args, - ): - """ - d_state: the dimension of the state, also denoted by N - l_max: the maximum kernel length, also denoted by L. Set l_max=None to always use a global kernel - channels: can be interpreted as a number of "heads"; the SSM is a map from a 1-dim to C-dim sequence. It's not recommended to change this unless desperate for things to tune; instead, increase d_model for larger models - bidirectional: if True, convolution kernel will be two-sided - - Position-wise feedforward components: - -------------------- - activation: activation in between SS and FF - postact: activation after FF - initializer: initializer on FF - weight_norm: weight normalization on FF - hyper_act: use a "hypernetwork" multiplication (experimental) - dropout: standard dropout argument. tie_dropout=True ties the dropout mask across the sequence length, emulating nn.Dropout1d - - Other arguments: - -------------------- - transposed: choose backbone axis ordering of (B, L, H) (if False) or (B, H, L) (if True) [B=batch size, L=sequence length, H=hidden dimension] - gate: add gated activation (GSS) - bottleneck: reduce SSM dimension (GSS) - shift: experimental option, shouldn't affect results - linear: Remove pointwise components so that the entire module is a linear SSM - - See the class .kernel.SSKernel for the kernel constructor which accepts kernel_args. Relevant options that are worth considering and tuning include "mode" + "measure", "dt_min", "dt_max", "lr" - - Other options are all experimental and should not need to be configured - """ - - super().__init__() - - self.d_model = d_model - self.H = d_model - self.N = d_state - self.L = l_max - self.bidirectional = bidirectional - self.channels = channels - self.transposed = transposed - self.shift = shift - self.linear = linear - - self.gate = gate - self.bottleneck = bottleneck - - if bottleneck is not None: - self.H = self.H // bottleneck - self.input_linear = LinearActivation( - self.d_model, - self.H, - transposed=self.transposed, - initializer=initializer, - activation=activation, - activate=True, - weight_norm=weight_norm, - ) - - if gate is not None: - self.input_gate = LinearActivation( - self.d_model, - self.d_model * gate, - transposed=self.transposed, - initializer=initializer, - activation=activation, - activate=True, - weight_norm=weight_norm, - ) - self.output_gate = LinearActivation( - self.d_model * gate, - self.d_model, - transposed=self.transposed, - initializer=initializer, - activation=None, - activate=False, - weight_norm=weight_norm, - ) - - # optional multiplicative modulation GLU-style - # https://arxiv.org/abs/2002.05202 - self.hyper = hyper_act is not None - if self.hyper: - channels *= 2 - self.hyper_activation = Activation(hyper_act) - - self.D = nn.Parameter(torch.randn(channels, self.H)) - - if self.bidirectional: - channels *= 2 - - # SSM Kernel - self.kernel = SSKernel( - self.H, - N=self.N, - L=self.L, - channels=channels, - verbose=verbose, - **kernel_args, - ) - - # Pointwise - if not self.linear: - self.activation = Activation(activation) - # dropout_fn = nn.Dropout2d if self.transposed else nn.Dropout # Broken in torch==1.11 - dropout_fn = DropoutNd if tie_dropout else nn.Dropout - self.dropout = dropout_fn(dropout) if dropout > 0.0 else nn.Identity() - # position-wise output transform to mix features - if not self.linear: - self.output_linear = LinearActivation( - self.H * self.channels, - self.d_model * (1 if self.gate is None else self.gate), - transposed=self.transposed, - initializer=initializer, - activation=postact, - activate=True, - weight_norm=weight_norm, - ) - - def forward( - self, u, state=None, rate=1.0, lengths=None, **kwargs - ): # absorbs return_output and transformer src mask - """ - u: (B H L) if self.transposed else (B L H) - state: (H N) never needed unless you know what you're doing - - Returns: same shape as u - """ - if not self.transposed: - u = u.transpose(-1, -2) - - L = u.size(-1) - # Mask out padding tokens - # TODO handle option for mask - instead of lengths, which assumes suffix padding - if isinstance(lengths, int): - if lengths != L: - lengths = torch.tensor(lengths, dtype=torch.long, device=u.device) - else: - lengths = None - if lengths is not None: - assert ( - isinstance(lengths, torch.Tensor) - and lengths.ndim == 1 - and lengths.size(0) in [1, u.size(0)] - ) - mask = torch.where( - torch.arange(L, device=lengths.device) < lengths[:, None, None], - 1.0, - 0.0, - ) - u = u * mask - - if self.gate is not None: - v = self.input_gate(u) - if self.bottleneck is not None: - u = self.input_linear(u) - - # Compute SS Kernel - L_kernel = L if self.L is None else min(L, round(self.L / rate)) - k, k_state = self.kernel( - L=L_kernel, rate=rate, state=state - ) # (C H L) (B C H L) - # Convolution - if self.bidirectional: - k0, k1 = rearrange(k, "(s c) h l -> s c h l", s=2) - k = F.pad(k0, (0, L)) + F.pad(k1.flip(-1), (L, 0)) - if self.shift: - # Try flip and pad to correct for potential off-by-one - k_f = torch.fft.rfft(F.pad(k.flip(-1), (L, 0)), n=2 * L) # (C H L) - u_f = torch.fft.rfft(F.pad(u.flip(-1), (L, 0)), n=2 * L) # (B H L) - y_f = contract( - "bhl,chl->bchl", u_f, k_f - ) # k_f.unsqueeze(-4) * u_f.unsqueeze(-3) # (B C H L) - y = torch.fft.irfft(y_f, n=L_kernel + L)[..., L:].flip(-1) # (B C H L) - else: - k_f = torch.fft.rfft(k, n=L_kernel + L) # (C H L) - u_f = torch.fft.rfft(u, n=L_kernel + L) # (B H L) - y_f = contract("bhl,chl->bchl", u_f, k_f) - y = torch.fft.irfft(y_f, n=L_kernel + L)[..., :L] # (B C H L) - - # Compute D term in state space equation - essentially a skip connection - y = y + contract("bhl,ch->bchl", u, self.D) - - # Compute state update - if state is not None: - assert ( - not self.bidirectional - ), "Bidirectional not supported with state forwarding" - y = y + k_state # - next_state = self.kernel.forward_state(u, state) - else: - next_state = None - - # Optional hyper-network multiplication - if self.hyper: - y, yh = rearrange(y, "b (s c) h l -> s b c h l", s=2) - y = self.hyper_activation(yh) * y - - # Reshape to flatten channels - y = rearrange(y, "... c h l -> ... (c h) l") - - if not self.linear: - y = self.dropout(self.activation(y)) - - if not self.transposed: - y = y.transpose(-1, -2) - - if not self.linear: - y = self.output_linear(y) - - if self.gate is not None: - y = self.output_gate(y * v) - return y, next_state - - def setup_step(self, **kwargs): - self.kernel._setup_step(**kwargs) - - def step(self, u, state): - """Step one time step as a recurrent model. Intended to be used during validation. - - u: (B H) - state: (B H N) - Returns: output (B H), state (B H N) - """ - assert not self.training - # u = u.squeeze(1) # (B H) - y, next_state = self.kernel.step(u, state) # (B C H) - y = y + u.unsqueeze(-2) * self.D - y = rearrange(y, "b c h -> b (c h)") - y = self.activation(y) - if self.transposed: - y = self.output_linear(y.unsqueeze(-1)).squeeze(-1) - else: - y = self.output_linear(y) - return y, next_state - - def default_state(self, *batch_shape, device=None): - # kernel is not a SequenceModule so it doesn't need to adhere to same interface - # the kernel will know the device of its own parameters - return self.kernel.default_state(*batch_shape) - - @property - def d_state(self): - return self.H * self.N - - @property - def d_output(self): - return self.d_model - - @property - def state_to_tensor(self): - return lambda state: rearrange("... h n -> ... (h n)", state) diff --git a/src/clm/module_library/sequence_model.py b/src/clm/module_library/sequence_model.py deleted file mode 100644 index 803746a1..00000000 --- a/src/clm/module_library/sequence_model.py +++ /dev/null @@ -1,204 +0,0 @@ -from functools import partial -from typing import Sequence, Mapping -import torch -import torch.nn as nn -from einops import rearrange - -from .sequence_residual_block import SequenceResidualBlock -from .sequence_module import SequenceModule -from .util_modules import Normalization, DropoutNd - - -def is_list(x): - return isinstance(x, Sequence) and not isinstance(x, str) - - -def is_dict(x): - return isinstance(x, Mapping) - - -def to_dict(x, recursive=True): - """Convert Sequence or Mapping object to dict - - lists get converted to {0: x[0], 1: x[1], ...} - """ - if is_list(x): - x = {i: v for i, v in enumerate(x)} - if is_dict(x): - if recursive: - return {k: to_dict(v, recursive=recursive) for k, v in x.items()} - else: - return dict(x) - else: - return x - - -def to_list(x, recursive=False): - """Convert an object to list. - - If Sequence (e.g. list, tuple, Listconfig): just return it - - Special case: If non-recursive and not a list, wrap in list - """ - if is_list(x): - if recursive: - return [to_list(_x) for _x in x] - else: - return list(x) - else: - if recursive: - return x - else: - return [x] - - -class SequenceModel(SequenceModule): - def __init__( - self, - d_model, # Resize input (useful for deep models with residuals) - n_layers=1, # Number of layers - transposed=False, # Transpose inputs so each layer receives (batch, dim, length) - dropout=0.0, # Dropout parameter applied on every residual and every layer - tie_dropout=False, # Tie dropout mask across sequence like nn.Dropout1d/nn.Dropout2d - prenorm=True, # Pre-norm vs. post-norm - n_repeat=1, # Each layer is repeated n times per stage before applying pooling - layer=None, # Layer config, must be specified - # residual=None, # Residual config - residual="R", # Residual config # changed the default value from None to "R" - # norm=None, # Normalization config (e.g. layer vs batch) - norm="layer", # Normalization config (e.g. layer vs batch) # changed the default value from None to "layer" - pool=None, # Config for pooling layer per stage - # track_norms=True, # Log norms of each layer output; changed the default value from True to False - track_norms=False, # Log norms of each layer output; changed the default value from True to False - dropinp=0.0, # Input dropout - ): - super().__init__() - # Save arguments needed for forward pass - self.d_model = d_model - self.transposed = transposed - self.track_norms = track_norms - - # Input dropout (not really used) - dropout_fn = ( - partial(DropoutNd, transposed=self.transposed) - if tie_dropout - else nn.Dropout - ) - self.drop = dropout_fn(dropinp) if dropinp > 0.0 else nn.Identity() - - layer = to_list(layer, recursive=False) - - # Some special arguments are passed into each layer - for _layer in layer: - # If layers don't specify dropout, add it - if _layer.get("dropout", None) is None: - _layer["dropout"] = dropout - # Ensure all layers are shaped the same way - _layer["transposed"] = transposed - - # Duplicate layers - layers = layer * n_layers * n_repeat - - # Instantiate layers - _layers = [] - d = d_model - for layer_idx, layer in enumerate(layers): - # Pool at the end of every n_repeat blocks - pool_cfg = pool if (layer_idx + 1) % n_repeat == 0 else None - block = SequenceResidualBlock( - d, - layer_idx + 1, - prenorm=prenorm, - dropout=dropout, - tie_dropout=tie_dropout, - transposed=transposed, - layer_config=layer, - residual=residual, - norm=norm, - pool=pool_cfg, - ) - _layers.append(block) - d = block.d_output - - self.d_output = d - self.layers = nn.ModuleList(_layers) - if prenorm: - if norm is None: - self.norm = None - elif isinstance(norm, str): - self.norm = Normalization( - self.d_output, transposed=self.transposed, _name_=norm - ) - else: - self.norm = Normalization( - self.d_output, transposed=self.transposed, **norm - ) - else: - self.norm = nn.Identity() - - def forward(self, inputs, *args, state=None, **kwargs): - """Inputs assumed to be (batch, sequence, dim)""" - if self.transposed: - inputs = rearrange(inputs, "b ... d -> b d ...") - inputs = self.drop(inputs) - - # Track norms - if self.track_norms: - output_norms = [torch.mean(inputs.detach() ** 2)] - - # Apply layers - outputs = inputs - prev_states = [None] * len(self.layers) if state is None else state - next_states = [] - for layer, prev_state in zip(self.layers, prev_states): - outputs, state = layer(outputs, *args, state=prev_state, **kwargs) - next_states.append(state) - if self.track_norms: - output_norms.append(torch.mean(outputs.detach() ** 2)) - if self.norm is not None: - outputs = self.norm(outputs) - - if self.transposed: - outputs = rearrange(outputs, "b d ... -> b ... d") - - if self.track_norms: - metrics = to_dict(output_norms, recursive=False) - self.metrics = {f"norm/{i}": v for i, v in metrics.items()} - - return outputs, next_states - - @property - def d_state(self): - d_states = [layer.d_state for layer in self.layers] - return sum([d for d in d_states if d is not None]) - - @property - def state_to_tensor(self): - # Slightly hacky way to implement this in a curried manner (so that the function can be extracted from an instance) - # Somewhat more sound may be to turn this into a @staticmethod and grab subclasses using hydra.utils.get_class - def fn(state): - x = [ - _layer.state_to_tensor(_state) - for (_layer, _state) in zip(self.layers, state) - ] - x = [_x for _x in x if _x is not None] - return torch.cat(x, dim=-1) - - return fn - - def default_state(self, *batch_shape, device=None): - return [ - layer.default_state(*batch_shape, device=device) for layer in self.layers - ] - - def step(self, x, state, **kwargs): - prev_states = [None] * len(self.layers) if state is None else state - next_states = [] - layer_idx = 0 - for layer, prev_state in zip(self.layers, prev_states): - x, state = layer.step(x, state=prev_state, **kwargs) - next_states.append(state) - layer_idx += 1 - - x = self.norm(x) - return x, next_states diff --git a/src/clm/module_library/sequence_module.py b/src/clm/module_library/sequence_module.py deleted file mode 100644 index 7daa2842..00000000 --- a/src/clm/module_library/sequence_module.py +++ /dev/null @@ -1,137 +0,0 @@ -from torch import nn -import functools - - -class SequenceModule(nn.Module): - """Abstract sequence model class. All models must adhere to this interface - - A SequenceModule is generally a model that transforms an input of shape - (n_batch, l_sequence, d_model) to (n_batch, l_sequence, d_output) - - REQUIRED methods and attributes - forward, d_model, d_output: controls standard forward pass, a sequence-to-sequence transformation - __init__ should also satisfy the following interface; see SequenceIdentity for an example - def __init__(self, d_model, transposed=False, **kwargs) - - OPTIONAL methods - default_state, step: allows stepping the model recurrently with a hidden state - state_to_tensor, d_state: allows decoding from hidden state - """ - - @property - def d_model(self): - """Model dimension (generally same as input dimension). - - This attribute is required for all SequenceModule instantiations. - It is used by the rest of the pipeline (e.g. model backbone, encoder) to track the internal shapes of the full model. - """ - if getattr(self, "_d_model", None) is None: - raise NotImplementedError("SequenceModule instantiation must set d_model") - return self._d_model - - @d_model.setter - def d_model(self, d): - self._d_model = d - - @property - def d_output(self): - """Output dimension of model. - - This attribute is required for all SequenceModule instantiations. - It is used by the rest of the pipeline (e.g. model backbone, decoder) to track the internal shapes of the full model. - """ - if getattr(self, "_d_output", None) is None: - raise NotImplementedError( - "SequenceModule instantiation must specify d_output for decoder" - ) - return self._d_output - - @d_output.setter - def d_output(self, d): - self._d_output = d - - def forward(self, x, state=None, **kwargs): - """Forward pass of sequence model, a sequence-to-sequence transformation with an optional state. - - Generally, this should map a tensor of shape (batch, length, self.d_model) to (batch, length, self.d_output) - - Additionally, it returns a "state" which can be any additional information - For example, RNN and SSM layers may return their hidden state, - while some types of transformer layers (e.g. Transformer-XL) may want to pass a state as well - """ - return x, None - - @property - def state_to_tensor(self): - """Returns a function mapping a state to a single tensor. - - This method should be implemented if one wants to use the hidden state instead of the output sequence for final prediction. - Currently only used with the StateDecoder. - """ - return lambda _: None - - @property - def d_state(self): - """Returns dimension of output of self.state_to_tensor""" - return None - - def default_state(self, *batch_shape, device=None): - """Create initial state for a batch of inputs.""" - - return None - - def step(self, x, state=None, **kwargs): - """Step the model recurrently for one step of the input sequence. - - For example, this should correspond to unrolling an RNN for one step. - If the forward pass has signature (B, L, H1) -> (B, L, H2), - this method should generally have signature (B, H1) -> (B, H2) with an optional recurrent state. - """ - raise NotImplementedError - - -def TransposedModule(module): - """Wrap a SequenceModule class to accept transposed parameter, handle state, absorb kwargs""" - # https://stackoverflow.com/a/65470430/1980685 - @functools.wraps(module, updated=()) - class TransposedModule(module): - def __init__(self, *args, transposed=False, **kwargs): - super().__init__(*args, **kwargs) - self.transposed = transposed - - def forward(self, x, state=None, **kwargs): - if self.transposed: - x = x.transpose(-1, -2) - x, next_state = super().forward(x, state) # Don't use kwarg because nn.LSTM - next_state = None if state is None else next_state - if self.transposed: - x = x.transpose(-1, -2) - return x, next_state - - # https://stackoverflow.com/questions/5352781/how-to-set-class-names-dynamically - # TransposedModule.__name__ = module.__name__ # functools wraps is better solution - return TransposedModule - - -@TransposedModule -class SequenceIdentity(SequenceModule): - """Simple SequenceModule for testing purposes""" - - def __init__(self, d_model, dropout=0.0, **kwargs): - """Default interface for SequenceModule - - d_model: input dimension (sometimes denoted H for hidden dimension) - transposed: if True, inputs have axis ordering (B, H, L) instead of (B, H, L) - """ - super().__init__() - self.d_model = d_model - self.d_output = d_model - - def forward(self, x, state=None): - return x, state - - def default_state(self, *batch_shape, device=None): - return None - - def step(self, x, state=None, **kwargs): - return x, state diff --git a/src/clm/module_library/sequence_residual_block.py b/src/clm/module_library/sequence_residual_block.py deleted file mode 100644 index ba5551a9..00000000 --- a/src/clm/module_library/sequence_residual_block.py +++ /dev/null @@ -1,148 +0,0 @@ -from torch import nn - -from functools import partial - -from .util_modules import Normalization, StochasticDepth, DropoutNd -from .sequence_module import SequenceModule -from .s4 import S4 -from .ff import FF -from .pool import DownAvgPool -from .residual import Residual - - -class SequenceResidualBlock(SequenceModule): - def __init__( - self, - d_input, - i_layer=None, # Only needs to be passed into certain residuals like Decay - prenorm=True, - dropout=0.0, - tie_dropout=False, - transposed=False, - layer_config=None, # Config for black box module - residual=None, # Config for residual function - norm=None, # Config for normalization layer - pool=None, - drop_path=0.0, - ): - super().__init__() - - self.i_layer = i_layer - self.d_input = d_input - # self.layer = instantiate(registry.layer, layer, d_input) - # layer_config = layer.copy() - # layer_cls = registry.get_layer(layer["_name_"]) - layer_config = layer_config.copy() - if layer_config["_name_"] == "s4": - layer_cls = S4 - elif layer_config["_name_"] == "ff": - layer_cls = FF - layer_config.pop("_name_") - self.layer = layer_cls(d_input, **layer_config) - - self.prenorm = prenorm - self.transposed = transposed - - # Residual - # d_residual is the output dimension after residual - if residual is None: - self.residual = None - self.d_residual = self.layer.d_output - else: - # self.residual = instantiate( - # residual_registry, residual, i_layer, d_input, self.layer.d_output - # ) - self.residual = Residual(i_layer, d_input, self.layer.d_output) - # instantiate( - # residual_registry, residual, i_layer, d_input, self.layer.d_output - # ) - self.d_residual = self.residual.d_output - - # Normalization - d_norm = d_input if self.prenorm else self.d_residual - # We don't use config to directly instantiate since Normalization has some special cases - if norm is None: - self.norm = None - elif isinstance(norm, str): - self.norm = Normalization(d_norm, transposed=self.transposed, _name_=norm) - else: - self.norm = Normalization(d_norm, transposed=self.transposed, **norm) - - # Pool - if pool is not None: - self.pool = DownAvgPool(self.d_residual, transposed=self.transposed) - - # Dropout - dropout_cls = ( - partial(DropoutNd, transposed=self.transposed) - if tie_dropout - else nn.Dropout - ) - self.drop = dropout_cls(dropout) if dropout > 0.0 else nn.Identity() - - # Stochastic depth - self.drop_path = ( - StochasticDepth(drop_path, mode="row") if drop_path > 0.0 else nn.Identity() - ) - - @property - def d_output(self): - return self.pool.d_output if self.pool is not None else self.d_residual - - @property - def d_state(self): - return self.layer.d_state - - @property - def state_to_tensor(self): - return self.layer.state_to_tensor - - def default_state(self, *args, **kwargs): - return self.layer.default_state(*args, **kwargs) - - def forward(self, x, state=None, **kwargs): - y = x - - # Pre-norm - if self.norm is not None and self.prenorm: - y = self.norm(y) - - # Black box layer - y, state = self.layer(y, state=state, **kwargs) - - # Residual - if self.residual is not None: - y = self.residual(x, self.drop_path(self.drop(y)), self.transposed) - # Post-norm - if self.norm is not None and not self.prenorm: - y = self.norm(y) - - # Pool - if self.pool is not None: - y, _ = self.pool(y) - - return y, state - - def step(self, x, state, **kwargs): - y = x - - # Pre-norm - if self.norm is not None and self.prenorm: - y = self.norm.step(y) - - # Black box layer - y, state = self.layer.step(y, state, **kwargs) - # Residual - if self.residual is not None: - y = self.residual( - x, y, transposed=self.transposed - ) # NOTE this would not work with concat residual function (catformer) - # Post-norm - if self.norm is not None and not self.prenorm: - y = self.norm.step(y) - - # Pool - if self.pool is not None: - y, _ = self.pool(y) - - return y, state diff --git a/src/clm/module_library/toeplitz.py b/src/clm/module_library/toeplitz.py deleted file mode 100644 index 5e382442..00000000 --- a/src/clm/module_library/toeplitz.py +++ /dev/null @@ -1,156 +0,0 @@ -import torch -import torch.nn.functional as F - - -def construct_toeplitz(v, f=0.0): - """Explicit construction of Krylov matrix [v A @ v A^2 @ v ... A^{n-1} @ v] - where A = Z_f. This uses vectorized indexing and cumprod so it's much - faster than using the Krylov function. - Parameters: - v: the starting vector of size n or (rank, n). - f: real number - Returns: - K: Krylov matrix of size (n, n) or (rank, n, n). - """ - n = v.shape[-1] - a = torch.arange(n, device=v.device) - b = -a - indices = a[:, None] + b[None] - K = v[..., indices] - K[..., indices < 0] *= f - return K - - -def triangular_toeplitz_multiply_(u, v, sum=None): - n = u.shape[-1] - u_expand = F.pad(u, (0, n)) - v_expand = F.pad(v, (0, n)) - u_f = torch.fft.rfft(u_expand, n=2 * n, dim=-1) - v_f = torch.fft.rfft(v_expand, n=2 * n, dim=-1) - uv_f = u_f * v_f - if sum is not None: - uv_f = uv_f.sum(dim=sum) - output = torch.fft.irfft(uv_f, n=2 * n, dim=-1)[..., :n] - return output - - -def triangular_toeplitz_multiply_padded_(u, v): - """Same as triangular_toeplitz_multiply but inputs and output assume to be 0-padded already.""" - n = u.shape[-1] - assert n % 2 == 0 - u_f = torch.fft.rfft(u, n=n, dim=-1) - v_f = torch.fft.rfft(v, n=n, dim=-1) - uv_f = u_f * v_f - output = torch.fft.irfft(uv_f, n=n, dim=-1) - output[..., n:] = 0 - return output - - -class TriangularToeplitzMult(torch.autograd.Function): - @staticmethod - def forward(ctx, u, v): - ctx.save_for_backward(u, v) - return triangular_toeplitz_multiply_(u, v) - - @staticmethod - def backward(ctx, grad): - u, v = ctx.saved_tensors - d_u = triangular_toeplitz_multiply_(grad.flip(-1), v).flip(-1) - d_v = triangular_toeplitz_multiply_(grad.flip(-1), u).flip(-1) - return d_u, d_v - - -class TriangularToeplitzMultFast(torch.autograd.Function): - @staticmethod - def forward(ctx, u, v): - n = u.shape[-1] - u_expand = F.pad(u, (0, n)) - v_expand = F.pad(v, (0, n)) - u_f = torch.fft.rfft(u_expand, n=2 * n, dim=-1) - v_f = torch.fft.rfft(v_expand, n=2 * n, dim=-1) - - ctx.save_for_backward(u_f, v_f) - - uv_f = u_f * v_f - output = torch.fft.irfft(uv_f, n=2 * n, dim=-1)[..., :n] - return output - - @staticmethod - def backward(ctx, grad): - u_f, v_f = ctx.saved_tensors - n = grad.shape[-1] - g_expand = F.pad(grad.flip(-1), (0, n)) - g_f = torch.fft.rfft(g_expand, n=2 * n, dim=-1) - gu_f = g_f * u_f - gv_f = g_f * v_f - d_u = torch.fft.irfft(gv_f, n=2 * n, dim=-1)[..., :n] - d_v = torch.fft.irfft(gu_f, n=2 * n, dim=-1)[..., :n] - d_u = d_u.flip(-1) - d_v = d_v.flip(-1) - return d_u, d_v - - -class TriangularToeplitzMultPadded(torch.autograd.Function): - @staticmethod - def forward(ctx, u, v): - ctx.save_for_backward(u, v) - output = triangular_toeplitz_multiply_(u, v) - return output - - @staticmethod - def backward(ctx, grad): - u, v = ctx.saved_tensors - d_u = triangular_toeplitz_multiply_padded_(grad.flip(-1), v).flip(-1) - d_v = triangular_toeplitz_multiply_padded_(grad.flip(-1), u).flip(-1) - return d_u, d_v - - -class TriangularToeplitzMultPaddedFast(torch.autograd.Function): - """Trade off speed (20-25% faster) for more memory (20-25%)""" - - @staticmethod - def forward(ctx, u, v): - n = u.shape[-1] - u_f = torch.fft.rfft(u, n=n, dim=-1) - v_f = torch.fft.rfft(v, n=n, dim=-1) - - ctx.save_for_backward(u_f, v_f) - - uv_f = u_f * v_f - output = torch.fft.irfft(uv_f, n=n, dim=-1) - output[..., n // 2 :].zero_() - return output - - @staticmethod - def backward(ctx, grad): - u_f, v_f = ctx.saved_tensors - n = grad.shape[-1] - g_expand = F.pad(grad[..., : n // 2].flip(-1), (0, n // 2)) - g_f = torch.fft.rfft(g_expand, n=n, dim=-1) - gu_f = g_f * u_f - gv_f = g_f * v_f - d_u = torch.fft.irfft(gv_f, n=n, dim=-1) - d_v = torch.fft.irfft(gu_f, n=n, dim=-1) - d_u[..., n // 2 :].zero_() - d_v[..., n // 2 :].zero_() - d_u[..., : n // 2] = d_u[..., : n // 2].flip(-1) # TODO - d_v[..., : n // 2] = d_v[..., : n // 2].flip(-1) # TODO - return d_u, d_v - - -# triangular_toeplitz_multiply = triangular_toeplitz_multiply_ -triangular_toeplitz_multiply = TriangularToeplitzMult.apply -triangular_toeplitz_multiply_fast = TriangularToeplitzMultFast.apply -triangular_toeplitz_multiply_padded = TriangularToeplitzMultPadded.apply -triangular_toeplitz_multiply_padded_fast = TriangularToeplitzMultPaddedFast.apply - - -def causal_convolution(u, v, fast=True, pad=False): - if not pad and not fast: - return triangular_toeplitz_multiply(u, v) - if not pad and fast: - return triangular_toeplitz_multiply_fast(u, v) - if pad and not fast: - return triangular_toeplitz_multiply_padded(u, v) - if pad and fast: - return triangular_toeplitz_multiply_padded_fast(u, v) diff --git a/src/clm/module_library/util_modules.py b/src/clm/module_library/util_modules.py deleted file mode 100644 index 9f4666f3..00000000 --- a/src/clm/module_library/util_modules.py +++ /dev/null @@ -1,318 +0,0 @@ -import math -from functools import partial -import torch -from torch import nn -from einops import rearrange -from opt_einsum import contract - - -def get_initializer(name, activation=None): - if activation in [None, "id", "identity", "linear", "modrelu"]: - nonlinearity = "linear" - elif activation in ["relu", "tanh", "sigmoid"]: - nonlinearity = activation - elif activation in ["gelu", "swish", "silu"]: - nonlinearity = "relu" # Close to ReLU so approximate with ReLU's gain - else: - raise NotImplementedError( - f"get_initializer: activation {activation} not supported" - ) - - if name == "uniform": - initializer = partial(torch.nn.init.kaiming_uniform_, nonlinearity=nonlinearity) - elif name == "normal": - initializer = partial(torch.nn.init.kaiming_normal_, nonlinearity=nonlinearity) - elif name == "xavier": - initializer = torch.nn.init.xavier_normal_ - elif name == "zero": - initializer = partial(torch.nn.init.constant_, val=0) - elif name == "one": - initializer = partial(torch.nn.init.constant_, val=1) - else: - raise NotImplementedError( - f"get_initializer: initializer type {name} not supported" - ) - - return initializer - - -def Activation(activation=None, size=None, dim=-1): - if activation in [None, "id", "identity", "linear"]: - return nn.Identity() - elif activation == "tanh": - return nn.Tanh() - elif activation == "relu": - return nn.ReLU() - elif activation == "gelu": - return nn.GELU() - elif activation in ["swish", "silu"]: - return nn.SiLU() - elif activation == "glu": - return nn.GLU(dim=dim) - elif activation == "sigmoid": - return nn.Sigmoid() - elif activation == "softplus": - return nn.Softplus() - else: - raise NotImplementedError( - "hidden activation '{}' is not implemented".format(activation) - ) - - -class TransposedLinear(nn.Module): - """Linear module on the second-to-last dimension - Assumes shape (B, D, L), where L can be 1 or more axis - """ - - def __init__(self, d_input, d_output, bias=True): - super().__init__() - - self.weight = nn.Parameter(torch.empty(d_output, d_input)) - nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # nn.Linear default init - # nn.init.kaiming_uniform_(self.weight, nonlinearity='linear') # should be equivalent - - if bias: - self.bias = nn.Parameter(torch.empty(d_output)) - bound = 1 / math.sqrt(d_input) - nn.init.uniform_(self.bias, -bound, bound) - setattr(self.bias, "_optim", {"weight_decay": 0.0}) - else: - self.bias = 0.0 - - def forward(self, x): - num_axis = len(x.shape[2:]) # num_axis in L, for broadcasting bias - y = contract("b u ..., v u -> b v ...", x, self.weight) + self.bias.view( - -1, *[1] * num_axis - ) - return y - - -def LinearActivation( - d_input, - d_output, - bias=True, - zero_bias_init=False, - transposed=False, - initializer=None, - activation=None, - activate=False, # Apply activation as part of this module - weight_norm=False, - **kwargs, -): - """Returns a linear nn.Module with control over axes order, initialization, and activation""" - - # Construct core module - # linear_cls = partial(nn.Conv1d, kernel_size=1) if transposed else nn.Linear - linear_cls = TransposedLinear if transposed else nn.Linear - if activation == "glu": - d_output *= 2 - linear = linear_cls(d_input, d_output, bias=bias, **kwargs) - - # Initialize weight - if initializer is not None: - get_initializer(initializer, activation)(linear.weight) - - # Initialize bias - if bias and zero_bias_init: - nn.init.zeros_(linear.bias) - - # Weight norm - if weight_norm: - linear = nn.utils.weight_norm(linear) - - if activate and activation is not None: - activation = Activation(activation, d_output, dim=1 if transposed else -1) - linear = nn.Sequential(linear, activation) - return linear - - -class DropoutNd(nn.Module): - def __init__(self, p: float = 0.5, tie=True, transposed=True): - """ - tie: tie dropout mask across sequence lengths (Dropout1d/2d/3d) - """ - super().__init__() - if p < 0 or p >= 1: - raise ValueError( - "dropout probability has to be in [0, 1), " "but got {}".format(p) - ) - self.p = p - self.tie = tie - self.transposed = transposed - self.binomial = torch.distributions.binomial.Binomial(probs=1 - self.p) - - def forward(self, X): - """X: (batch, dim, lengths...)""" - if self.training: - if not self.transposed: - X = rearrange(X, "b d ... -> b ... d") - # binomial = torch.distributions.binomial.Binomial(probs=1-self.p) # This is incredibly slow - mask_shape = X.shape[:2] + (1,) * (X.ndim - 2) if self.tie else X.shape - # mask = self.binomial.sample(mask_shape) - mask = torch.rand(*mask_shape, device=X.device) < 1.0 - self.p - X = X * mask * (1.0 / (1 - self.p)) - if not self.transposed: - X = rearrange(X, "b ... d -> b d ...") - return X - return X - - -class Normalization(nn.Module): - def __init__( - self, - d, - transposed=False, # Length dimension is -1 or -2 - _name_="layer", - **kwargs, - ): - super().__init__() - self.transposed = transposed - self._name_ = _name_ - - if _name_ == "layer": - self.channel = True # Normalize over channel dimension - if self.transposed: - self.norm = TransposedLN(d, **kwargs) - else: - self.norm = nn.LayerNorm(d, **kwargs) - elif _name_ == "instance": - self.channel = False - norm_args = {"affine": False, "track_running_stats": False} - norm_args.update(kwargs) - self.norm = nn.InstanceNorm1d( - d, **norm_args - ) # (True, True) performs very poorly - elif _name_ == "batch": - self.channel = False - norm_args = {"affine": True, "track_running_stats": True} - norm_args.update(kwargs) - self.norm = nn.BatchNorm1d(d, **norm_args) - elif _name_ == "group": - self.channel = False - self.norm = nn.GroupNorm(1, d, *kwargs) - elif _name_ == "none": - self.channel = True - self.norm = nn.Identity() - else: - raise NotImplementedError - - def forward(self, x): - # Handle higher dimension logic - shape = x.shape - if self.transposed: - x = rearrange(x, "b d ... -> b d (...)") - else: - x = rearrange(x, "b ... d -> b (...)d ") - - # The cases of LayerNorm / no normalization are automatically handled in all cases - # Instance/Batch Norm work automatically with transposed axes - if self.channel or self.transposed: - x = self.norm(x) - else: - x = x.transpose(-1, -2) - x = self.norm(x) - x = x.transpose(-1, -2) - - x = x.view(shape) - return x - - def step(self, x, **kwargs): - assert self._name_ in ["layer", "none"] - if self.transposed: - x = x.unsqueeze(-1) - x = self.forward(x) - if self.transposed: - x = x.squeeze(-1) - return x - - -class TransposedLN(nn.Module): - """LayerNorm module over second dimension - Assumes shape (B, D, L), where L can be 1 or more axis - - This is slow and a dedicated CUDA/Triton implementation shuld provide substantial end-to-end speedup - """ - - def __init__(self, d, scalar=True): - super().__init__() - self.scalar = scalar - if self.scalar: - self.m = nn.Parameter(torch.zeros(1)) - self.s = nn.Parameter(torch.ones(1)) - setattr(self.m, "_optim", {"weight_decay": 0.0}) - setattr(self.s, "_optim", {"weight_decay": 0.0}) - else: - self.ln = nn.LayerNorm(d) - - def forward(self, x): - if self.scalar: - # calc. stats over D dim / channels - s, m = torch.std_mean(x, dim=1, unbiased=False, keepdim=True) - y = (self.s / s) * (x - m + self.m) - else: - # move channel to last axis, apply layer_norm, then move channel back to second axis - _x = self.ln(rearrange(x, "b d ... -> b ... d")) - y = rearrange(_x, "b ... d -> b d ...") - return y - - -def stochastic_depth(input: torch.tensor, p: float, mode: str, training: bool = True): - """ - Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth" - `_ used for randomly dropping residual - branches of residual architectures. - - Args: - input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one - being its batch i.e. a batch with ``N`` rows. - p (float): probability of the input to be zeroed. - mode (str): ``"batch"`` or ``"row"``. - ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes - randomly selected rows from the batch. - training: apply stochastic depth if is ``True``. Default: ``True`` - - Returns: - Tensor[N, ...]: The randomly zeroed tensor. - """ - if p < 0.0 or p > 1.0: - raise ValueError( - "drop probability has to be between 0 and 1, but got {}".format(p) - ) - if mode not in ["batch", "row"]: - raise ValueError( - "mode has to be either 'batch' or 'row', but got {}".format(mode) - ) - if not training or p == 0.0: - return input - - survival_rate = 1.0 - p - if mode == "row": - size = [input.shape[0]] + [1] * (input.ndim - 1) - else: - size = [1] * input.ndim - noise = torch.empty(size, dtype=input.dtype, device=input.device) - noise = noise.bernoulli_(survival_rate).div_(survival_rate) - return input * noise - - -class StochasticDepth(nn.Module): - """ - See :func:`stochastic_depth`. - """ - - def __init__(self, p: float, mode: str) -> None: - # TODO(karan): need to upgrade to torchvision==0.11.0 to use StochasticDepth directly - # from torchvision.ops import StochasticDepth - super().__init__() - self.p = p - self.mode = mode - - def forward(self, input): - return stochastic_depth(input, self.p, self.mode, self.training) - - def __repr__(self) -> str: - tmpstr = self.__class__.__name__ + "(" - tmpstr += "p=" + str(self.p) - tmpstr += ", mode=" + str(self.mode) - tmpstr += ")" - return tmpstr From c9c73e1ca9f6cabc2a45b44755aaadd8e6109664 Mon Sep 17 00:00:00 2001 From: Vishu Gupta Date: Wed, 17 Dec 2025 21:47:06 -0500 Subject: [PATCH 10/21] remove unnecessary imports and variables --- requirements.txt | 2 +- setup.sh | 2 +- src/clm/models.py | 11 +++-------- 3 files changed, 5 insertions(+), 10 deletions(-) diff --git a/requirements.txt b/requirements.txt index 89725aae..24b428d7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -129,4 +129,4 @@ wrapt==1.16.0 yte==1.5.4 zipp==3.19.0 einops==0.6.0 -opt_einsum==3.3.0 \ No newline at end of file +opt_einsum==3.3.0 diff --git a/setup.sh b/setup.sh index 90063e3c..26870182 100644 --- a/setup.sh +++ b/setup.sh @@ -24,4 +24,4 @@ cd .. # Install CLM package pip install -e . --no-deps -echo "Environment setup complete!" \ No newline at end of file +echo "Environment setup complete!" diff --git a/src/clm/models.py b/src/clm/models.py index 865a0778..3a8e08d6 100644 --- a/src/clm/models.py +++ b/src/clm/models.py @@ -4,13 +4,11 @@ import torch.nn.functional as F from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence -from einops import rearrange - # from clm.src.models.sequence.h3 import H3 # from clm.src.models.sequence.h3_conv import H3Conv # from clm.src.models.sequence.hyena_components import HyenaOperator -from .module_library.sequence_model import SequenceModel +from s4dd.module_library.sequence_model import SequenceModel # class H3Model(nn.Module): @@ -631,9 +629,6 @@ def __init__( self.vocabulary_size, self.model_dim, padding_idx=padding_t ) - # Import SequenceModel from your module library - from .module_library.sequence_model import SequenceModel - self.model = SequenceModel( d_model=self.model_dim, n_layers=self.n_layers, @@ -701,8 +696,8 @@ def loss(self, batch): # Likely (seq_len, batch_size), transpose to (batch_size, seq_len) padded = padded.transpose(0, 1) - batch_size = padded.shape[0] - seq_len = padded.shape[1] + # batch_size = padded.shape[0] + # seq_len = padded.shape[1] # Don't use recurrent state during training - use full convolution mode self.recurrent_state = None From d484bc8bcf33e4715328dec8297c963670b2fc6e Mon Sep 17 00:00:00 2001 From: Vishu Gupta Date: Wed, 17 Dec 2025 21:59:40 -0500 Subject: [PATCH 11/21] remove yaml files --- .../config-spectraverse-allv1-s4_cv.yaml | 196 ------------------ ...fig-spectraverse-allv1-transformer_cv.yaml | 196 ------------------ 2 files changed, 392 deletions(-) delete mode 100644 workflow/config/config-spectraverse-allv1-s4_cv.yaml delete mode 100644 workflow/config/config-spectraverse-allv1-transformer_cv.yaml diff --git a/workflow/config/config-spectraverse-allv1-s4_cv.yaml b/workflow/config/config-spectraverse-allv1-s4_cv.yaml deleted file mode 100644 index cfa201b5..00000000 --- a/workflow/config/config-spectraverse-allv1-s4_cv.yaml +++ /dev/null @@ -1,196 +0,0 @@ -# Molecular sequence representations of chemical species for training and sampling. -# Determines how the molecules are encoded internally. -# The only avaiable option for now is 'SMILES'. -representations: - - SMILES - -# The number of cross-validation folds. -# The dataset is split into train/test set for each fold, and models are trained/tested on each fold. -folds: 10 - -# Seeds used to initialize random number generators for the training runs. -# Each seed corresponds to a separate training run. -# Each fold trains 'train_seeds' number of models on the training set for that fold. -train_seeds: - - 0 - -# Seeds used when sampling molecules from the trained models. -# The number of 'sample_seeds' values specifies how many times the 'sample_molecules_RNN' step is executed, -# each time using the same trained model but with different random seed values. -sample_seeds: - - 0 - -# Specifies by how many times the input data is augmented (or enumerated) before training. -# Augmentation refers to the fact that a single molecule can have multiple SMILES representation. -# For example: -# - A value of 0 means no augmentation, leaving the input data unchanged. -# - A value of 100 means each molecule can have up to 100 different SMILES representations in the training set. -# Note: Both 0 and 1 indicate no augmentation, but with 1, the representations are updated to be different -# than those provided in the original dataset. -enum_factors: - - 0 - -# Limits the maximum number of input SMILES to be read from the original dataset. 0 means there's no limit. -max_input_smiles: 0 - -# A dictionary defining the arguments to be passed to the preprocess command. -preprocess: - # Specifies the minimum number of heavy atoms that a valid molecule should have. - min_heavy_atoms: 3 - # Defines the set of elements required for a molecule to be considered valid. - # Any SMILES containing elements outside this set will be considered invalid and excluded from the training set. - valid_atoms: [Br, C, Cl, F, H, I, N, O, P, S, Se, Si, B, As] - # Specifies whether the charges in the training SMILES should be neutralized. - neutralise: false - # Specifies whether to remove training SMILES representing molecules with tokens found in less than 0.01% of samples or fewer than 10 molecules. - remove_rare: false - # Specifies whether to remove duplicate SMILES from the training set, identifying duplicates by inchikey. - keep_duplicates: false - -# Parameters that define the neural network model and training process. -model_params: - # Type of Recurrent Neural Network (RNN) to use. - # Available options are 'LSTM' and 'GRU' - rnn_type: S4 - embedding_size: 128 # Size of the embedding vectors that represent each token in the input sequence. - hidden_size: 256 # Size of the hidden state of the RNN. - n_layers: 2 # Number of stacked RNN layers in the model. - dropout: 0 # Dropout rate applied to the RNN layer for regularization. - batch_size: 64 # Number of samples processed before the models internal parameters are updated. - learning_rate: 0.001 # Used by the optimizer to update model parameters. - max_epochs: 999999 # Maximum number of training epochs (complete passes through the training dataset). - patience: 50000 # Number of steps with no improvement in the validation loss after which early stopping is triggered. - - # An RNN model conditioned on input descriptors (experimentally obtained properties of the input SMILES). - # Note that rnn_type and other RNN architecture parameters are still applicable in this case. - conditional: - # Is the conditional model enabled? - enabled: false - - # Note: Both emb and emb_l below cannot be true at the same time. - # Concatenate the descriptors directly to the token embeddings at each step in the sequence? - emb: false - # Concatenate the descriptors to the token embeddings, but by first passing them through a - # linear layer to obtain embeddings of dimensionality equal to that of the token embeddings? - emb_l: true - - # Note: Both dec and dec_l below cannot be true at the same time. - # Concatenate the descriptors directly to the output of the RNN layers - # (prior to the decoder layer)? - dec: false - # Concatenate the descriptors to the output of the RNN layers - # (prior to the decoder layer), but by first passing them through a - # linear layer to obtain embeddings of dimensionality equal to that of the token embeddings? - dec_l: true - - # Instantiate the hidden states based on learned transformations of the descriptors - # (with a single linear layer), as in Kotsias et al? - h: false - - # Frequency of logging training progress in terms of steps (batches). - log_every_steps: 100 - # Frequency of logging training progress in terms of epochs. - log_every_epochs: 1 - # Number of molecules to sample from the trained model after training. - sample_mols: 1000000 - -# When looking at sampled molecules across all folds, what metric(s) do we -# use for aggregating frequencies? -metrics: - # With what frequency (across all folds) was each valid molecule produced? - # - freq-sum - # With what average frequency (across all folds) was each valid molecule produced? - - freq-avg - # With what average frequency (across all folds) was each valid molecule produced, - # as a fraction of total sampling frequency (x 10e3 to avoid ~0 values) - # - fp10k - -# Minimum Tanimoto coefficient threshold to filter out molecules from training set. -# This allows for only similar SMILES to be considered from the preprocessed dataset -# for the creation of training/ testing folds, (with or without augmentation). -# 0 (default) means no filtering based on Tanimoto similarity. -min_tc: 0 - -# Number of top candidate molecules to consider when evaluating correctness. -# Here, correctness is defined as an exact mass match within a specified error range. -# Example: -# A value of 30 means that the 30 most frequently generated molecules with a mass -# matching the target molecule's mass within the allowed error margin are considered -# for further evaluation. -top_k: 30 - -# Error tolerance in parts per million for mass-matching to consider a molecule "correct". -# Used in rules that evaluate the correctness of generated or sampled molecules against -# known test molecules based on mass. -err_ppm: 10 - -# Specifies minimum frequency thresholds for inclusion. -# Each value represents the minimum number of times a molecule must be generated -# across all folds to be considered for further evaluation. -structural_prior_min_freq: - - 1 - -# seed used as a global random seed for steps not covered by 'train_seeds' or 'sample_seeds'. -random_seed: 5831 - -# A dictionary that defines various input and output file paths, incorporating wildcards. -paths: - # Modify these paths to match your system. - - # Base directory for outputs - output_dir: '/Genomics/argo/users/vg8892/git/CLM/workflow/data_spectraverse_allv1_s4_cv' - # The input dataset file. - dataset: "/Genomics/argo/users/vg8892/git/CLM/data/spectraverse_allv1.txt" - # File containing data from PubChem. - pubchem_tsv_file: "../tests/test_data/PubChem_truncated.tsv" - - # The following paths can be modified, as long as all wildcards are preserved in each case. - - # Processed dataset before augmentation/training. - preprocess_output: "{output_dir}/prior/raw/{dataset}.txt" - # Training file for each cross-validation fold. - train_file: "{output_dir}/{enum_factor}/prior/inputs/train_{dataset}_{repr}_{fold}.smi" - # Vocabulary file for the tokenized sequences. - vocab_file: "{output_dir}/{enum_factor}/prior/inputs/train_{dataset}_{repr}_{fold}.vocabulary" - # Trained RNN model checkpoint file. - model_file: "{output_dir}/{enum_factor}/prior/models/{dataset}_{repr}_{fold}_{train_seed}_model.pt" - # Sampled dataset for each fold. - input_file: "{output_dir}/{enum_factor}/prior/samples/{dataset}_{repr}_{fold}_{train_seed}_{sample_seed}_samples.csv.gz" - # Unaugmented training dataset for each cross-validation fold. - train0_file: "{output_dir}/{enum_factor}/prior/inputs/train0_{dataset}_{repr}_{fold}.smi" - # Unaugmented test dataset for each cross-validation fold. - test0_file: "{output_dir}/{enum_factor}/prior/inputs/test0_{dataset}_{repr}_{fold}.smi" - # A file generated by add_carbon rule, inserting carbon symbols at random spots in training SMILES. - # carbon_file: "{output_dir}/{enum_factor}/prior/inputs/train0_{dataset}_{repr}_{fold}_carbon.csv.gz" - # Complete training dataset aggregated across all folds. - train_all_file: "{output_dir}/{enum_factor}/prior/inputs/train_{dataset}_{repr}_all.smi" - # Complete testing dataset aggregated across all folds. - test_all_file: "{output_dir}/{enum_factor}/prior/inputs/test_{dataset}_{repr}_all.smi" - # Complete aggregated SMILES from add_carbon rule across all folds. - # carbon_all_file: "{output_dir}/{enum_factor}/prior/inputs/train_{dataset}_{repr}_carbon_all.csv.gz" - # Top-n candidate SMILES based on matching by exact mass for each cross-validation fold. - cv_ranks_file: "{output_dir}/{enum_factor}/prior/structural_prior/{dataset}_{repr}_{fold}_CV_ranks_structure.csv.gz" - # Top-n candidate SMILES based on matching mass including Tanimoto coefficient for each cross-validation fold. - cv_tc_file: "{output_dir}/{enum_factor}/prior/structural_prior/{dataset}_{repr}_{fold}_CV_tc.csv.gz" - # Top-n candidate SMILES (correctness based on formula rather than structure) for each cross-validation fold. - formula_ranks_file: "{output_dir}/{enum_factor}/prior/structural_prior/{dataset}_{repr}_{fold}_CV_ranks_formula.csv.gz" - # Sampled SMILES aggregated across all folds. - process_tabulated_output: "{output_dir}/{enum_factor}/prior/samples/{dataset}_{repr}_processed_min{min_freq}_{metric}.csv.gz" - # Loss curves from model training. - loss_file: "{output_dir}/{enum_factor}/prior/models/{dataset}_{repr}_{fold}_{train_seed}_loss.csv.gz" - # Novel SMILES generated by each trained model, along with inchikey, mass and formula. - tabulate_molecules_output: "{output_dir}/{enum_factor}/prior/samples/{dataset}_{repr}_{fold}_{train_seed}_{sample_seed}_samples_masses.csv.gz" - # Aggregated sampled SMILES from all the trained models in a fold. - collect_tabulated_output: "{output_dir}/{enum_factor}/prior/samples/{dataset}_{repr}_{fold}_unique_masses.csv.gz" - # Top-n candidate SMILES based on matching mass across all folds. - overall_ranks_file: "{output_dir}/{enum_factor}/prior/structural_prior/{dataset}_{repr}_min{min_freq}_all_{metric}_CV_ranks_structure.csv.gz" - # Top-n candidate SMILES based on matching mass including tc per fold. - overall_tc_file: "{output_dir}/{enum_factor}/prior/structural_prior/{dataset}_{repr}_min{min_freq}_all_{metric}_CV_tc.csv.gz" - # Sampled molecules per trained model that appear in training set. - known_smiles_file: "{output_dir}/{enum_factor}/prior/samples/known_{dataset}_{repr}_{fold}_{train_seed}_{sample_seed}_samples_masses.csv.gz" - # Invalid SMILES sampled per trained model. - invalid_smiles_file: "{output_dir}/{enum_factor}/prior/samples/invalid_{dataset}_{repr}_{fold}_{train_seed}_{sample_seed}_samples_masses.csv.gz" - # Known (training set) sampled molecules within a fold. - collect_known_smiles: "{output_dir}/{enum_factor}/prior/samples/known_{dataset}_{repr}_{fold}_unique_masses.csv.gz" - # Invalid sampled SMILES within a fold. - collect_invalid_smiles: "{output_dir}/{enum_factor}/prior/samples/invalid_{dataset}_{repr}_{fold}_unique_masses.csv.gz" diff --git a/workflow/config/config-spectraverse-allv1-transformer_cv.yaml b/workflow/config/config-spectraverse-allv1-transformer_cv.yaml deleted file mode 100644 index d96c6c9d..00000000 --- a/workflow/config/config-spectraverse-allv1-transformer_cv.yaml +++ /dev/null @@ -1,196 +0,0 @@ -# Molecular sequence representations of chemical species for training and sampling. -# Determines how the molecules are encoded internally. -# The only avaiable option for now is 'SMILES'. -representations: - - SMILES - -# The number of cross-validation folds. -# The dataset is split into train/test set for each fold, and models are trained/tested on each fold. -folds: 10 - -# Seeds used to initialize random number generators for the training runs. -# Each seed corresponds to a separate training run. -# Each fold trains 'train_seeds' number of models on the training set for that fold. -train_seeds: - - 0 - -# Seeds used when sampling molecules from the trained models. -# The number of 'sample_seeds' values specifies how many times the 'sample_molecules_RNN' step is executed, -# each time using the same trained model but with different random seed values. -sample_seeds: - - 0 - -# Specifies by how many times the input data is augmented (or enumerated) before training. -# Augmentation refers to the fact that a single molecule can have multiple SMILES representation. -# For example: -# - A value of 0 means no augmentation, leaving the input data unchanged. -# - A value of 100 means each molecule can have up to 100 different SMILES representations in the training set. -# Note: Both 0 and 1 indicate no augmentation, but with 1, the representations are updated to be different -# than those provided in the original dataset. -enum_factors: - - 0 - -# Limits the maximum number of input SMILES to be read from the original dataset. 0 means there's no limit. -max_input_smiles: 0 - -# A dictionary defining the arguments to be passed to the preprocess command. -preprocess: - # Specifies the minimum number of heavy atoms that a valid molecule should have. - min_heavy_atoms: 3 - # Defines the set of elements required for a molecule to be considered valid. - # Any SMILES containing elements outside this set will be considered invalid and excluded from the training set. - valid_atoms: [Br, C, Cl, F, H, I, N, O, P, S, Se, Si, B, As] - # Specifies whether the charges in the training SMILES should be neutralized. - neutralise: false - # Specifies whether to remove training SMILES representing molecules with tokens found in less than 0.01% of samples or fewer than 10 molecules. - remove_rare: false - # Specifies whether to remove duplicate SMILES from the training set, identifying duplicates by inchikey. - keep_duplicates: false - -# Parameters that define the neural network model and training process. -model_params: - # Type of Recurrent Neural Network (RNN) to use. - # Available options are 'LSTM' and 'GRU' - rnn_type: Transformer - embedding_size: 128 # Size of the embedding vectors that represent each token in the input sequence. - hidden_size: 256 # Size of the hidden state of the RNN. - n_layers: 2 # Number of stacked RNN layers in the model. - dropout: 0 # Dropout rate applied to the RNN layer for regularization. - batch_size: 64 # Number of samples processed before the models internal parameters are updated. - learning_rate: 0.001 # Used by the optimizer to update model parameters. - max_epochs: 999999 # Maximum number of training epochs (complete passes through the training dataset). - patience: 50000 # Number of steps with no improvement in the validation loss after which early stopping is triggered. - - # An RNN model conditioned on input descriptors (experimentally obtained properties of the input SMILES). - # Note that rnn_type and other RNN architecture parameters are still applicable in this case. - conditional: - # Is the conditional model enabled? - enabled: false - - # Note: Both emb and emb_l below cannot be true at the same time. - # Concatenate the descriptors directly to the token embeddings at each step in the sequence? - emb: false - # Concatenate the descriptors to the token embeddings, but by first passing them through a - # linear layer to obtain embeddings of dimensionality equal to that of the token embeddings? - emb_l: true - - # Note: Both dec and dec_l below cannot be true at the same time. - # Concatenate the descriptors directly to the output of the RNN layers - # (prior to the decoder layer)? - dec: false - # Concatenate the descriptors to the output of the RNN layers - # (prior to the decoder layer), but by first passing them through a - # linear layer to obtain embeddings of dimensionality equal to that of the token embeddings? - dec_l: true - - # Instantiate the hidden states based on learned transformations of the descriptors - # (with a single linear layer), as in Kotsias et al? - h: false - - # Frequency of logging training progress in terms of steps (batches). - log_every_steps: 100 - # Frequency of logging training progress in terms of epochs. - log_every_epochs: 1 - # Number of molecules to sample from the trained model after training. - sample_mols: 1000000 - -# When looking at sampled molecules across all folds, what metric(s) do we -# use for aggregating frequencies? -metrics: - # With what frequency (across all folds) was each valid molecule produced? - # - freq-sum - # With what average frequency (across all folds) was each valid molecule produced? - - freq-avg - # With what average frequency (across all folds) was each valid molecule produced, - # as a fraction of total sampling frequency (x 10e3 to avoid ~0 values) - # - fp10k - -# Minimum Tanimoto coefficient threshold to filter out molecules from training set. -# This allows for only similar SMILES to be considered from the preprocessed dataset -# for the creation of training/ testing folds, (with or without augmentation). -# 0 (default) means no filtering based on Tanimoto similarity. -min_tc: 0 - -# Number of top candidate molecules to consider when evaluating correctness. -# Here, correctness is defined as an exact mass match within a specified error range. -# Example: -# A value of 30 means that the 30 most frequently generated molecules with a mass -# matching the target molecule's mass within the allowed error margin are considered -# for further evaluation. -top_k: 30 - -# Error tolerance in parts per million for mass-matching to consider a molecule "correct". -# Used in rules that evaluate the correctness of generated or sampled molecules against -# known test molecules based on mass. -err_ppm: 10 - -# Specifies minimum frequency thresholds for inclusion. -# Each value represents the minimum number of times a molecule must be generated -# across all folds to be considered for further evaluation. -structural_prior_min_freq: - - 1 - -# seed used as a global random seed for steps not covered by 'train_seeds' or 'sample_seeds'. -random_seed: 5831 - -# A dictionary that defines various input and output file paths, incorporating wildcards. -paths: - # Modify these paths to match your system. - - # Base directory for outputs - output_dir: '/Genomics/argo/users/vg8892/git/CLM/workflow/data_spectraverse_allv1_transformer_cv' - # The input dataset file. - dataset: "/Genomics/argo/users/vg8892/git/CLM/data/spectraverse_allv1.txt" - # File containing data from PubChem. - pubchem_tsv_file: "../tests/test_data/PubChem_truncated.tsv" - - # The following paths can be modified, as long as all wildcards are preserved in each case. - - # Processed dataset before augmentation/training. - preprocess_output: "{output_dir}/prior/raw/{dataset}.txt" - # Training file for each cross-validation fold. - train_file: "{output_dir}/{enum_factor}/prior/inputs/train_{dataset}_{repr}_{fold}.smi" - # Vocabulary file for the tokenized sequences. - vocab_file: "{output_dir}/{enum_factor}/prior/inputs/train_{dataset}_{repr}_{fold}.vocabulary" - # Trained RNN model checkpoint file. - model_file: "{output_dir}/{enum_factor}/prior/models/{dataset}_{repr}_{fold}_{train_seed}_model.pt" - # Sampled dataset for each fold. - input_file: "{output_dir}/{enum_factor}/prior/samples/{dataset}_{repr}_{fold}_{train_seed}_{sample_seed}_samples.csv.gz" - # Unaugmented training dataset for each cross-validation fold. - train0_file: "{output_dir}/{enum_factor}/prior/inputs/train0_{dataset}_{repr}_{fold}.smi" - # Unaugmented test dataset for each cross-validation fold. - test0_file: "{output_dir}/{enum_factor}/prior/inputs/test0_{dataset}_{repr}_{fold}.smi" - # A file generated by add_carbon rule, inserting carbon symbols at random spots in training SMILES. - # carbon_file: "{output_dir}/{enum_factor}/prior/inputs/train0_{dataset}_{repr}_{fold}_carbon.csv.gz" - # Complete training dataset aggregated across all folds. - train_all_file: "{output_dir}/{enum_factor}/prior/inputs/train_{dataset}_{repr}_all.smi" - # Complete testing dataset aggregated across all folds. - test_all_file: "{output_dir}/{enum_factor}/prior/inputs/test_{dataset}_{repr}_all.smi" - # Complete aggregated SMILES from add_carbon rule across all folds. - # carbon_all_file: "{output_dir}/{enum_factor}/prior/inputs/train_{dataset}_{repr}_carbon_all.csv.gz" - # Top-n candidate SMILES based on matching by exact mass for each cross-validation fold. - cv_ranks_file: "{output_dir}/{enum_factor}/prior/structural_prior/{dataset}_{repr}_{fold}_CV_ranks_structure.csv.gz" - # Top-n candidate SMILES based on matching mass including Tanimoto coefficient for each cross-validation fold. - cv_tc_file: "{output_dir}/{enum_factor}/prior/structural_prior/{dataset}_{repr}_{fold}_CV_tc.csv.gz" - # Top-n candidate SMILES (correctness based on formula rather than structure) for each cross-validation fold. - formula_ranks_file: "{output_dir}/{enum_factor}/prior/structural_prior/{dataset}_{repr}_{fold}_CV_ranks_formula.csv.gz" - # Sampled SMILES aggregated across all folds. - process_tabulated_output: "{output_dir}/{enum_factor}/prior/samples/{dataset}_{repr}_processed_min{min_freq}_{metric}.csv.gz" - # Loss curves from model training. - loss_file: "{output_dir}/{enum_factor}/prior/models/{dataset}_{repr}_{fold}_{train_seed}_loss.csv.gz" - # Novel SMILES generated by each trained model, along with inchikey, mass and formula. - tabulate_molecules_output: "{output_dir}/{enum_factor}/prior/samples/{dataset}_{repr}_{fold}_{train_seed}_{sample_seed}_samples_masses.csv.gz" - # Aggregated sampled SMILES from all the trained models in a fold. - collect_tabulated_output: "{output_dir}/{enum_factor}/prior/samples/{dataset}_{repr}_{fold}_unique_masses.csv.gz" - # Top-n candidate SMILES based on matching mass across all folds. - overall_ranks_file: "{output_dir}/{enum_factor}/prior/structural_prior/{dataset}_{repr}_min{min_freq}_all_{metric}_CV_ranks_structure.csv.gz" - # Top-n candidate SMILES based on matching mass including tc per fold. - overall_tc_file: "{output_dir}/{enum_factor}/prior/structural_prior/{dataset}_{repr}_min{min_freq}_all_{metric}_CV_tc.csv.gz" - # Sampled molecules per trained model that appear in training set. - known_smiles_file: "{output_dir}/{enum_factor}/prior/samples/known_{dataset}_{repr}_{fold}_{train_seed}_{sample_seed}_samples_masses.csv.gz" - # Invalid SMILES sampled per trained model. - invalid_smiles_file: "{output_dir}/{enum_factor}/prior/samples/invalid_{dataset}_{repr}_{fold}_{train_seed}_{sample_seed}_samples_masses.csv.gz" - # Known (training set) sampled molecules within a fold. - collect_known_smiles: "{output_dir}/{enum_factor}/prior/samples/known_{dataset}_{repr}_{fold}_unique_masses.csv.gz" - # Invalid sampled SMILES within a fold. - collect_invalid_smiles: "{output_dir}/{enum_factor}/prior/samples/invalid_{dataset}_{repr}_{fold}_unique_masses.csv.gz" From d87e8ce0a1fdfa1fb12bed9ad1f855d204b7ea2b Mon Sep 17 00:00:00 2001 From: Vishu Gupta Date: Thu, 1 Jan 2026 23:00:40 -0500 Subject: [PATCH 12/21] temp fix until s4 pr --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 24b428d7..e5502ffb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -130,3 +130,4 @@ yte==1.5.4 zipp==3.19.0 einops==0.6.0 opt_einsum==3.3.0 +s4dd @ git+https://github.com/GuptaVishu2002/s4-for-de-novo-drug-design.git@fix-module-library-packaging From d61b97f03191a9380b6a6f6636715244f5086274 Mon Sep 17 00:00:00 2001 From: "Michael A. Skinnider" Date: Tue, 6 Jan 2026 14:20:31 -0500 Subject: [PATCH 13/21] reformat with black --- src/clm/loggers.py | 8 ++++++-- src/clm/models.py | 16 ++++++++++++---- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/src/clm/loggers.py b/src/clm/loggers.py index cc0349bc..1b89de80 100644 --- a/src/clm/loggers.py +++ b/src/clm/loggers.py @@ -42,12 +42,16 @@ def __call__(self, val_loss, model, output_file, step_idx): # Check for NaN/Inf if math.isnan(val_loss) or math.isinf(val_loss): self.nan_counter += 1 - print(f"NaN/Inf loss detected at step {step_idx} ({self.nan_counter}/3)") + print( + f"NaN/Inf loss detected at step {step_idx} ({self.nan_counter}/3)" + ) if self.nan_counter >= 3: self.stop = True print("Stopping training after 3 consecutive NaN/Inf losses.") if self.best_loss is not None: - print(f"Best model (loss={self.best_loss:.4f}) already saved.") + print( + f"Best model (loss={self.best_loss:.4f}) already saved." + ) return # do nothing if early stopping is disabled diff --git a/src/clm/models.py b/src/clm/models.py index c4a1ab3d..53c59636 100644 --- a/src/clm/models.py +++ b/src/clm/models.py @@ -592,7 +592,9 @@ def __init__( super(StructuredStateSpaceSequenceModel, self).__init__() # detect device - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) # vocabulary self.vocabulary = vocabulary @@ -667,7 +669,9 @@ def forward(self, x): def reset_state(self, batch_size, device=None): if device is None: device = self.device - self.recurrent_state = self.model.default_state(batch_size, device=device) + self.recurrent_state = self.model.default_state( + batch_size, device=device + ) def recurrent_step(self, x_t): if x_t.dim() == 1: @@ -748,7 +752,9 @@ def sample( pad_token = self.vocabulary.dictionary[""] # Create start token tensor - inputs = torch.empty(n_sequences).fill_(start_token).long().to(self.device) + inputs = ( + torch.empty(n_sequences).fill_(start_token).long().to(self.device) + ) # Setup loss function loss_fn = nn.NLLLoss(reduction="none", ignore_index=pad_token) @@ -803,7 +809,9 @@ def sample( ) if return_smiles: - outputs = [self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs] + outputs = [ + self.vocabulary.decode(seq.cpu().numpy()) for seq in seqs + ] else: outputs = sequences From ac37db320b2166275dfc86c3ff8d32ac13ba1c3c Mon Sep 17 00:00:00 2001 From: "Michael A. Skinnider" Date: Tue, 6 Jan 2026 15:22:58 -0500 Subject: [PATCH 14/21] remove setup.sh --- setup.sh | 27 --------------------------- 1 file changed, 27 deletions(-) delete mode 100644 setup.sh diff --git a/setup.sh b/setup.sh deleted file mode 100644 index 26870182..00000000 --- a/setup.sh +++ /dev/null @@ -1,27 +0,0 @@ -#!/bin/bash -# setup.sh - -set -e # Exit on any error - -# Initialize conda for bash -eval "$(conda shell.bash hook)" - -# Create and activate environment -conda create --name clm python=3.10 pip -y -conda activate clm - -# Install main requirements -conda env update --file environment.yml - -# Install s4dd from source -if [ ! -d "s4-for-de-novo-drug-design" ]; then - git clone https://github.com/molML/s4-for-de-novo-drug-design.git -fi -cd s4-for-de-novo-drug-design -pip install -e . -cd .. - -# Install CLM package -pip install -e . --no-deps - -echo "Environment setup complete!" From 2621505c5923a1fd77c4083fedac3d5041862cd4 Mon Sep 17 00:00:00 2001 From: skinnider Date: Wed, 7 Jan 2026 07:52:22 -0500 Subject: [PATCH 15/21] add a test for S4 --- .gitignore | 3 +++ tests/test_snakemake_steps.py | 27 +++++++++++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/.gitignore b/.gitignore index 04e1d608..f5236153 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,6 @@ __pycache__ # version file (managed by setuptools-scm) src/clm/_version.py + +# Snakemake logs +.snakemake diff --git a/tests/test_snakemake_steps.py b/tests/test_snakemake_steps.py index 2f7869ac..e1bc6ef2 100644 --- a/tests/test_snakemake_steps.py +++ b/tests/test_snakemake_steps.py @@ -242,6 +242,33 @@ def test_02_train_models_conditional_RNN(tmp_path): # so we simply ensure that this step runs without errors. +def test_02_train_models_S4(tmp_path): + train_models_RNN.train_models_RNN( + representation="SMILES", + rnn_type="S4", + embedding_size=32, + hidden_size=256, + n_layers=3, + dropout=0, + batch_size=64, + learning_rate=0.001, + max_epochs=3, + patience=5000, + log_every_steps=100, + log_every_epochs=1, + sample_mols=100, + input_file=test_dir + / "0/prior/inputs/train_LOTUS_truncated_SMILES_0.smi", + vocab_file=test_dir + / "0/prior/inputs/train_LOTUS_truncated_SMILES_0.vocabulary", + model_file=tmp_path / "LOTUS_truncated_SMILES_0_0_model.pt", + loss_file=tmp_path / "LOTUS_truncated_SMILES_0_0_loss.csv", + smiles_file=None, + ) + # Model loss values can vary between platforms and architectures, + # so we simply ensure that this step runs without errors. + + def test_03_sample_molecules_RNN(tmp_path): output_file = ( tmp_path / "0/prior/samples/LOTUS_truncated_SMILES_0_0_0_samples.csv" From 41092852c43cb8454f50004f9a1f612c79aa8634 Mon Sep 17 00:00:00 2001 From: Michael Skinnider Date: Sat, 17 Jan 2026 18:55:28 -0500 Subject: [PATCH 16/21] try explicitly specifying s4dd --- environment.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/environment.yml b/environment.yml index ab59c2ea..3530ec42 100644 --- a/environment.yml +++ b/environment.yml @@ -13,3 +13,4 @@ dependencies: - pip - pip: - -r requirements.txt + - s4dd @ git+https://github.com/GuptaVishu2002/s4-for-de-novo-drug-design.git@fix-module-library-packaging From a7c8c58ab89ddafc3e06d7529028deefe6eb4a23 Mon Sep 17 00:00:00 2001 From: Michael Skinnider Date: Sat, 17 Jan 2026 19:49:50 -0500 Subject: [PATCH 17/21] revert changes to environment.yml, change pyproject.toml instead --- environment.yml | 1 - pyproject.toml | 3 +++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index 3530ec42..ab59c2ea 100644 --- a/environment.yml +++ b/environment.yml @@ -13,4 +13,3 @@ dependencies: - pip - pip: - -r requirements.txt - - s4dd @ git+https://github.com/GuptaVishu2002/s4-for-de-novo-drug-design.git@fix-module-library-packaging diff --git a/pyproject.toml b/pyproject.toml index 63eff6e5..86aa9acf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,11 +18,14 @@ classifiers = [ dependencies = [ "deepsmiles", + "einops", "fcd_torch", "numpy", + "opt_einsum", "pandas", "pulp<2.8.0", "rdkit", + "s4dd @ git+https://github.com/GuptaVishu2002/s4-for-de-novo-drug-design.git@fix-module-library-packaging", "scikit-learn", "scipy==1.11.1", "selfies", From dfc87841ec9d8a4cb0cdb9cb9912301379f79107 Mon Sep 17 00:00:00 2001 From: Michael Skinnider Date: Sun, 18 Jan 2026 19:00:43 -0500 Subject: [PATCH 18/21] fix loss calculation for S4 and Transformer --- src/clm/models.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/clm/models.py b/src/clm/models.py index 53c59636..0a1019a5 100644 --- a/src/clm/models.py +++ b/src/clm/models.py @@ -693,12 +693,11 @@ def loss(self, batch): padded = padded.to(self.device) # Handle different input formats - # RNN format is typically (seq_len, batch_size) - # S4/Transformer format is typically (batch_size, seq_len) + # RNN collate returns (seq_len, batch_size) format + # S4 model expects (batch_size, seq_len) format + # Always transpose since collate always uses (seq_len, batch_size) if padded.dim() == 2: - if padded.shape[0] > padded.shape[1]: - # Likely (seq_len, batch_size), transpose to (batch_size, seq_len) - padded = padded.transpose(0, 1) + padded = padded.transpose(0, 1) # batch_size = padded.shape[0] # seq_len = padded.shape[1] @@ -1263,6 +1262,11 @@ def loss(self, batch): padded = padded.to(self.device) + # RNN collate returns (seq_len, batch_size) format + # Transformer expects (batch_size, seq_len) format + if padded.dim() == 2: + padded = padded.transpose(0, 1) + # Get actual sequence length from batch actual_seq_len = padded.shape[1] From 9dc77bb786b2e1cf621f7d14e3c624282853e7b6 Mon Sep 17 00:00:00 2001 From: Michael Skinnider Date: Tue, 20 Jan 2026 06:50:19 -0500 Subject: [PATCH 19/21] add a little more runtime to sample from S4 --- workflow/Snakefile_data | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/workflow/Snakefile_data b/workflow/Snakefile_data index ecad1886..fd4c2126 100644 --- a/workflow/Snakefile_data +++ b/workflow/Snakefile_data @@ -206,7 +206,7 @@ rule sample_molecules_RNN: output_file = PATHS['input_file'] resources: mem_mb=12000, - runtime=15+MODEL_PARAMS["sample_mols"]//10000, + runtime=120+MODEL_PARAMS["sample_mols"]//10000, slurm_extra="--gres=gpu:1" shell: 'clm sample_molecules_RNN ' From 9f62a290c0a16007db4f05934d9742ae1e8306e0 Mon Sep 17 00:00:00 2001 From: Vishu Gupta Date: Wed, 21 Jan 2026 14:24:46 -0500 Subject: [PATCH 20/21] Remove the reset_state() call from Transformer.sample() as Transformer does not maintain recurrent state. Also add torch.cuda.empty_cache() and torch.no_grad() for sampling to GPU memory management --- src/clm/models.py | 58 ++++++++++++++++++++++++++--------------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/src/clm/models.py b/src/clm/models.py index 0a1019a5..7271c013 100644 --- a/src/clm/models.py +++ b/src/clm/models.py @@ -1292,7 +1292,10 @@ def sample( descriptors=None, ): # Reset recurrent state before sampling - self.reset_state(n_sequences, device=self.device) + # self.reset_state(n_sequences, device=self.device) + + self.eval() + torch.cuda.empty_cache() # get start/stop tokens start_token = self.vocabulary.dictionary["SOS"] @@ -1315,32 +1318,33 @@ def sample( finished = torch.zeros(n_sequences).byte().to(self.device) log_probs = torch.zeros(n_sequences).to(self.device) sequences = [] - for step in range(self.max_len): - logits = self(inputs)[:, -1, :] - # Clamp logits to prevent inf/nan - logits = torch.clamp(logits, min=-1e4, max=1e4) - prob = F.softmax(logits, dim=-1) - - # Check for invalid values and skip if found - if torch.isnan(prob).any() or torch.isinf(prob).any(): - break + with torch.no_grad(): + for step in range(self.max_len): + logits = self(inputs)[:, -1, :] + # Clamp logits to prevent inf/nan + logits = torch.clamp(logits, min=-1e4, max=1e4) + prob = F.softmax(logits, dim=-1) - outputs = torch.multinomial(prob, num_samples=1) - # append to growing sequence - inputs = torch.cat((inputs, outputs), dim=1) - sequences.append(outputs) - # calculate NLL too - log_prob = F.log_softmax(logits, dim=1) - losses = loss(log_prob, outputs.squeeze(1)) - # zero losses if we are finished sampling - losses[finished.bool()] = 0 - log_probs += losses - # track whether sampling is done for all molecules - finished = torch.ge( - finished + (outputs.squeeze(1) == stop_token), 1 - ) - if torch.prod(finished) == 1: - break + # Check for invalid values and skip if found + if torch.isnan(prob).any() or torch.isinf(prob).any(): + break + + outputs = torch.multinomial(prob, num_samples=1) + # append to growing sequence + inputs = torch.cat((inputs, outputs), dim=1) + sequences.append(outputs) + # calculate NLL too + log_prob = F.log_softmax(logits, dim=1) + losses = loss(log_prob, outputs.squeeze(1)) + # zero losses if we are finished sampling + losses[finished.bool()] = 0 + log_probs += losses + # track whether sampling is done for all molecules + finished = torch.ge( + finished + (outputs.squeeze(1) == stop_token), 1 + ) + if torch.prod(finished) == 1: + break # concatenate sequences and decode seqs = ( @@ -1357,6 +1361,8 @@ def sample( else: outputs = sequences + torch.cuda.empty_cache() + # optionally return losses if return_losses: return outputs, log_probs.detach().cpu().numpy() From a83b67b8d053e288fae0374eab04098727a277be Mon Sep 17 00:00:00 2001 From: Vishu Gupta Date: Wed, 21 Jan 2026 14:35:03 -0500 Subject: [PATCH 21/21] Trim Trailing Whitespace --- src/clm/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/clm/models.py b/src/clm/models.py index 7271c013..ad365a8b 100644 --- a/src/clm/models.py +++ b/src/clm/models.py @@ -1361,7 +1361,7 @@ def sample( else: outputs = sequences - torch.cuda.empty_cache() + torch.cuda.empty_cache() # optionally return losses if return_losses: