From 4a0eef7e0751256b452d17b138f0a5154459d592 Mon Sep 17 00:00:00 2001 From: Raffi Khatchadourian Date: Fri, 2 Aug 2024 10:51:39 -0400 Subject: [PATCH] Add transformations. --- custom/callback.py | 1 + custom/layers.py | 2 ++ model.py | 5 +++++ 3 files changed, 8 insertions(+) diff --git a/custom/callback.py b/custom/callback.py index 5866fa9..1293778 100644 --- a/custom/callback.py +++ b/custom/callback.py @@ -21,6 +21,7 @@ def __init__(self, from_logits=False, reduction='none', debug=False, **kwargs): self.debug = debug pass + @tf.function def call(self, y_true, y_pred): y_true = tf.cast(y_true, tf.int32) mask = tf.math.logical_not(tf.math.equal(y_true, par.pad_token)) diff --git a/custom/layers.py b/custom/layers.py index 1506671..3aa9343 100644 --- a/custom/layers.py +++ b/custom/layers.py @@ -239,6 +239,7 @@ def _get_left_embedding(self, len_q, len_k): return e @staticmethod + @tf.function def _qe_masking(qe): mask = tf.sequence_mask( tf.range(qe.shape[-1] -1, qe.shape[-1] - qe.shape[-2] -1, -1), qe.shape[-1]) @@ -248,6 +249,7 @@ def _qe_masking(qe): return mask * qe + @tf.function def _skewing(self, tensor: tf.Tensor): padded = tf.pad(tensor, [[0, 0], [0,0], [0, 0], [1, 0]]) reshaped = tf.reshape(padded, shape=[-1, padded.shape[1], padded.shape[-1], padded.shape[-2]]) diff --git a/model.py b/model.py index 26e3c01..8d7bf7b 100644 --- a/model.py +++ b/model.py @@ -8,6 +8,7 @@ import random import utils from progress.bar import Bar +from tensorflow import function tf.executing_eagerly() @@ -38,6 +39,7 @@ def __init__(self, embedding_dim=256, vocab_size=388+2, num_layer=6, if loader_path is not None: self.load_ckpt_file(loader_path) + @function def call(self, inputs, targets, training=None, eval=None, src_mask=None, trg_mask=None, lookup_mask=None): encoder, weight_encoder = self.Encoder(inputs, training=training, mask=src_mask) decoder, weights = self.Decoder( @@ -81,6 +83,7 @@ def train_on_batch(self, x, y=None, sample_weight=None, class_weight=None, reset return [loss.numpy()]+result_metric # @tf.function + @function def __dist_train_step(self, inp, inp_tar, out_tar, enc_mask, tar_mask, lookup_mask, training): return self._distribution_strategy.experimental_run_v2( self.__train_step, args=(inp, inp_tar, out_tar, enc_mask, tar_mask, lookup_mask, training)) @@ -283,6 +286,7 @@ def __init__(self, embedding_dim=256, vocab_size=388+2, num_layer=6, if loader_path is not None: self.load_ckpt_file(loader_path) + @function def call(self, inputs, training=None, eval=None, lookup_mask=None): decoder, w = self.Decoder(inputs, training=training, mask=lookup_mask) fc = self.fc(decoder) @@ -322,6 +326,7 @@ def train_on_batch(self, x, y=None, sample_weight=None, class_weight=None, reset return [loss.numpy()]+result_metric # @tf.function + @function def __dist_train_step(self, inp_tar, out_tar, lookup_mask, training): return self._distribution_strategy.experimental_run_v2( self.__train_step, args=(inp_tar, out_tar, lookup_mask, training))