diff --git a/keras/src/backend/jax/__init__.py b/keras/src/backend/jax/__init__.py index 89ac0fa71c8c..afae28a7614f 100644 --- a/keras/src/backend/jax/__init__.py +++ b/keras/src/backend/jax/__init__.py @@ -25,6 +25,8 @@ from keras.src.backend.jax.core import shape from keras.src.backend.jax.core import stop_gradient from keras.src.backend.jax.core import vectorized_map +from keras.src.backend.jax.nn import adaptive_avg_pool +from keras.src.backend.jax.nn import adaptive_max_pool from keras.src.backend.jax.rnn import cudnn_ok from keras.src.backend.jax.rnn import gru from keras.src.backend.jax.rnn import lstm diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index 15cc90f73747..7597e4650ada 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -1464,3 +1464,368 @@ def _pair(x): # ---- reshape -> (N, C*kH*kW, L) ---- _, CKK, oH, oW = patches.shape return patches.reshape(N, CKK, oH * oW) + + +def get_static_window_sizes(input_dim, output_dim): + """Calculate small and big window sizes for adaptive pooling.""" + small_window = math.ceil(input_dim / output_dim) + big_window = small_window + 1 + return small_window, big_window + + +def compute_static_gather_indices(input_dim, output_size, big_window): + """Compute gather indices for Two-Pool Gather method.""" + window_starts = jnp.floor( + (jnp.arange(output_size) * input_dim) / output_size + ).astype(jnp.int32) + + window_ends = jnp.ceil( + (jnp.arange(1, output_size + 1) * input_dim) / output_size + ).astype(jnp.int32) + + window_sizes = window_ends - window_starts + is_big_window = window_sizes == big_window + + small_window = big_window - 1 + small_pool_len = input_dim - small_window + 1 + + small_indices = window_starts + big_indices = window_starts + small_pool_len + + gather_indices = jnp.where(is_big_window, big_indices, small_indices) + return gather_indices.astype(jnp.int32) + + +# ---------- 1D Adaptive Pooling ---------- +def adaptive_avg_pool1d(inputs, output_size, data_format="channels_first"): + """Adaptive Average Pooling 1D using Two-Pool Gather method.""" + if isinstance(output_size, int): + output_size = (output_size,) + + if data_format == "channels_first": + inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL -> NLC + + n, l, c = inputs.shape + out_l = output_size[0] + + small_l, big_l = get_static_window_sizes(l, out_l) + gather_l = compute_static_gather_indices(l, out_l, big_l) + + small_pool_l = lax.reduce_window( + inputs, 0.0, lax.add, (1, small_l, 1), (1, 1, 1), "valid" + ) + small_pool_l = small_pool_l / small_l + + big_pool_l = lax.reduce_window( + inputs, 0.0, lax.add, (1, big_l, 1), (1, 1, 1), "valid" + ) + big_pool_l = big_pool_l / big_l + + combined_l = jnp.concatenate([small_pool_l, big_pool_l], axis=1) + pooled_l = jnp.take(combined_l, gather_l, axis=1) + + if data_format == "channels_first": + pooled_l = jnp.transpose(pooled_l, (0, 2, 1)) # NLC -> NCL + + return pooled_l + + +def adaptive_max_pool1d(inputs, output_size, data_format="channels_first"): + """Adaptive Max Pooling 1D using Two-Pool Gather method.""" + if isinstance(output_size, int): + output_size = (output_size,) + + if data_format == "channels_first": + inputs = jnp.transpose(inputs, (0, 2, 1)) # NCL -> NLC + + n, l, c = inputs.shape + out_l = output_size[0] + + small_l, big_l = get_static_window_sizes(l, out_l) + gather_l = compute_static_gather_indices(l, out_l, big_l) + + small_pool_l = lax.reduce_window( + inputs, -jnp.inf, lax.max, (1, small_l, 1), (1, 1, 1), "valid" + ) + big_pool_l = lax.reduce_window( + inputs, -jnp.inf, lax.max, (1, big_l, 1), (1, 1, 1), "valid" + ) + + combined_l = jnp.concatenate([small_pool_l, big_pool_l], axis=1) + pooled_l = jnp.take(combined_l, gather_l, axis=1) + + if data_format == "channels_first": + pooled_l = jnp.transpose(pooled_l, (0, 2, 1)) # NLC -> NCL + + return pooled_l + + +# ---------- 2D Adaptive Pooling ---------- +def adaptive_avg_pool2d(inputs, output_size, data_format="channels_first"): + """Adaptive Average Pooling 2D using Two-Pool Gather method.""" + if isinstance(output_size, int): + output_size = (output_size, output_size) + + if data_format == "channels_first": + inputs = jnp.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC + + n, h, w, c = inputs.shape + out_h, out_w = output_size + + small_h, big_h = get_static_window_sizes(h, out_h) + gather_h = compute_static_gather_indices(h, out_h, big_h) + + small_w, big_w = get_static_window_sizes(w, out_w) + gather_w = compute_static_gather_indices(w, out_w, big_w) + + small_pool_h = lax.reduce_window( + inputs, 0.0, lax.add, (1, small_h, 1, 1), (1, 1, 1, 1), "valid" + ) + small_pool_h = small_pool_h / small_h + + big_pool_h = lax.reduce_window( + inputs, 0.0, lax.add, (1, big_h, 1, 1), (1, 1, 1, 1), "valid" + ) + big_pool_h = big_pool_h / big_h + + combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=1) + pooled_h = jnp.take(combined_h, gather_h, axis=1) + + small_pool_w = lax.reduce_window( + pooled_h, 0.0, lax.add, (1, 1, small_w, 1), (1, 1, 1, 1), "valid" + ) + small_pool_w = small_pool_w / small_w + + big_pool_w = lax.reduce_window( + pooled_h, 0.0, lax.add, (1, 1, big_w, 1), (1, 1, 1, 1), "valid" + ) + big_pool_w = big_pool_w / big_w + + combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=2) + pooled_w = jnp.take(combined_w, gather_w, axis=2) + + if data_format == "channels_first": + pooled_w = jnp.transpose(pooled_w, (0, 3, 1, 2)) # NHWC -> NCHW + + return pooled_w + + +def adaptive_max_pool2d(inputs, output_size, data_format="channels_first"): + """Adaptive Max Pooling 2D using Two-Pool Gather method.""" + if isinstance(output_size, int): + output_size = (output_size, output_size) + + if data_format == "channels_first": + inputs = jnp.transpose(inputs, (0, 2, 3, 1)) # NCHW -> NHWC + + n, h, w, c = inputs.shape + out_h, out_w = output_size + + small_h, big_h = get_static_window_sizes(h, out_h) + gather_h = compute_static_gather_indices(h, out_h, big_h) + + small_w, big_w = get_static_window_sizes(w, out_w) + gather_w = compute_static_gather_indices(w, out_w, big_w) + + small_pool_h = lax.reduce_window( + inputs, -jnp.inf, lax.max, (1, small_h, 1, 1), (1, 1, 1, 1), "valid" + ) + big_pool_h = lax.reduce_window( + inputs, -jnp.inf, lax.max, (1, big_h, 1, 1), (1, 1, 1, 1), "valid" + ) + + combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=1) + pooled_h = jnp.take(combined_h, gather_h, axis=1) + + small_pool_w = lax.reduce_window( + pooled_h, -jnp.inf, lax.max, (1, 1, small_w, 1), (1, 1, 1, 1), "valid" + ) + big_pool_w = lax.reduce_window( + pooled_h, -jnp.inf, lax.max, (1, 1, big_w, 1), (1, 1, 1, 1), "valid" + ) + + combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=2) + pooled_w = jnp.take(combined_w, gather_w, axis=2) + + if data_format == "channels_first": + pooled_w = jnp.transpose(pooled_w, (0, 3, 1, 2)) # NHWC -> NCHW + + return pooled_w + + +# ---------- 3D Adaptive Pooling ---------- +def adaptive_avg_pool3d(inputs, output_size, data_format="channels_first"): + """Adaptive Average Pooling 3D using Two-Pool Gather method.""" + if isinstance(output_size, int): + output_size = (output_size, output_size, output_size) + + if data_format == "channels_first": + inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) # NCDHW -> NDHWC + + n, d, h, w, c = inputs.shape + out_d, out_h, out_w = output_size + + small_d, big_d = get_static_window_sizes(d, out_d) + gather_d = compute_static_gather_indices(d, out_d, big_d) + + small_h, big_h = get_static_window_sizes(h, out_h) + gather_h = compute_static_gather_indices(h, out_h, big_h) + + small_w, big_w = get_static_window_sizes(w, out_w) + gather_w = compute_static_gather_indices(w, out_w, big_w) + + small_pool_d = lax.reduce_window( + inputs, 0.0, lax.add, (1, small_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid" + ) + small_pool_d = small_pool_d / small_d + + big_pool_d = lax.reduce_window( + inputs, 0.0, lax.add, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid" + ) + big_pool_d = big_pool_d / big_d + + combined_d = jnp.concatenate([small_pool_d, big_pool_d], axis=1) + pooled_d = jnp.take(combined_d, gather_d, axis=1) + + small_pool_h = lax.reduce_window( + pooled_d, 0.0, lax.add, (1, 1, small_h, 1, 1), (1, 1, 1, 1, 1), "valid" + ) + small_pool_h = small_pool_h / small_h + + big_pool_h = lax.reduce_window( + pooled_d, 0.0, lax.add, (1, 1, big_h, 1, 1), (1, 1, 1, 1, 1), "valid" + ) + big_pool_h = big_pool_h / big_h + + combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=2) + pooled_h = jnp.take(combined_h, gather_h, axis=2) + + small_pool_w = lax.reduce_window( + pooled_h, 0.0, lax.add, (1, 1, 1, small_w, 1), (1, 1, 1, 1, 1), "valid" + ) + small_pool_w = small_pool_w / small_w + + big_pool_w = lax.reduce_window( + pooled_h, 0.0, lax.add, (1, 1, 1, big_w, 1), (1, 1, 1, 1, 1), "valid" + ) + big_pool_w = big_pool_w / big_w + + combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=3) + pooled_w = jnp.take(combined_w, gather_w, axis=3) + + if data_format == "channels_first": + pooled_w = jnp.transpose(pooled_w, (0, 4, 1, 2, 3)) # NDHWC -> NCDHW + + return pooled_w + + +def adaptive_max_pool3d(inputs, output_size, data_format="channels_first"): + """Adaptive Max Pooling 3D using Two-Pool Gather method.""" + if isinstance(output_size, int): + output_size = (output_size, output_size, output_size) + + if data_format == "channels_first": + inputs = jnp.transpose(inputs, (0, 2, 3, 4, 1)) # NCDHW -> NDHWC + + n, d, h, w, c = inputs.shape + out_d, out_h, out_w = output_size + + small_d, big_d = get_static_window_sizes(d, out_d) + gather_d = compute_static_gather_indices(d, out_d, big_d) + + small_h, big_h = get_static_window_sizes(h, out_h) + gather_h = compute_static_gather_indices(h, out_h, big_h) + + small_w, big_w = get_static_window_sizes(w, out_w) + gather_w = compute_static_gather_indices(w, out_w, big_w) + + small_pool_d = lax.reduce_window( + inputs, + -jnp.inf, + lax.max, + (1, small_d, 1, 1, 1), + (1, 1, 1, 1, 1), + "valid", + ) + big_pool_d = lax.reduce_window( + inputs, -jnp.inf, lax.max, (1, big_d, 1, 1, 1), (1, 1, 1, 1, 1), "valid" + ) + + combined_d = jnp.concatenate([small_pool_d, big_pool_d], axis=1) + pooled_d = jnp.take(combined_d, gather_d, axis=1) + + small_pool_h = lax.reduce_window( + pooled_d, + -jnp.inf, + lax.max, + (1, 1, small_h, 1, 1), + (1, 1, 1, 1, 1), + "valid", + ) + big_pool_h = lax.reduce_window( + pooled_d, + -jnp.inf, + lax.max, + (1, 1, big_h, 1, 1), + (1, 1, 1, 1, 1), + "valid", + ) + + combined_h = jnp.concatenate([small_pool_h, big_pool_h], axis=2) + pooled_h = jnp.take(combined_h, gather_h, axis=2) + + small_pool_w = lax.reduce_window( + pooled_h, + -jnp.inf, + lax.max, + (1, 1, 1, small_w, 1), + (1, 1, 1, 1, 1), + "valid", + ) + big_pool_w = lax.reduce_window( + pooled_h, + -jnp.inf, + lax.max, + (1, 1, 1, big_w, 1), + (1, 1, 1, 1, 1), + "valid", + ) + + combined_w = jnp.concatenate([small_pool_w, big_pool_w], axis=3) + pooled_w = jnp.take(combined_w, gather_w, axis=3) + + if data_format == "channels_first": + pooled_w = jnp.transpose(pooled_w, (0, 4, 1, 2, 3)) # NDHWC -> NCDHW + + return pooled_w + + +# ---------- Dispatcher ---------- +def adaptive_avg_pool(inputs, output_size, data_format="channels_first"): + """Dispatcher for adaptive average pooling (1D, 2D, or 3D).""" + ndims = inputs.ndim - 2 + if ndims == 1: + return adaptive_avg_pool1d(inputs, output_size, data_format) + elif ndims == 2: + return adaptive_avg_pool2d(inputs, output_size, data_format) + elif ndims == 3: + return adaptive_avg_pool3d(inputs, output_size, data_format) + else: + raise ValueError( + "adaptive_avg_pool supports 1D, 2D, or 3D inputs only." + ) + + +def adaptive_max_pool(inputs, output_size, data_format="channels_first"): + """Dispatcher for adaptive max pooling (1D, 2D, or 3D).""" + ndims = inputs.ndim - 2 + if ndims == 1: + return adaptive_max_pool1d(inputs, output_size, data_format) + elif ndims == 2: + return adaptive_max_pool2d(inputs, output_size, data_format) + elif ndims == 3: + return adaptive_max_pool3d(inputs, output_size, data_format) + else: + raise ValueError( + "adaptive_max_pool supports 1D, 2D, or 3D inputs only." + ) diff --git a/keras/src/backend/numpy/nn.py b/keras/src/backend/numpy/nn.py index 44f3fb882e12..a5f3e762da4e 100644 --- a/keras/src/backend/numpy/nn.py +++ b/keras/src/backend/numpy/nn.py @@ -1237,3 +1237,19 @@ def _pair(x): # ---- reshape -> (N, C*kH*kW, L) ---- return patches.reshape(N, C * k[0] * k[1], -1) + + +def adaptive_max_pool(inputs, output_size, data_format=None): + """Adaptive max pooling - Numpy backend not yet supported.""" + raise NotImplementedError( + "Adaptive pooling not implemented for Numpy. " + "Use JAX, Torch or Tensorflow backend." + ) + + +def adaptive_avg_pool(inputs, output_size, data_format=None): + """Adaptive average pooling - Numpy backend not yet supported.""" + raise NotImplementedError( + "Adaptive pooling not implemented for Numpy. " + "Use JAX, Torch or Tensorflow backend." + ) diff --git a/keras/src/backend/openvino/nn.py b/keras/src/backend/openvino/nn.py index 2c025825ed82..88b8b746a875 100644 --- a/keras/src/backend/openvino/nn.py +++ b/keras/src/backend/openvino/nn.py @@ -133,6 +133,14 @@ def max_pool( ) +def adaptive_max_pool(inputs, output_size, data_format=None): + """Adaptive max pooling - OpenVINO backend not yet supported.""" + raise NotImplementedError( + "Adaptive pooling not implemented for OpenVINO. " + "Use JAX or Torch backend." + ) + + def average_pool( inputs, pool_size, @@ -145,6 +153,14 @@ def average_pool( ) +def adaptive_avg_pool(inputs, output_size, data_format=None): + """Adaptive average pooling - OpenVINO backend not yet supported.""" + raise NotImplementedError( + "Adaptive pooling not implemented for OpenVINO. " + "Use JAX or Torch backend." + ) + + def _adjust_strides_dilation( x, num_spatial_dims, diff --git a/keras/src/backend/tensorflow/nn.py b/keras/src/backend/tensorflow/nn.py index 8a89e6a6b590..9310719af152 100644 --- a/keras/src/backend/tensorflow/nn.py +++ b/keras/src/backend/tensorflow/nn.py @@ -240,6 +240,281 @@ def max_pool( return outputs +def get_static_window_sizes(input_dim, output_dim): + """Calculate small and big window sizes for adaptive pooling.""" + if input_dim < output_dim: + small_window = 1 + else: + small_window = max(1, math.ceil(input_dim / output_dim)) + + big_window = small_window + 1 + + # Ensure windows don't exceed input dimension + small_window = min(small_window, input_dim) + big_window = min(big_window, input_dim) + + return small_window, big_window + + +def compute_static_gather_indices( + input_dim, output_size, small_window, big_window +): + """Compute gather indices for Two-Pool Gather method (corrected).""" + window_starts = tf.cast( + tf.floor( + tf.cast(tf.range(output_size), tf.float32) + * tf.cast(input_dim, tf.float32) + / tf.cast(output_size, tf.float32) + ), + tf.int32, + ) + window_ends = tf.cast( + tf.math.ceil( + tf.cast(tf.range(1, output_size + 1), tf.float32) + * tf.cast(input_dim, tf.float32) + / tf.cast(output_size, tf.float32) + ), + tf.int32, + ) + + window_ends = tf.minimum(window_ends, input_dim) + window_starts = tf.minimum(window_starts, input_dim - 1) + + window_sizes = window_ends - window_starts + is_big_window = tf.equal(window_sizes, big_window) + + small_pool_len = max(1, input_dim - small_window + 1) + + small_indices = window_starts + big_indices = window_starts + small_pool_len + + gather_indices = tf.where(is_big_window, big_indices, small_indices) + return tf.cast(gather_indices, tf.int32) + + +def adaptive_max_pool1d(inputs, output_size, data_format="channels_first"): + if isinstance(output_size, int): + output_size = (output_size,) + if data_format == "channels_first": + inputs = tf.transpose(inputs, (0, 2, 1)) + + static_shape = inputs.shape.as_list() + l_static = static_shape[1] + out_l = output_size[0] + + if l_static is None: + raise ValueError( + "Input length must be statically known for adaptive pooling" + ) + + small_l, big_l = get_static_window_sizes(l_static, out_l) + gather_l = compute_static_gather_indices(l_static, out_l, small_l, big_l) + + small_pool_l = tf.nn.pool( + inputs, + window_shape=(small_l,), + pooling_type="MAX", + strides=(1,), + padding="VALID", + data_format="NWC", + ) + big_pool_l = tf.nn.pool( + inputs, + window_shape=(big_l,), + pooling_type="MAX", + strides=(1,), + padding="VALID", + data_format="NWC", + ) + + combined_l = tf.concat([small_pool_l, big_pool_l], axis=1) + pooled_l = tf.gather(combined_l, gather_l, axis=1) + + if data_format == "channels_first": + pooled_l = tf.transpose(pooled_l, (0, 2, 1)) + return pooled_l + + +def adaptive_max_pool2d(inputs, output_size, data_format="channels_first"): + """Adaptive Max Pooling 2D using Two-Pool Gather method.""" + if isinstance(output_size, int): + output_size = (output_size, output_size) + + if data_format == "channels_first": + inputs = tf.transpose(inputs, (0, 2, 3, 1)) + + static_shape = inputs.shape.as_list() + h_static = static_shape[1] + w_static = static_shape[2] + out_h, out_w = output_size + + if h_static is None or w_static is None: + raise ValueError( + "Input spatial dimensions must be " + "statically known for adaptive pooling" + ) + + small_h, big_h = get_static_window_sizes(h_static, out_h) + small_w, big_w = get_static_window_sizes(w_static, out_w) + + gather_h = compute_static_gather_indices(h_static, out_h, small_h, big_h) + gather_w = compute_static_gather_indices(w_static, out_w, small_w, big_w) + + small_pool_h = tf.nn.pool( + inputs, + window_shape=(small_h, 1), + pooling_type="MAX", + strides=(1, 1), + padding="VALID", + data_format="NHWC", + ) + big_pool_h = tf.nn.pool( + inputs, + window_shape=(big_h, 1), + pooling_type="MAX", + strides=(1, 1), + padding="VALID", + data_format="NHWC", + ) + + combined_h = tf.concat([small_pool_h, big_pool_h], axis=1) + pooled_h = tf.gather(combined_h, gather_h, axis=1) + + small_pool_w = tf.nn.pool( + pooled_h, + window_shape=(1, small_w), + pooling_type="MAX", + strides=(1, 1), + padding="VALID", + data_format="NHWC", + ) + big_pool_w = tf.nn.pool( + pooled_h, + window_shape=(1, big_w), + pooling_type="MAX", + strides=(1, 1), + padding="VALID", + data_format="NHWC", + ) + + combined_w = tf.concat([small_pool_w, big_pool_w], axis=2) + pooled_w = tf.gather(combined_w, gather_w, axis=2) + + if data_format == "channels_first": + pooled_w = tf.transpose(pooled_w, (0, 3, 1, 2)) + + return pooled_w + + +def adaptive_max_pool3d(inputs, output_size, data_format="channels_first"): + """Adaptive Max Pooling 3D using Two-Pool Gather method.""" + if isinstance(output_size, int): + output_size = (output_size, output_size, output_size) + + if data_format == "channels_first": + inputs = tf.transpose(inputs, (0, 2, 3, 4, 1)) + + static_shape = inputs.shape.as_list() + d_static = static_shape[1] + h_static = static_shape[2] + w_static = static_shape[3] + out_d, out_h, out_w = output_size + + if d_static is None or h_static is None or w_static is None: + raise ValueError( + "Input spatial dimensions must be " + "statically known for adaptive pooling" + ) + + small_d, big_d = get_static_window_sizes(d_static, out_d) + small_h, big_h = get_static_window_sizes(h_static, out_h) + small_w, big_w = get_static_window_sizes(w_static, out_w) + + gather_d = compute_static_gather_indices(d_static, out_d, small_d, big_d) + gather_h = compute_static_gather_indices(h_static, out_h, small_h, big_h) + gather_w = compute_static_gather_indices(w_static, out_w, small_w, big_w) + + small_pool_d = tf.nn.pool( + inputs, + window_shape=(small_d, 1, 1), + pooling_type="MAX", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + big_pool_d = tf.nn.pool( + inputs, + window_shape=(big_d, 1, 1), + pooling_type="MAX", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + + combined_d = tf.concat([small_pool_d, big_pool_d], axis=1) + pooled_d = tf.gather(combined_d, gather_d, axis=1) + + small_pool_h = tf.nn.pool( + pooled_d, + window_shape=(1, small_h, 1), + pooling_type="MAX", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + big_pool_h = tf.nn.pool( + pooled_d, + window_shape=(1, big_h, 1), + pooling_type="MAX", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + + combined_h = tf.concat([small_pool_h, big_pool_h], axis=2) + pooled_h = tf.gather(combined_h, gather_h, axis=2) + + small_pool_w = tf.nn.pool( + pooled_h, + window_shape=(1, 1, small_w), + pooling_type="MAX", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + big_pool_w = tf.nn.pool( + pooled_h, + window_shape=(1, 1, big_w), + pooling_type="MAX", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + + combined_w = tf.concat([small_pool_w, big_pool_w], axis=3) + pooled_w = tf.gather(combined_w, gather_w, axis=3) + + if data_format == "channels_first": + pooled_w = tf.transpose(pooled_w, (0, 4, 1, 2, 3)) + + return pooled_w + + +def adaptive_max_pool(inputs, output_size, data_format="channels_first"): + """Dispatcher for adaptive max pooling (1D, 2D, or 3D).""" + ndims = len(inputs.shape) - 2 + if ndims == 1: + return adaptive_max_pool1d(inputs, output_size, data_format) + elif ndims == 2: + return adaptive_max_pool2d(inputs, output_size, data_format) + elif ndims == 3: + return adaptive_max_pool3d(inputs, output_size, data_format) + else: + raise ValueError( + "adaptive_max_pool supports 1D, 2D, or 3D inputs only." + ) + + def average_pool( inputs, pool_size, @@ -268,6 +543,226 @@ def average_pool( return outputs +def adaptive_avg_pool1d(inputs, output_size, data_format="channels_first"): + if isinstance(output_size, int): + output_size = (output_size,) + if data_format == "channels_first": + inputs = tf.transpose(inputs, (0, 2, 1)) + + static_shape = inputs.shape.as_list() + l_static = static_shape[1] + out_l = output_size[0] + + if l_static is None: + raise ValueError( + "Input length must be statically known for adaptive pooling" + ) + + small_l, big_l = get_static_window_sizes(l_static, out_l) + gather_l = compute_static_gather_indices(l_static, out_l, small_l, big_l) + + small_pool_l = tf.nn.pool( + inputs, + window_shape=(small_l,), + pooling_type="AVG", + strides=(1,), + padding="VALID", + data_format="NWC", + ) + big_pool_l = tf.nn.pool( + inputs, + window_shape=(big_l,), + pooling_type="AVG", + strides=(1,), + padding="VALID", + data_format="NWC", + ) + + combined_l = tf.concat([small_pool_l, big_pool_l], axis=1) + pooled_l = tf.gather(combined_l, gather_l, axis=1) + + if data_format == "channels_first": + pooled_l = tf.transpose(pooled_l, (0, 2, 1)) + return pooled_l + + +def adaptive_avg_pool2d(inputs, output_size, data_format="channels_first"): + if isinstance(output_size, int): + output_size = (output_size, output_size) + + if data_format == "channels_first": + inputs = tf.transpose(inputs, (0, 2, 3, 1)) + + static_shape = inputs.shape.as_list() + h_static = static_shape[1] + w_static = static_shape[2] + out_h, out_w = output_size + + if h_static is None or w_static is None: + raise ValueError( + "Input spatial dimensions must be " + "statically known for adaptive pooling" + ) + + small_h, big_h = get_static_window_sizes(h_static, out_h) + small_w, big_w = get_static_window_sizes(w_static, out_w) + + gather_h = compute_static_gather_indices(h_static, out_h, small_h, big_h) + gather_w = compute_static_gather_indices(w_static, out_w, small_w, big_w) + + small_pool_h = tf.nn.pool( + inputs, + window_shape=(small_h, 1), + pooling_type="AVG", + strides=(1, 1), + padding="VALID", + data_format="NHWC", + ) + big_pool_h = tf.nn.pool( + inputs, + window_shape=(big_h, 1), + pooling_type="AVG", + strides=(1, 1), + padding="VALID", + data_format="NHWC", + ) + + combined_h = tf.concat([small_pool_h, big_pool_h], axis=1) + pooled_h = tf.gather(combined_h, gather_h, axis=1) + + small_pool_w = tf.nn.pool( + pooled_h, + window_shape=(1, small_w), + pooling_type="AVG", + strides=(1, 1), + padding="VALID", + data_format="NHWC", + ) + big_pool_w = tf.nn.pool( + pooled_h, + window_shape=(1, big_w), + pooling_type="AVG", + strides=(1, 1), + padding="VALID", + data_format="NHWC", + ) + + combined_w = tf.concat([small_pool_w, big_pool_w], axis=2) + pooled_w = tf.gather(combined_w, gather_w, axis=2) + + if data_format == "channels_first": + pooled_w = tf.transpose(pooled_w, (0, 3, 1, 2)) + + return pooled_w + + +def adaptive_avg_pool3d(inputs, output_size, data_format="channels_first"): + if isinstance(output_size, int): + output_size = (output_size, output_size, output_size) + + if data_format == "channels_first": + inputs = tf.transpose(inputs, (0, 2, 3, 4, 1)) + + static_shape = inputs.shape.as_list() + d_static = static_shape[1] + h_static = static_shape[2] + w_static = static_shape[3] + out_d, out_h, out_w = output_size + + if d_static is None or h_static is None or w_static is None: + raise ValueError( + "Input spatial dimensions must be " + "statically known for adaptive pooling" + ) + + small_d, big_d = get_static_window_sizes(d_static, out_d) + small_h, big_h = get_static_window_sizes(h_static, out_h) + small_w, big_w = get_static_window_sizes(w_static, out_w) + + gather_d = compute_static_gather_indices(d_static, out_d, small_d, big_d) + gather_h = compute_static_gather_indices(h_static, out_h, small_h, big_h) + gather_w = compute_static_gather_indices(w_static, out_w, small_w, big_w) + + small_pool_d = tf.nn.pool( + inputs, + window_shape=(small_d, 1, 1), + pooling_type="AVG", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + big_pool_d = tf.nn.pool( + inputs, + window_shape=(big_d, 1, 1), + pooling_type="AVG", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + + combined_d = tf.concat([small_pool_d, big_pool_d], axis=1) + pooled_d = tf.gather(combined_d, gather_d, axis=1) + + small_pool_h = tf.nn.pool( + pooled_d, + window_shape=(1, small_h, 1), + pooling_type="AVG", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + big_pool_h = tf.nn.pool( + pooled_d, + window_shape=(1, big_h, 1), + pooling_type="AVG", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + + combined_h = tf.concat([small_pool_h, big_pool_h], axis=2) + pooled_h = tf.gather(combined_h, gather_h, axis=2) + + small_pool_w = tf.nn.pool( + pooled_h, + window_shape=(1, 1, small_w), + pooling_type="AVG", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + big_pool_w = tf.nn.pool( + pooled_h, + window_shape=(1, 1, big_w), + pooling_type="AVG", + strides=(1, 1, 1), + padding="VALID", + data_format="NDHWC", + ) + + combined_w = tf.concat([small_pool_w, big_pool_w], axis=3) + pooled_w = tf.gather(combined_w, gather_w, axis=3) + + if data_format == "channels_first": + pooled_w = tf.transpose(pooled_w, (0, 4, 1, 2, 3)) + + return pooled_w + + +def adaptive_avg_pool(inputs, output_size, data_format="channels_first"): + ndims = len(inputs.shape) - 2 + if ndims == 1: + return adaptive_avg_pool1d(inputs, output_size, data_format) + elif ndims == 2: + return adaptive_avg_pool2d(inputs, output_size, data_format) + elif ndims == 3: + return adaptive_avg_pool3d(inputs, output_size, data_format) + else: + raise ValueError( + "adaptive_avg_pool supports 1D, 2D, or 3D inputs only." + ) + + def _convert_data_format(data_format, ndim): if data_format == "channels_last": if ndim == 3: diff --git a/keras/src/backend/torch/nn.py b/keras/src/backend/torch/nn.py index 85b2a32d5560..3e1e87398336 100644 --- a/keras/src/backend/torch/nn.py +++ b/keras/src/backend/torch/nn.py @@ -384,6 +384,51 @@ def max_pool( return outputs +def adaptive_max_pool(inputs, output_size, data_format=None): + """Adaptive max pooling(1D/2D/3D) with channels_last support.""" + inputs = convert_to_tensor(inputs) + num_spatial_dims = inputs.ndim - 2 + + data_format = backend.standardize_data_format(data_format) + orig_format = data_format + if data_format == "channels_last": + inputs = _transpose_spatial_inputs(inputs) + + if isinstance(output_size, int): + torch_output_size = ( + output_size + if num_spatial_dims == 1 + else (output_size,) * num_spatial_dims + ) + else: + torch_output_size = standardize_tuple( + output_size, num_spatial_dims, "output_size" + ) + + if get_device() == "meta": + inputs = torch.empty( + size=inputs.shape, dtype=inputs.dtype, device="cpu" + ) + + if num_spatial_dims == 1: + res = tnn.adaptive_max_pool1d(inputs, output_size=torch_output_size) + elif num_spatial_dims == 2: + res = tnn.adaptive_max_pool2d(inputs, output_size=torch_output_size) + elif num_spatial_dims == 3: + res = tnn.adaptive_max_pool3d(inputs, output_size=torch_output_size) + else: + raise ValueError( + "Inputs to adaptive max pooling must have ndim=3, 4 or 5, " + f"Received input shape: {inputs.shape}." + ) + + outputs = res[0] if isinstance(res, tuple) else res + + if orig_format == "channels_last": + outputs = _transpose_spatial_outputs(outputs) + return outputs + + def average_pool( inputs, pool_size, @@ -458,6 +503,49 @@ def average_pool( return outputs +def adaptive_avg_pool(inputs, output_size, data_format=None): + """Adaptive average pooling(1D/2D/3D) with channels_last support.""" + inputs = convert_to_tensor(inputs) + num_spatial_dims = inputs.ndim - 2 + + data_format = backend.standardize_data_format(data_format) + orig_format = data_format + if data_format == "channels_last": + inputs = _transpose_spatial_inputs(inputs) + + if isinstance(output_size, int): + torch_output_size = ( + output_size + if num_spatial_dims == 1 + else (output_size,) * num_spatial_dims + ) + else: + torch_output_size = standardize_tuple( + output_size, num_spatial_dims, "output_size" + ) + + if get_device() == "meta": + inputs = torch.empty( + size=inputs.shape, dtype=inputs.dtype, device="cpu" + ) + + if num_spatial_dims == 1: + outputs = tnn.adaptive_avg_pool1d(inputs, output_size=torch_output_size) + elif num_spatial_dims == 2: + outputs = tnn.adaptive_avg_pool2d(inputs, output_size=torch_output_size) + elif num_spatial_dims == 3: + outputs = tnn.adaptive_avg_pool3d(inputs, output_size=torch_output_size) + else: + raise ValueError( + "Inputs to adaptive average pooling must have ndim=3, 4 or 5, " + f"Received input shape: {inputs.shape}." + ) + + if orig_format == "channels_last": + outputs = _transpose_spatial_outputs(outputs) + return outputs + + def conv( inputs, kernel, diff --git a/keras/src/layers/__init__.py b/keras/src/layers/__init__.py index febdcef15a98..e2d1ec0a6479 100644 --- a/keras/src/layers/__init__.py +++ b/keras/src/layers/__init__.py @@ -63,6 +63,18 @@ SpectralNormalization, ) from keras.src.layers.normalization.unit_normalization import UnitNormalization +from keras.src.layers.pooling.adaptive_average_pooling1d import ( + AdaptiveAveragePooling1D, +) +from keras.src.layers.pooling.adaptive_average_pooling2d import ( + AdaptiveAveragePooling2D, +) +from keras.src.layers.pooling.adaptive_average_pooling3d import ( + AdaptiveAveragePooling3D, +) +from keras.src.layers.pooling.adaptive_max_pooling1d import AdaptiveMaxPooling1D +from keras.src.layers.pooling.adaptive_max_pooling2d import AdaptiveMaxPooling2D +from keras.src.layers.pooling.adaptive_max_pooling3d import AdaptiveMaxPooling3D from keras.src.layers.pooling.average_pooling1d import AveragePooling1D from keras.src.layers.pooling.average_pooling2d import AveragePooling2D from keras.src.layers.pooling.average_pooling3d import AveragePooling3D diff --git a/keras/src/layers/pooling/__init__.py b/keras/src/layers/pooling/__init__.py index e69de29bb2d1..ed06581b27d6 100644 --- a/keras/src/layers/pooling/__init__.py +++ b/keras/src/layers/pooling/__init__.py @@ -0,0 +1,12 @@ +from keras.src.layers.pooling.adaptive_average_pooling1d import ( + AdaptiveAveragePooling1D, +) +from keras.src.layers.pooling.adaptive_average_pooling2d import ( + AdaptiveAveragePooling2D, +) +from keras.src.layers.pooling.adaptive_average_pooling3d import ( + AdaptiveAveragePooling3D, +) +from keras.src.layers.pooling.adaptive_max_pooling1d import AdaptiveMaxPooling1D +from keras.src.layers.pooling.adaptive_max_pooling2d import AdaptiveMaxPooling2D +from keras.src.layers.pooling.adaptive_max_pooling3d import AdaptiveMaxPooling3D diff --git a/keras/src/layers/pooling/adaptive_average_pooling1d.py b/keras/src/layers/pooling/adaptive_average_pooling1d.py new file mode 100644 index 000000000000..a6d6deeb41a0 --- /dev/null +++ b/keras/src/layers/pooling/adaptive_average_pooling1d.py @@ -0,0 +1,84 @@ +"""Adaptive Average Pooling 1D layer.""" + +from keras import config +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.AdaptiveAveragePooling1D") +class AdaptiveAveragePooling1D(Layer): + """Adaptive average pooling operation for 1D temporal or spatial data. + + This layer applies an adaptive average pooling operation, which pools the + input such that the output has a target length specified by `output_size`, + regardless of the input length. The kernel size and stride are automatically + computed to achieve the target output size. + + Args: + output_size: Integer specifying the target output length. + data_format: string, either `"channels_last"` or `"channels_first"`. + `"channels_last"` corresponds to inputs with shape + `(batch, length, channels)`. + `"channels_first"` corresponds to inputs with shape + `(batch, channels, length)`. + Defaults to the value found in your Keras config file at + `~/.keras/keras.json`. If never set, `"channels_last"` is used. + + Input shape: + - If `data_format="channels_last"`: 3D tensor + `(batch_size, length, channels)` + - If `data_format="channels_first"`: 3D tensor + `(batch_size, channels, length)` + + Output shape: + - If `data_format="channels_last"`: + `(batch_size, output_length, channels)` + - If `data_format="channels_first"`: + `(batch_size, channels, output_length)` + + Examples: + + >>> import numpy as np + >>> input_seq = np.random.rand(1, 64, 3) + >>> layer = AdaptiveAveragePooling1D(output_size=32) + >>> output_seq = layer(input_seq) + >>> output_seq.shape + (1, 32, 3) + """ + + def __init__(self, output_size, data_format=None, **kwargs): + super().__init__(**kwargs) + if not isinstance(output_size, int): + raise TypeError( + f"`output_size` must be an integer. " + f"Received: {output_size} of type {type(output_size)}" + ) + + self.output_size = output_size + self.data_format = data_format or config.image_data_format() + + if self.data_format not in {"channels_first", "channels_last"}: + raise ValueError( + f"Invalid data_format: {self.data_format}. " + "Must be either 'channels_first' or 'channels_last'." + ) + + def call(self, inputs): + return ops.adaptive_avg_pool( + inputs, output_size=self.output_size, data_format=self.data_format + ) + + def compute_output_shape(self, input_shape): + if self.data_format == "channels_last": + return (input_shape[0], self.output_size, input_shape[2]) + else: # channels_first + return (input_shape[0], input_shape[1], self.output_size) + + def get_config(self): + config_dict = { + "output_size": self.output_size, + "data_format": self.data_format, + } + base_config = super().get_config() + return {**base_config, **config_dict} diff --git a/keras/src/layers/pooling/adaptive_average_pooling2d.py b/keras/src/layers/pooling/adaptive_average_pooling2d.py new file mode 100644 index 000000000000..a2714b33fe5b --- /dev/null +++ b/keras/src/layers/pooling/adaptive_average_pooling2d.py @@ -0,0 +1,112 @@ +"""Adaptive Average Pooling 2D layer.""" + +from keras import config +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.AdaptiveAveragePooling2D") +class AdaptiveAveragePooling2D(Layer): + """Adaptive average pooling operation for 2D spatial data. + + This layer applies an adaptive average pooling operation, which pools the + input such that the output has a target shape specified by `output_size`, + regardless of the input shape. The kernel size and stride are automatically + computed to achieve the target output size. + + Args: + output_size: Integer or tuple of 2 integers, specifying the target + output size `(height, width)`. If a single integer is provided, + the same value is used for both dimensions. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, channels, height, width)`. Defaults to the value found in + your Keras config file at `~/.keras/keras.json`. If never set, then + "channels_last" will be used. + + Input shape: + - If `data_format="channels_last"`: + 4D tensor with shape `(batch_size, height, width, channels)`. + - If `data_format="channels_first"`: + 4D tensor with shape `(batch_size, channels, height, width)`. + + Output shape: + - If `data_format="channels_last"`: + 4D tensor with shape + `(batch_size, output_height, output_width, channels)`. + - If `data_format="channels_first"`: + 4D tensor with shape + `(batch_size, channels, output_height, output_width)`. + + Examples: + + >>> input_img = np.random.rand(1, 64, 64, 3) + >>> layer = keras.layers.AdaptiveAveragePooling2D(output_size=(32, 32)) + >>> output_img = layer(input_img) + >>> output_img.shape + (1, 32, 32, 3) + + >>> # Single integer for square output + >>> layer = keras.layers.AdaptiveAveragePooling2D(output_size=7) + >>> output_img = layer(input_img) + >>> output_img.shape + (1, 7, 7, 3) + """ + + def __init__(self, output_size, data_format=None, **kwargs): + super().__init__(**kwargs) + if isinstance(output_size, int): + self.output_size = (output_size, output_size) + elif isinstance(output_size, (list, tuple)): + if len(output_size) != 2: + raise ValueError( + f"`output_size` must be an integer or tuple of 2 integers. " + f"Received: output_size={output_size}" + ) + self.output_size = tuple(output_size) + else: + raise TypeError( + f"`output_size` must be an integer or tuple of 2 integers. " + f"Received: output_size={output_size} of type " + f"{type(output_size)}" + ) + + self.data_format = data_format or config.image_data_format() + + if self.data_format not in {"channels_first", "channels_last"}: + raise ValueError( + f"Invalid data_format: {self.data_format}. " + "Must be either 'channels_first' or 'channels_last'." + ) + + def call(self, inputs): + return ops.adaptive_avg_pool( + inputs, output_size=self.output_size, data_format=self.data_format + ) + + def compute_output_shape(self, input_shape): + if self.data_format == "channels_last": + return ( + input_shape[0], + self.output_size[0], + self.output_size[1], + input_shape[3], + ) + else: # channels_first + return ( + input_shape[0], + input_shape[1], + self.output_size[0], + self.output_size[1], + ) + + def get_config(self): + config_dict = { + "output_size": self.output_size, + "data_format": self.data_format, + } + base_config = super().get_config() + return {**base_config, **config_dict} diff --git a/keras/src/layers/pooling/adaptive_average_pooling3d.py b/keras/src/layers/pooling/adaptive_average_pooling3d.py new file mode 100644 index 000000000000..b2f582301859 --- /dev/null +++ b/keras/src/layers/pooling/adaptive_average_pooling3d.py @@ -0,0 +1,118 @@ +"""Adaptive Average Pooling 3D layer.""" + +from keras import config +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.AdaptiveAveragePooling3D") +class AdaptiveAveragePooling3D(Layer): + """Adaptive average pooling operation for 3D spatial data. + + This layer applies an adaptive average pooling operation, which pools the + input such that the output has a target shape specified by `output_size`, + regardless of the input shape. The kernel size and stride are automatically + computed to achieve the target output size. + + Args: + output_size: Integer or tuple of 3 integers, specifying the target + output size `(depth, height, width)`. + If a single integer is provided, the same value is used for all + three dimensions. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. + `"channels_last"` corresponds to inputs with shape + `(batch, depth, height, width, channels)` while + `"channels_first"` corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + Defaults to the value found in your Keras config file at + `~/.keras/keras.json`. If never set, then "channels_last" is used. + + Input shape: + - If `data_format="channels_last"`: + 5D tensor with shape `(batch_size, depth, height, width, channels)`. + - If `data_format="channels_first"`: + 5D tensor with shape `(batch_size, channels, depth, height, width)`. + + Output shape: + - If `data_format="channels_last"`: + 5D tensor with shape + `(batch_size, output_depth, output_height, output_width, channels)`. + - If `data_format="channels_first"`: + 5D tensor with shape + `(batch_size, channels, output_depth, output_height, output_width)`. + + Examples: + + >>> input_vol = np.random.rand(1, 16, 64, 64, 3) + >>> layer = keras.layers.AdaptiveAveragePooling3D(output_size=(8, 32, 32)) + >>> output_vol = layer(input_vol) + >>> output_vol.shape + (1, 8, 32, 32, 3) + + >>> # Single integer for cubic output + >>> layer = keras.layers.AdaptiveAveragePooling3D(output_size=4) + >>> output_vol = layer(input_vol) + >>> output_vol.shape + (1, 4, 4, 4, 3) + """ + + def __init__(self, output_size, data_format=None, **kwargs): + super().__init__(**kwargs) + + if isinstance(output_size, int): + self.output_size = (output_size, output_size, output_size) + elif isinstance(output_size, (list, tuple)): + if len(output_size) != 3: + raise ValueError( + "`output_size` must be an integer or tuple of 3 integers. " + f"Received output_size={output_size}" + ) + self.output_size = tuple(output_size) + else: + raise TypeError( + "`output_size` must be an integer or tuple of 3 integers. " + "Received output_size={} of type {}".format( + output_size, type(output_size) + ) + ) + + self.data_format = data_format or config.image_data_format() + + if self.data_format not in {"channels_first", "channels_last"}: + raise ValueError( + f"Invalid data_format: {self.data_format}. " + "Must be either 'channels_first' or 'channels_last'." + ) + + def call(self, inputs): + return ops.adaptive_avg_pool( + inputs, output_size=self.output_size, data_format=self.data_format + ) + + def compute_output_shape(self, input_shape): + if self.data_format == "channels_last": + return ( + input_shape[0], + self.output_size[0], + self.output_size[1], + self.output_size[2], + input_shape[4], + ) + else: # channels_first + return ( + input_shape[0], + input_shape[1], + self.output_size[0], + self.output_size[1], + self.output_size[2], + ) + + def get_config(self): + config_dict = { + "output_size": self.output_size, + "data_format": self.data_format, + } + base_config = super().get_config() + return {**base_config, **config_dict} diff --git a/keras/src/layers/pooling/adaptive_max_pooling1d.py b/keras/src/layers/pooling/adaptive_max_pooling1d.py new file mode 100644 index 000000000000..31d67ab27895 --- /dev/null +++ b/keras/src/layers/pooling/adaptive_max_pooling1d.py @@ -0,0 +1,84 @@ +"""Adaptive Max Pooling 1D layer.""" + +from keras import config +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.AdaptiveMaxPooling1D") +class AdaptiveMaxPooling1D(Layer): + """Adaptive max pooling operation for 1D temporal or spatial data. + + This layer applies an adaptive max pooling operation, which pools the + input such that the output has a target length specified by `output_size`, + regardless of the input length. The kernel size and stride are automatically + computed to achieve the target output size. + + Args: + output_size: Integer specifying the target output length. + data_format: string, either `"channels_last"` or `"channels_first"`. + `"channels_last"` corresponds to inputs with shape + `(batch, length, channels)`. + `"channels_first"` corresponds to inputs with shape + `(batch, channels, length)`. + Defaults to the value found in your Keras config file at + `~/.keras/keras.json`. If never set, `"channels_last"` is used. + + Input shape: + - If `data_format="channels_last"`: + 3D tensor `(batch_size, length, channels)`. + - If `data_format="channels_first"`: + 3D tensor `(batch_size, channels, length)`. + + Output shape: + - If `data_format="channels_last"`: + 3D tensor `(batch_size, output_length, channels)`. + - If `data_format="channels_first"`: + 3D tensor `(batch_size, channels, output_length)`. + + Examples: + + >>> import numpy as np + >>> input_seq = np.random.rand(1, 64, 3) + >>> layer = AdaptiveMaxPooling1D(output_size=32) + >>> output_seq = layer(input_seq) + >>> output_seq.shape + (1, 32, 3) + """ + + def __init__(self, output_size, data_format=None, **kwargs): + super().__init__(**kwargs) + + if not isinstance(output_size, int): + raise TypeError( + "`output_size` must be an integer. Received output_size={} " + "of type {}".format(output_size, type(output_size)) + ) + self.output_size = output_size + self.data_format = data_format or config.image_data_format() + + if self.data_format not in {"channels_first", "channels_last"}: + raise ValueError( + "Invalid data_format: {}. Must be either 'channels_first' " + "or 'channels_last'.".format(self.data_format) + ) + + def call(self, inputs): + return ops.adaptive_max_pool( + inputs, output_size=self.output_size, data_format=self.data_format + ) + + def compute_output_shape(self, input_shape): + if self.data_format == "channels_last": + return (input_shape[0], self.output_size, input_shape[2]) + else: # channels_first + return (input_shape[0], input_shape[1], self.output_size) + + def get_config(self): + config_dict = { + "output_size": self.output_size, + "data_format": self.data_format, + } + base_config = super().get_config() + return {**base_config, **config_dict} diff --git a/keras/src/layers/pooling/adaptive_max_pooling2d.py b/keras/src/layers/pooling/adaptive_max_pooling2d.py new file mode 100644 index 000000000000..50f498650d18 --- /dev/null +++ b/keras/src/layers/pooling/adaptive_max_pooling2d.py @@ -0,0 +1,112 @@ +"""Adaptive Max Pooling 2D layer.""" + +from keras import config +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.AdaptiveMaxPooling2D") +class AdaptiveMaxPooling2D(Layer): + """Adaptive max pooling operation for 2D spatial data. + + This layer applies an adaptive max pooling operation, which pools the + input such that the output has a target shape specified by `output_size`, + regardless of the input shape. The kernel size and stride are automatically + computed to achieve the target output size. + + Args: + output_size: Integer or tuple of 2 integers, specifying the target + output size `(height, width)`. If a single integer is provided, + the same value is used for both dimensions. + data_format: string, either `"channels_last"` or `"channels_first"`. + The ordering of the dimensions in the inputs. `"channels_last"` + corresponds to inputs with shape `(batch, height, width, channels)` + while `"channels_first"` corresponds to inputs with shape + `(batch, channels, height, width)`. Defaults to the value found in + your Keras config file at `~/.keras/keras.json`. If never set, then + "channels_last" will be used. + + Input shape: + - If `data_format="channels_last"`: + 4D tensor with shape `(batch_size, height, width, channels)`. + - If `data_format="channels_first"`: + 4D tensor with shape `(batch_size, channels, height, width)`. + + Output shape: + - If `data_format="channels_last"`: + 4D tensor with shape + `(batch_size, output_height, output_width, channels)`. + - If `data_format="channels_first"`: + 4D tensor with shape + `(batch_size, channels, output_height, output_width)`. + + Examples: + + >>> input_img = np.random.rand(1, 64, 64, 3) + >>> layer = keras.layers.AdaptiveMaxPooling2D(output_size=(32, 32)) + >>> output_img = layer(input_img) + >>> output_img.shape + (1, 32, 32, 3) + + >>> # Single integer for square output + >>> layer = keras.layers.AdaptiveMaxPooling2D(output_size=7) + >>> output_img = layer(input_img) + >>> output_img.shape + (1, 7, 7, 3) + """ + + def __init__(self, output_size, data_format=None, **kwargs): + super().__init__(**kwargs) + if isinstance(output_size, int): + self.output_size = (output_size, output_size) + elif isinstance(output_size, (list, tuple)): + if len(output_size) != 2: + raise ValueError( + f"`output_size` must be an integer or tuple of 2 integers. " + f"Received: output_size={output_size}" + ) + self.output_size = tuple(output_size) + else: + raise TypeError( + f"`output_size` must be an integer or tuple of 2 integers. " + f"Received: output_size={output_size} of type " + f"{type(output_size)}" + ) + + self.data_format = data_format or config.image_data_format() + + if self.data_format not in {"channels_first", "channels_last"}: + raise ValueError( + f"Invalid data_format: {self.data_format}. " + "Must be either 'channels_first' or 'channels_last'." + ) + + def call(self, inputs): + return ops.adaptive_max_pool( + inputs, output_size=self.output_size, data_format=self.data_format + ) + + def compute_output_shape(self, input_shape): + if self.data_format == "channels_last": + return ( + input_shape[0], + self.output_size[0], + self.output_size[1], + input_shape[3], + ) + else: # channels_first + return ( + input_shape[0], + input_shape[1], + self.output_size[0], + self.output_size[1], + ) + + def get_config(self): + config_dict = { + "output_size": self.output_size, + "data_format": self.data_format, + } + base_config = super().get_config() + return {**base_config, **config_dict} diff --git a/keras/src/layers/pooling/adaptive_max_pooling3d.py b/keras/src/layers/pooling/adaptive_max_pooling3d.py new file mode 100644 index 000000000000..a8074e5e426f --- /dev/null +++ b/keras/src/layers/pooling/adaptive_max_pooling3d.py @@ -0,0 +1,115 @@ +"""Adaptive Max Pooling 3D layer.""" + +from keras import config +from keras.src import ops +from keras.src.api_export import keras_export +from keras.src.layers.layer import Layer + + +@keras_export("keras.layers.AdaptiveMaxPooling3D") +class AdaptiveMaxPooling3D(Layer): + """Adaptive max pooling operation for 3D spatial data. + + This layer applies an adaptive max pooling operation, which pools the + input such that the output has a target shape specified by `output_size`, + regardless of the input shape. The kernel size and stride are automatically + computed to achieve the target output size. + + Args: + output_size: Integer or tuple of 3 integers specifying the target + output size `(depth, height, width)`. If a single integer is + provided, the same value is used for all three dimensions. + data_format: string, either `"channels_last"` or `"channels_first"`. + `"channels_last"` corresponds to inputs with shape + `(batch, depth, height, width, channels)`. + `"channels_first"` corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + Defaults to the value found in your Keras config file at + `~/.keras/keras.json`. If never set, `"channels_last"` is used. + + Input shape: + - If `data_format="channels_last"`: + 5D tensor with shape `(batch_size, depth, height, width, channels)`. + - If `data_format="channels_first"`: + 5D tensor with shape `(batch_size, channels, depth, height, width)`. + + Output shape: + - If `data_format="channels_last"`: + 5D tensor `(batch_size, output_depth, output_height, + output_width, channels)`. + - If `data_format="channels_first"`: + 5D tensor `(batch_size, channels, output_depth, + output_height, output_width)`. + + Examples: + + >>> import numpy as np + >>> input_vol = np.random.rand(1, 16, 64, 64, 3) + >>> layer = AdaptiveMaxPooling3D(output_size=(8, 32, 32)) + >>> output_vol = layer(input_vol) + >>> output_vol.shape + (1, 8, 32, 32, 3) + + >>> # Single integer for cubic output + >>> layer = AdaptiveMaxPooling3D(output_size=4) + >>> output_vol = layer(input_vol) + >>> output_vol.shape + (1, 4, 4, 4, 3) + """ + + def __init__(self, output_size, data_format=None, **kwargs): + super().__init__(**kwargs) + + if isinstance(output_size, int): + self.output_size = (output_size, output_size, output_size) + elif isinstance(output_size, (list, tuple)): + if len(output_size) != 3: + raise ValueError( + "`output_size` must be an integer or tuple of 3 integers. " + "Received: {}".format(output_size) + ) + self.output_size = tuple(output_size) + else: + raise TypeError( + "`output_size` must be an integer or tuple of 3 integers. " + "Received: {} of type {}".format(output_size, type(output_size)) + ) + + self.data_format = data_format or config.image_data_format() + + if self.data_format not in {"channels_first", "channels_last"}: + raise ValueError( + "Invalid data_format: {}. Must be either 'channels_first' or " + "'channels_last'.".format(self.data_format) + ) + + def call(self, inputs): + return ops.adaptive_max_pool( + inputs, output_size=self.output_size, data_format=self.data_format + ) + + def compute_output_shape(self, input_shape): + if self.data_format == "channels_last": + return ( + input_shape[0], + self.output_size[0], + self.output_size[1], + self.output_size[2], + input_shape[4], + ) + else: # channels_first + return ( + input_shape[0], + input_shape[1], + self.output_size[0], + self.output_size[1], + self.output_size[2], + ) + + def get_config(self): + config_dict = { + "output_size": self.output_size, + "data_format": self.data_format, + } + base_config = super().get_config() + return {**base_config, **config_dict} diff --git a/keras/src/layers/pooling/adaptive_pooling1d_test.py b/keras/src/layers/pooling/adaptive_pooling1d_test.py new file mode 100644 index 000000000000..61bda31cefea --- /dev/null +++ b/keras/src/layers/pooling/adaptive_pooling1d_test.py @@ -0,0 +1,93 @@ +"""Tests for Adaptive Average and Max Pooling 1D layer.""" + +import numpy as np +import pytest + +from keras.src import backend as K +from keras.src import layers +from keras.src import ops +from keras.src import testing + +SKIP_BACKENDS = ["openvino", "numpy"] + +pytestmark = pytest.mark.skipif( + K.backend() in SKIP_BACKENDS, + reason=( + "Adaptive pooling tests not supported for backend: {}".format( + K.backend() + ) + ), +) + +try: + import torch + + TORCH_AVAILABLE = True +except ImportError: + TORCH_AVAILABLE = False + + +class AdaptivePooling1DLayerTest(testing.TestCase): + """Basic tests for AdaptiveAveragePooling1D and AdaptiveMaxPooling1D.""" + + def _run_layer_test(self, layer_class, x_np, output_size, data_format): + layer = layer_class(output_size=output_size, data_format=data_format) + y = layer(x_np) + expected_shape = layer.compute_output_shape(x_np.shape) + self.assertEqual(y.shape, expected_shape) + + def test_average_pooling_basic_shapes(self): + shape = (2, 3, 8) # N,C,L + x = np.random.randn(*shape).astype("float32") + self._run_layer_test( + layers.AdaptiveAveragePooling1D, + x, + output_size=4, + data_format="channels_first", + ) + + def test_max_pooling_basic_shapes(self): + shape = (2, 3, 8) + x = np.random.randn(*shape).astype("float32") + self._run_layer_test( + layers.AdaptiveMaxPooling1D, + x, + output_size=4, + data_format="channels_first", + ) + + +@pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not installed") +@pytest.mark.parametrize("output_size", [1, 2, 3, 4]) +def test_adaptive_avg_pool1d_matches_torch(output_size): + x_np = np.random.randn(2, 3, 8).astype(np.float32) + x_torch = torch.tensor(x_np) + y_torch = torch.nn.functional.adaptive_avg_pool1d(x_torch, output_size) + + x_keras = ops.convert_to_tensor(x_np) + y_keras = ops.adaptive_avg_pool( + x_keras, output_size=output_size, data_format="channels_first" + ) + y_keras_np = np.asarray(y_keras) + + np.testing.assert_allclose( + y_keras_np, y_torch.numpy(), rtol=1e-5, atol=1e-5 + ) + + +@pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not installed") +@pytest.mark.parametrize("output_size", [1, 2, 3, 4]) +def test_adaptive_max_pool1d_matches_torch(output_size): + x_np = np.random.randn(2, 3, 8).astype(np.float32) + x_torch = torch.tensor(x_np) + y_torch = torch.nn.functional.adaptive_max_pool1d(x_torch, output_size) + + x_keras = ops.convert_to_tensor(x_np) + y_keras = ops.adaptive_max_pool( + x_keras, output_size=output_size, data_format="channels_first" + ) + y_keras_np = np.asarray(y_keras) + + np.testing.assert_allclose( + y_keras_np, y_torch.numpy(), rtol=1e-5, atol=1e-5 + ) diff --git a/keras/src/layers/pooling/adaptive_pooling2d_test.py b/keras/src/layers/pooling/adaptive_pooling2d_test.py new file mode 100644 index 000000000000..cd6de8eec5de --- /dev/null +++ b/keras/src/layers/pooling/adaptive_pooling2d_test.py @@ -0,0 +1,93 @@ +"""Tests for Adaptive Average and Max Pooling 2D layer.""" + +import numpy as np +import pytest + +from keras.src import backend as K +from keras.src import layers +from keras.src import ops +from keras.src import testing + +SKIP_BACKENDS = ["openvino", "numpy"] + +pytestmark = pytest.mark.skipif( + K.backend() in SKIP_BACKENDS, + reason=( + "Adaptive pooling tests not supported for backend: {}".format( + K.backend() + ) + ), +) + +try: + import torch + + TORCH_AVAILABLE = True +except ImportError: + TORCH_AVAILABLE = False + + +class AdaptivePooling2DLayerTest(testing.TestCase): + """Basic tests for AdaptiveAveragePooling2D and AdaptiveMaxPooling2D.""" + + def _run_layer_test(self, layer_class, x_np, output_size, data_format): + layer = layer_class(output_size=output_size, data_format=data_format) + y = layer(x_np) + expected_shape = layer.compute_output_shape(x_np.shape) + self.assertEqual(y.shape, expected_shape) + + def test_average_pooling_basic_shapes(self): + shape = (2, 3, 8, 8) # N,C,H,W + x = np.random.randn(*shape).astype("float32") + self._run_layer_test( + layers.AdaptiveAveragePooling2D, + x, + output_size=4, + data_format="channels_first", + ) + + def test_max_pooling_basic_shapes(self): + shape = (2, 3, 8, 8) + x = np.random.randn(*shape).astype("float32") + self._run_layer_test( + layers.AdaptiveMaxPooling2D, + x, + output_size=4, + data_format="channels_first", + ) + + +@pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not installed") +@pytest.mark.parametrize("output_size", [1, 2, 3, 4]) +def test_adaptive_avg_pool2d_matches_torch(output_size): + x_np = np.random.randn(2, 3, 8, 8).astype(np.float32) + x_torch = torch.tensor(x_np) + y_torch = torch.nn.functional.adaptive_avg_pool2d(x_torch, output_size) + + x_keras = ops.convert_to_tensor(x_np) + y_keras = ops.adaptive_avg_pool( + x_keras, output_size=output_size, data_format="channels_first" + ) + y_keras_np = np.asarray(y_keras) + + np.testing.assert_allclose( + y_keras_np, y_torch.numpy(), rtol=1e-5, atol=1e-5 + ) + + +@pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not installed") +@pytest.mark.parametrize("output_size", [1, 2, 3, 4]) +def test_adaptive_max_pool2d_matches_torch(output_size): + x_np = np.random.randn(2, 3, 8, 8).astype(np.float32) + x_torch = torch.tensor(x_np) + y_torch = torch.nn.functional.adaptive_max_pool2d(x_torch, output_size) + + x_keras = ops.convert_to_tensor(x_np) + y_keras = ops.adaptive_max_pool( + x_keras, output_size=output_size, data_format="channels_first" + ) + y_keras_np = np.asarray(y_keras) + + np.testing.assert_allclose( + y_keras_np, y_torch.numpy(), rtol=1e-5, atol=1e-5 + ) diff --git a/keras/src/layers/pooling/adaptive_pooling3d_test.py b/keras/src/layers/pooling/adaptive_pooling3d_test.py new file mode 100644 index 000000000000..188880964229 --- /dev/null +++ b/keras/src/layers/pooling/adaptive_pooling3d_test.py @@ -0,0 +1,93 @@ +"""Tests for Adaptive Average and Max Pooling 3D layer.""" + +import numpy as np +import pytest + +from keras.src import backend as K +from keras.src import layers +from keras.src import ops +from keras.src import testing + +SKIP_BACKENDS = ["openvino", "numpy"] + +pytestmark = pytest.mark.skipif( + K.backend() in SKIP_BACKENDS, + reason=( + "Adaptive pooling tests not supported for backend: {}".format( + K.backend() + ) + ), +) + +try: + import torch + + TORCH_AVAILABLE = True +except ImportError: + TORCH_AVAILABLE = False + + +class AdaptivePooling3DLayerTest(testing.TestCase): + """Basic tests for AdaptiveAveragePooling3D and AdaptiveMaxPooling3D.""" + + def _run_layer_test(self, layer_class, x_np, output_size, data_format): + layer = layer_class(output_size=output_size, data_format=data_format) + y = layer(x_np) + expected_shape = layer.compute_output_shape(x_np.shape) + self.assertEqual(y.shape, expected_shape) + + def test_average_pooling_basic_shapes(self): + shape = (2, 3, 8, 8, 8) # N,C,D,H,W + x = np.random.randn(*shape).astype("float32") + self._run_layer_test( + layers.AdaptiveAveragePooling3D, + x, + output_size=4, + data_format="channels_first", + ) + + def test_max_pooling_basic_shapes(self): + shape = (2, 3, 8, 8, 8) + x = np.random.randn(*shape).astype("float32") + self._run_layer_test( + layers.AdaptiveMaxPooling3D, + x, + output_size=4, + data_format="channels_first", + ) + + +@pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not installed") +@pytest.mark.parametrize("output_size", [1, 2, 3, 4]) +def test_adaptive_avg_pool3d_matches_torch(output_size): + x_np = np.random.randn(2, 3, 8, 8, 8).astype(np.float32) + x_torch = torch.tensor(x_np) + y_torch = torch.nn.functional.adaptive_avg_pool3d(x_torch, output_size) + + x_keras = ops.convert_to_tensor(x_np) + y_keras = ops.adaptive_avg_pool( + x_keras, output_size=output_size, data_format="channels_first" + ) + y_keras_np = np.asarray(y_keras) + + np.testing.assert_allclose( + y_keras_np, y_torch.numpy(), rtol=1e-5, atol=1e-5 + ) + + +@pytest.mark.skipif(not TORCH_AVAILABLE, reason="PyTorch not installed") +@pytest.mark.parametrize("output_size", [1, 2, 3, 4]) +def test_adaptive_max_pool3d_matches_torch(output_size): + x_np = np.random.randn(2, 3, 8, 8, 8).astype(np.float32) + x_torch = torch.tensor(x_np) + y_torch = torch.nn.functional.adaptive_max_pool3d(x_torch, output_size) + + x_keras = ops.convert_to_tensor(x_np) + y_keras = ops.adaptive_max_pool( + x_keras, output_size=output_size, data_format="channels_first" + ) + y_keras_np = np.asarray(y_keras) + + np.testing.assert_allclose( + y_keras_np, y_torch.numpy(), rtol=1e-5, atol=1e-5 + ) diff --git a/keras/src/ops/nn.py b/keras/src/ops/nn.py index 23792400ae4e..a398ce7d8c69 100644 --- a/keras/src/ops/nn.py +++ b/keras/src/ops/nn.py @@ -2,6 +2,7 @@ import warnings +from keras import config from keras.src import backend from keras.src.api_export import keras_export from keras.src.backend import KerasTensor @@ -1162,6 +1163,58 @@ def max_pool( return backend.nn.max_pool(inputs, pool_size, strides, padding, data_format) +@keras_export("keras.ops.adaptive_max_pool") +def adaptive_max_pool( + inputs, + output_size, + data_format=None, +): + """Adaptive max pooling operation. + + Applies an adaptive max pooling operation that automatically computes the + kernel size and stride to pool the input to the specified `output_size`. + This operation is useful when you want a fixed output size regardless of + input size, commonly used in models like ResNet for global feature + extraction. + Args: + inputs: Tensor of rank 4. Input tensor of shape: + - If `data_format="channels_last"`: + `(batch_size, height, width, channels)`. + - If `data_format="channels_first"`: + `(batch_size, channels, height, width)`. + output_size: Integer or tuple/list of 2 integers, specifying the target + output spatial dimensions `(output_height, output_width)`. If a + single + integer is provided, the same value is used for both dimensions. + data_format: string, either `"channels_last"` or `"channels_first"`. + Defaults to the value found in your Keras config file at + `~/.keras/keras.json`. If never set, defaults to `"channels_last"`. + + Returns: + A tensor of rank 4 representing the adaptive max pooled result. + + Example: + + >>> x = np.random.rand(2, 64, 64, 3) + >>> y = keras.ops.adaptive_max_pool(x, output_size=(32, 32)) + >>> y.shape + (2, 32, 32, 3) + + >>> # Works with any input size + >>> x = np.random.rand(2, 100, 80, 3) + >>> y = keras.ops.adaptive_max_pool(x, output_size=7) + >>> y.shape + (2, 7, 7, 3) + """ + if data_format is None: + data_format = config.image_data_format() + return backend.nn.adaptive_max_pool( + inputs, + output_size=output_size, + data_format=data_format, + ) + + class AveragePool(Operation): def __init__( self, @@ -1257,6 +1310,60 @@ def average_pool( ) +@keras_export("keras.ops.adaptive_avg_pool") +def adaptive_avg_pool( + inputs, + output_size, + data_format=None, +): + """Adaptive average pooling operation. + + Applies an adaptive average pooling operation that automatically + computes the + kernel size and stride to pool the input to the specified `output_size`. + This operation is useful when you want a fixed output size regardless of + input size, commonly used in models like ResNet for global feature + extraction. + + Args: + inputs: Tensor of rank 4. Input tensor of shape: + - If `data_format="channels_last"`: + `(batch_size, height, width, channels)`. + - If `data_format="channels_first"`: + `(batch_size, channels, height, width)`. + output_size: Integer or tuple/list of 2 integers, specifying the target + output spatial dimensions `(output_height, output_width)`. If a + single + integer is provided, the same value is used for both dimensions. + data_format: string, either `"channels_last"` or `"channels_first"`. + Defaults to the value found in your Keras config file at + `~/.keras/keras.json`. If never set, defaults to `"channels_last"`. + + Returns: + A tensor of rank 4 representing the adaptive average pooled result. + + Example: + + >>> x = np.random.rand(2, 64, 64, 3) + >>> y = keras.ops.adaptive_avg_pool(x, output_size=(32, 32)) + >>> y.shape + (2, 32, 32, 3) + + >>> # Works with any input size + >>> x = np.random.rand(2, 100, 80, 3) + >>> y = keras.ops.adaptive_avg_pool(x, output_size=7) + >>> y.shape + (2, 7, 7, 3) + """ + if data_format is None: + data_format = config.image_data_format() + return backend.nn.adaptive_avg_pool( + inputs, + output_size=output_size, + data_format=data_format, + ) + + class Conv(Operation): def __init__( self,