Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 77 additions & 52 deletions scaaml/models/gpam.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,23 @@
import networkx as nx
except ImportError:
nx = None # type: ignore[assignment]
import tensorflow as tf
import keras
from tensorflow.keras import layers
from tensorflow import Tensor
import numpy as np
from keras import layers
from keras.src.backend import KerasTensor


@keras.saving.register_keras_serializable()
class Rescale(layers.Layer): # type: ignore[type-arg]
class Rescale(layers.Layer): # type: ignore[misc,no-any-unimported]
"""Rescale input to the interval [-1, 1].
"""

def __init__(self, trace_min: float, trace_delta: float,
**kwargs: Any) -> None:
def __init__(
self,
trace_min: float,
trace_delta: float,
**kwargs: Any,
) -> None:
"""Information for trace rescaling.

Args:
Expand All @@ -59,14 +63,18 @@ def __init__(self, trace_min: float, trace_delta: float,
self.trace_min: float = trace_min
self.trace_delta: float = trace_delta

def call(self, inputs: Tensor, **kwargs: Any) -> Tensor:
def call( # type: ignore[no-any-unimported]
self,
inputs: KerasTensor,
**kwargs: Any,
) -> KerasTensor:
"""Rescale to the interval [-1, 1]."""
del kwargs # unused
x = inputs
x = 2 * ((x - self.trace_min) / self.trace_delta) - 1
return x

def get_config(self) -> dict[str, Any]:
def get_config(self) -> Any:
"""Return the config to allow saving and loading of the model.
"""
config = super().get_config()
Expand All @@ -78,7 +86,7 @@ def get_config(self) -> dict[str, Any]:


@keras.saving.register_keras_serializable()
class ScaledNorm(layers.Layer): # type: ignore[type-arg]
class ScaledNorm(layers.Layer): # type: ignore[misc,no-any-unimported]
"""ScaledNorm layer.

Transformers without Tears: Improving the Normalization of Self-Attention
Expand All @@ -104,20 +112,23 @@ def __init__(self,
self._scale = self.add_weight(
name="norm_scale",
shape=(),
initializer=tf.constant_initializer(value=1.0),
initializer=keras.initializers.Constant(value=1.0),
trainable=True,
)

def call(self, inputs: Tensor) -> Tensor:
def call( # type: ignore[no-any-unimported]
self,
inputs: KerasTensor,
) -> KerasTensor:
"""Return the output of this layer.
"""
x = inputs
axes = list(range(len(x.shape)))[self._begin_axis:]
mean_square = tf.reduce_mean(tf.math.square(x), axes, keepdims=True)
x = x * tf.math.rsqrt(mean_square + self._epsilon)
mean_square = keras.ops.mean(keras.ops.square(x), axes, keepdims=True)
x = x * keras.ops.rsqrt(mean_square + self._epsilon)
return x * self._scale

def get_config(self) -> dict[str, Any]:
def get_config(self) -> Any:
"""Return the config to allow saving and loading of the model.
"""
config = super().get_config()
Expand All @@ -128,19 +139,20 @@ def get_config(self) -> dict[str, Any]:
return config


def clone_initializer(initializer: tf.keras.initializers.Initializer) -> Any:
def clone_initializer( # type: ignore[no-any-unimported]
initializer: keras.initializers.Initializer,) -> Any:
"""Clone an initializer (if an initializer is reused the generated
weights are the same).
"""
if isinstance(initializer, tf.keras.initializers.Initializer):
if isinstance(initializer, keras.initializers.Initializer):
return initializer.__class__.from_config(initializer.get_config())
return initializer # type: ignore[unreachable]
return initializer


def rope(
x: Tensor,
def rope( # type: ignore[no-any-unimported]
x: KerasTensor,
axis: Union[list[int], int],
) -> Tensor:
) -> KerasTensor:
"""RoPE positional encoding.

Implementation of the Rotary Position Embedding proposed in
Expand All @@ -153,7 +165,10 @@ def rope(
Returns:
The input tensor with RoPE encodings.
"""
shape = x.shape.as_list()
# TensorFlow and JAX treat the shape differently. For the case of
# TensorFlow we need a list otherwise there is a problem in the
# toeplitz_matrix_rope.
shape = list(x.shape)

if isinstance(axis, int):
axis = [axis]
Expand All @@ -162,44 +177,46 @@ def rope(
spatial_shape = [shape[i] for i in axis]
total_len = 1
for i in spatial_shape:
total_len *= i # type: ignore[operator]
position = tf.reshape(
tf.cast(tf.range(total_len, delta=1.0), tf.float32), spatial_shape)
total_len *= i
position = keras.ops.reshape(
keras.ops.cast(keras.ops.arange(total_len), np.float32),
spatial_shape)
else:
raise ValueError(f"Unsupported shape: {shape}")

# we assume that the axis can not be negative (e.g., -1)
if any(dim < 0 for dim in axis):
raise ValueError(f"Unsupported axis: {axis}")
for i in range(axis[-1] + 1, len(shape) - 1, 1):
position = tf.expand_dims(position, axis=-1)
position = keras.ops.expand_dims(position, axis=-1)

half_size = shape[-1] // 2 # type: ignore[operator]
freq_seq = tf.cast(tf.range(half_size), tf.float32) / float(half_size)
half_size = shape[-1] // 2
freq_seq = keras.ops.cast(keras.ops.arange(half_size),
np.float32) / float(half_size)
inv_freq = 10000**-freq_seq
sinusoid = tf.einsum("...,d->...d", position, inv_freq)
sin = tf.cast(tf.sin(sinusoid), dtype=x.dtype)
cos = tf.cast(tf.cos(sinusoid), dtype=x.dtype)
x1, x2 = tf.split(x, 2, axis=-1)
return tf.concat( # type: ignore[no-any-return]
sinusoid = keras.ops.einsum("...,d->...d", position, inv_freq)
sin = keras.ops.cast(keras.ops.sin(sinusoid), dtype=x.dtype)
cos = keras.ops.cast(keras.ops.cos(sinusoid), dtype=x.dtype)
x1, x2 = keras.ops.split(x, 2, axis=-1)
return keras.ops.concatenate(
[x1 * cos - x2 * sin, x2 * cos + x1 * sin],
axis=-1,
)


def toeplitz_matrix_rope(
def toeplitz_matrix_rope( # type: ignore[no-any-unimported]
n: int,
a: Tensor,
b: Tensor,
) -> Tensor:
a: KerasTensor,
b: KerasTensor,
) -> KerasTensor:
"""Obtain Toeplitz matrix using rope."""
a = rope(tf.tile(a[None, :], [n, 1]), axis=[0])
b = rope(tf.tile(b[None, :], [n, 1]), axis=[0])
return tf.einsum("mk,nk->mn", a, b) # type: ignore[no-any-return]
a = rope(keras.ops.tile(a[None, :], [n, 1]), axis=[0])
b = rope(keras.ops.tile(b[None, :], [n, 1]), axis=[0])
return keras.ops.einsum("mk,nk->mn", a, b)


@keras.saving.register_keras_serializable()
class GAU(layers.Layer): # type: ignore[type-arg]
class GAU(layers.Layer): # type: ignore[misc,no-any-unimported]
"""Gated Attention Unit layer introduced in Transformer
Quality in Linear Time.

Expand Down Expand Up @@ -281,7 +298,7 @@ def __init__(
self.spatial_dropout_rate)

# attention activation function
self.attention_activation_layer = tf.keras.layers.Activation(
self.attention_activation_layer = keras.layers.Activation(
self.attention_activation)

def build(self, input_shape: tuple[int, ...]) -> None:
Expand Down Expand Up @@ -334,15 +351,21 @@ def call(self, x: Any, training: bool = False) -> Any:
uv = self.proj1(x)
uv = self.dropout2(uv, training=training)

u, v, base = tf.split(
uv, [self.expand_dim, self.expand_dim, self.shared_dim], axis=-1)
u, v, base = keras.ops.split(
uv,
[self.expand_dim, self.expand_dim + self.expand_dim],
axis=-1,
)
Comment on lines +354 to +358
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While using keras.ops.split is functionally correct, it's less readable than the original tf.split because it uses split indices rather than chunk sizes. For better clarity and maintainability, consider using tensor slicing, which is more explicit about the intended chunk sizes and avoids potential confusion with the different split function semantics between TensorFlow and Keras/NumPy.

        u = uv[..., :self.expand_dim]
        v = uv[..., self.expand_dim:self.expand_dim + self.expand_dim]
        base = uv[..., self.expand_dim + self.expand_dim:]

assert u.shape[-1] == self.expand_dim
assert v.shape[-1] == self.expand_dim
assert base.shape[-1] == self.shared_dim

# generate q, k by scaled offset
base = tf.einsum("bnr,hr->bnhr", base, self.gamma) + self.beta
q, k = tf.unstack(base, axis=-2)
base = keras.ops.einsum("bnr,hr->bnhr", base, self.gamma) + self.beta
q, k = keras.ops.unstack(base, axis=-2)

# compute key-query scores
qk = tf.einsum("bnd,bmd->bnm", q, k)
qk = keras.ops.einsum("bnd,bmd->bnm", q, k)
qk = qk / self.max_len

# add relative position bias for attention
Expand All @@ -355,12 +378,14 @@ def call(self, x: Any, training: bool = False) -> Any:
kernel = self.attention_dropout(kernel)

# apply values and project
x = u * tf.einsum("bnm,bme->bne", kernel, v)
x = u * keras.ops.einsum("bnm,bme->bne", kernel, v)

x = self.proj2(x)
return x + shortcut

def get_config(self) -> dict[str, Any]:
def get_config(self) -> Any:
"""Returns the model config as a dictionary.
"""
config = super().get_config()
config.update({
"dim": self.dim,
Expand All @@ -377,11 +402,11 @@ def get_config(self) -> dict[str, Any]:

@property
def weight_initializer(self) -> Any:
return clone_initializer(tf.random_normal_initializer(stddev=0.02))
return clone_initializer(keras.initializers.RandomNormal(stddev=0.02))

@property
def zeros_initializer(self) -> Any:
return clone_initializer(tf.initializers.zeros())
return clone_initializer(keras.initializers.Zeros())


@keras.saving.register_keras_serializable()
Expand Down Expand Up @@ -434,7 +459,7 @@ def _make_head( # type: ignore[no-any-unimported]

Args:

x (Tensor): Stem of the neural network.
x (KerasTensor): Stem of the neural network.

heads (dict[str, keras.layers.Layer]): A dictionary of previous heads
(those that are sooner in the topologically sorted outputs).
Expand Down Expand Up @@ -545,7 +570,7 @@ def get_topological_order(


def create_heads_outputs( # type: ignore[no-any-unimported]
x: Tensor,
x: KerasTensor,
outputs: dict[str, dict[str, int]],
output_relations: list[tuple[str, str]],
) -> dict[str, keras.layers.Layer]:
Expand Down
6 changes: 4 additions & 2 deletions tests/models/test_gpam.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,17 +71,19 @@ def test_train_save_load(tmp_path):

model.save(save_path)

min_accuracy: float = 0.17

score = model.evaluate(x_test, y_test)
print("[orig] Test loss:", score[0])
print("[orig] Test accuracy:", score[1])
assert score[1] > 0.2
assert score[1] > min_accuracy

loaded_model = keras.models.load_model(save_path)
loaded_model.summary()
score = loaded_model.evaluate(x_test, y_test)
print("[loaded] Test loss:", score[0])
print("[loaded] Test accuracy:", score[1])
assert score[1] > 0.2
assert score[1] > min_accuracy

# Make sure the loaded model is the same layer by layer.
def match(i, x):
Expand Down
Loading