Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Dec 4, 2025

📄 94% (0.94x) speedup for compute_conv_output_shape in keras/src/ops/operation_utils.py

⏱️ Runtime : 946 microseconds 487 microseconds (best of 250 runs)

📝 Explanation and details

The optimized code achieves a 94% speedup by eliminating expensive NumPy array operations and replacing them with efficient Python list comprehensions and native arithmetic.

Key optimizations:

  1. Eliminated unnecessary NumPy array conversions: The original code converted spatial_shape, kernel_shape[:-2], and dilation_rate to NumPy arrays, even when only basic indexing and arithmetic were needed. The optimized version keeps these as native Python tuples/lists.

  2. Replaced vectorized NumPy operations with list comprehensions: The most expensive operations were the NumPy vectorized calculations for output_spatial_shape. These are now computed element-wise using list comprehensions with explicit indexing, avoiding NumPy's overhead for small arrays (typically 1-3 dimensions).

  3. Streamlined None dimension handling: Instead of mutating a NumPy array in a loop to handle None dimensions, the optimized version uses a single list comprehension to identify None positions and a tuple comprehension to create the calculation-ready spatial shape.

  4. Eliminated redundant array operations: Removed the final [int(i) for i in output_spatial_shape] conversion since the list comprehensions already produce integers directly.

Why this works: For small arrays (1-3D convolutions are most common), NumPy's vectorization overhead outweighs its benefits. The function references show this is called from convolutional layer constructors during model building, where the 94% speedup significantly improves model initialization time. The optimization is particularly effective for the common test cases with valid/same padding, showing 70-100% improvements across different input configurations.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 50 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import numpy as np
# imports
import pytest  # used for our unit tests
from keras.src.ops.operation_utils import compute_conv_output_shape

# unit tests

# --- Basic Test Cases ---

def test_basic_1d_conv_valid():
    # 1D conv, batch=2, length=10, channels=3, filters=4, kernel=3, stride=1, valid padding
    codeflash_output = compute_conv_output_shape((2, 10, 3), 4, (3,), strides=1, padding="valid"); out = codeflash_output # 19.5μs -> 10.0μs (94.1% faster)

def test_basic_1d_conv_same():
    # 1D conv, batch=2, length=10, channels=3, filters=4, kernel=3, stride=1, same padding
    codeflash_output = compute_conv_output_shape((2, 10, 3), 4, (3,), strides=1, padding="same"); out = codeflash_output # 16.2μs -> 9.32μs (73.3% faster)

def test_basic_2d_conv_valid_channels_last():
    # 2D conv, batch=1, height=28, width=28, channels=1, filters=32, kernel=(3,3), stride=1
    codeflash_output = compute_conv_output_shape((1, 28, 28, 1), 32, (3, 3), strides=1, padding="valid"); out = codeflash_output # 19.8μs -> 11.2μs (75.7% faster)

def test_basic_2d_conv_same_channels_last():
    # 2D conv, batch=1, height=28, width=28, channels=1, filters=32, kernel=(3,3), stride=1, same padding
    codeflash_output = compute_conv_output_shape((1, 28, 28, 1), 32, (3, 3), strides=1, padding="same"); out = codeflash_output # 16.1μs -> 10.1μs (59.4% faster)

def test_basic_2d_conv_channels_first():
    # 2D conv, batch=1, channels=1, height=28, width=28, filters=32, kernel=(3,3), stride=1
    codeflash_output = compute_conv_output_shape((1, 1, 28, 28), 32, (3, 3), strides=1, padding="valid", data_format="channels_first"); out = codeflash_output # 19.5μs -> 11.0μs (77.3% faster)

def test_basic_3d_conv_valid():
    # 3D conv, batch=2, depth=8, height=8, width=8, channels=2, filters=5, kernel=(3,3,3), stride=1
    codeflash_output = compute_conv_output_shape((2, 8, 8, 8, 2), 5, (3, 3, 3), strides=1, padding="valid"); out = codeflash_output # 22.0μs -> 12.5μs (75.4% faster)

def test_basic_stride_and_dilation():
    # 2D conv, batch=1, height=10, width=10, channels=1, filters=1, kernel=(3,3), stride=2, dilation=1
    codeflash_output = compute_conv_output_shape((1, 10, 10, 1), 1, (3, 3), strides=2, padding="valid", dilation_rate=1); out = codeflash_output # 20.3μs -> 11.0μs (84.0% faster)

def test_basic_tuple_stride_and_dilation():
    # 2D conv, batch=1, height=10, width=10, channels=1, filters=1, kernel=(3,3), stride=(2,3), dilation=(1,2)
    codeflash_output = compute_conv_output_shape((1, 10, 10, 1), 1, (3, 3), strides=(2, 3), padding="valid", dilation_rate=(1, 2)); out = codeflash_output # 19.5μs -> 11.0μs (77.8% faster)

# --- Edge Test Cases ---

def test_edge_none_input_shape():
    # Input shape has None for spatial dimension; should propagate None in output
    codeflash_output = compute_conv_output_shape((1, None, 3), 4, (3,), strides=1, padding="valid"); out = codeflash_output # 22.5μs -> 5.78μs (288% faster)

def test_edge_invalid_padding():
    # Padding string not valid; should raise ValueError
    with pytest.raises(ValueError):
        compute_conv_output_shape((1, 10, 3), 4, (3,), strides=1, padding="foobar") # 11.2μs -> 4.77μs (136% faster)

def test_edge_invalid_dilation_length():
    # Dilation tuple length doesn't match spatial dims; should raise ValueError
    with pytest.raises(ValueError):
        compute_conv_output_shape((1, 10, 10, 3), 4, (3, 3), dilation_rate=(1, 2, 3)) # 3.94μs -> 4.42μs (10.7% slower)

def test_edge_invalid_kernel_shape_length():
    # Kernel shape length doesn't match input shape; should raise ValueError
    with pytest.raises(ValueError):
        compute_conv_output_shape((1, 10, 10, 3), 4, (3,), strides=1) # 2.91μs -> 3.16μs (7.98% slower)

def test_edge_negative_output_size():
    # Kernel and dilation so large that output size would be negative; should raise ValueError
    with pytest.raises(ValueError):
        compute_conv_output_shape((1, 5, 5, 3), 4, (6, 6), strides=1, padding="valid", dilation_rate=2) # 81.0μs -> 18.4μs (341% faster)

def test_edge_none_batch_size():
    # Batch size is None; should propagate None in output
    codeflash_output = compute_conv_output_shape((None, 10, 3), 4, (3,), strides=1, padding="valid"); out = codeflash_output # 23.6μs -> 10.7μs (121% faster)

def test_edge_dilation_scalar_and_tuple_equivalence():
    # Dilation as scalar and tuple should yield same result
    codeflash_output = compute_conv_output_shape((1, 10, 3), 4, (3,), strides=1, padding="valid", dilation_rate=2); out1 = codeflash_output # 19.7μs -> 10.1μs (94.3% faster)
    codeflash_output = compute_conv_output_shape((1, 10, 3), 4, (3,), strides=1, padding="valid", dilation_rate=(2,)); out2 = codeflash_output # 8.09μs -> 4.19μs (93.0% faster)

def test_edge_stride_scalar_and_tuple_equivalence():
    # Stride as scalar and tuple should yield same result
    codeflash_output = compute_conv_output_shape((1, 10, 3), 4, (3,), strides=2, padding="valid"); out1 = codeflash_output # 18.1μs -> 8.97μs (102% faster)
    codeflash_output = compute_conv_output_shape((1, 10, 3), 4, (3,), strides=(2,), padding="valid"); out2 = codeflash_output # 7.82μs -> 4.08μs (91.7% faster)

def test_edge_causal_padding_equivalence_to_same():
    # Causal padding is treated as "same" for output shape
    codeflash_output = compute_conv_output_shape((1, 10, 3), 4, (3,), strides=1, padding="same"); out1 = codeflash_output # 15.5μs -> 9.07μs (70.8% faster)
    codeflash_output = compute_conv_output_shape((1, 10, 3), 4, (3,), strides=1, padding="causal"); out2 = codeflash_output # 6.19μs -> 3.72μs (66.3% faster)

# --- Large Scale Test Cases ---

def test_large_2d_conv_large_input():
    # Large 2D input, batch=8, height=512, width=512, channels=3, filters=64, kernel=(5,5), stride=2
    codeflash_output = compute_conv_output_shape((8, 512, 512, 3), 64, (5, 5), strides=2, padding="valid"); out = codeflash_output # 20.0μs -> 10.9μs (83.6% faster)

def test_large_3d_conv_large_input():
    # Large 3D input, batch=4, depth=100, height=100, width=100, channels=8, filters=16, kernel=(3,3,3), stride=1
    codeflash_output = compute_conv_output_shape((4, 100, 100, 100, 8), 16, (3, 3, 3), strides=1, padding="same"); out = codeflash_output # 17.2μs -> 11.1μs (54.9% faster)

def test_large_stride_and_dilation():
    # Large input, large stride and dilation
    codeflash_output = compute_conv_output_shape((2, 1000, 1000, 3), 32, (7, 7), strides=10, dilation_rate=3, padding="valid"); out = codeflash_output # 19.1μs -> 10.6μs (80.3% faster)

def test_large_channels_first():
    # Large input, channels_first format
    codeflash_output = compute_conv_output_shape((16, 3, 256, 256), 128, (5, 5), strides=2, padding="same", data_format="channels_first"); out = codeflash_output # 15.9μs -> 10.1μs (58.1% faster)

def test_large_batch_none():
    # Large input with None batch size
    codeflash_output = compute_conv_output_shape((None, 512, 512, 3), 64, (3, 3), strides=1, padding="same"); out = codeflash_output # 16.2μs -> 10.0μs (61.3% faster)

def test_large_spatial_none():
    # Large input with None spatial dimension
    codeflash_output = compute_conv_output_shape((8, None, 512, 3), 32, (3, 3), strides=1, padding="same"); out = codeflash_output # 20.3μs -> 9.52μs (114% faster)

def test_large_maximum_dimensions():
    # Maximum allowed dimensions (4D spatial, e.g. 6D input)
    input_shape = (1, 10, 20, 30, 40, 50, 3)  # batch, 5 spatial, channels
    kernel_size = (3, 3, 3, 3, 3)
    codeflash_output = compute_conv_output_shape(input_shape, 8, kernel_size, strides=2, padding="valid"); out = codeflash_output # 22.2μs -> 13.5μs (64.3% faster)
    # Output spatial dims: floor((dim - 3 + 1 - 1)/2)+1 for each
    expected = (
        1,
        4,  # (10-3+1-1)/2+1 = (7)/2+1=3+1=4
        9,  # (20-3+1-1)/2+1 = (17)/2+1=8+1=9
        14, # (30-3+1-1)/2+1 = (27)/2+1=13+1=14
        19, # (40-3+1-1)/2+1 = (37)/2+1=18+1=19
        24, # (50-3+1-1)/2+1 = (47)/2+1=23+1=24
        8
    )
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import numpy as np
# imports
import pytest  # used for our unit tests
from keras.src.ops.operation_utils import compute_conv_output_shape

# unit tests

# 1. Basic Test Cases

def test_basic_1d_conv_channels_last_valid():
    # 1D convolution, channels_last, valid padding
    input_shape = (8, 32, 3)  # (batch, length, channels)
    filters = 16
    kernel_size = (5,)
    expected = (8, 28, 16)  # output length: (32-5)+1 = 28
    codeflash_output = compute_conv_output_shape(input_shape, filters, kernel_size) # 20.6μs -> 10.3μs (100% faster)

def test_basic_1d_conv_channels_last_same():
    # 1D convolution, channels_last, same padding
    input_shape = (8, 32, 3)
    filters = 16
    kernel_size = (5,)
    expected = (8, 32, 16)  # output length: 32
    codeflash_output = compute_conv_output_shape(input_shape, filters, kernel_size, padding="same") # 16.5μs -> 9.89μs (67.2% faster)

def test_basic_2d_conv_channels_last_valid():
    # 2D convolution, channels_last, valid padding
    input_shape = (4, 28, 28, 1)
    filters = 32
    kernel_size = (3, 3)
    expected = (4, 26, 26, 32)  # output: (28-3)+1 = 26
    codeflash_output = compute_conv_output_shape(input_shape, filters, kernel_size) # 19.6μs -> 11.2μs (74.9% faster)

def test_basic_2d_conv_channels_first_valid():
    # 2D convolution, channels_first, valid padding
    input_shape = (4, 1, 28, 28)
    filters = 32
    kernel_size = (3, 3)
    expected = (4, 32, 26, 26)
    codeflash_output = compute_conv_output_shape(input_shape, filters, kernel_size, data_format="channels_first") # 21.1μs -> 10.8μs (95.8% faster)

def test_basic_2d_conv_channels_last_stride():
    # 2D convolution, stride=2, valid padding
    input_shape = (4, 28, 28, 1)
    filters = 32
    kernel_size = (3, 3)
    strides = 2
    expected = (4, 13, 13, 32)  # output: floor((28-3)+1)/2 = floor(26/2) = 13
    codeflash_output = compute_conv_output_shape(input_shape, filters, kernel_size, strides=strides) # 20.6μs -> 11.2μs (83.8% faster)

def test_basic_3d_conv_channels_last_valid():
    # 3D convolution, channels_last, valid padding
    input_shape = (2, 10, 20, 30, 3)
    filters = 8
    kernel_size = (2, 4, 6)
    expected = (2, 9, 17, 25, 8)
    codeflash_output = compute_conv_output_shape(input_shape, filters, kernel_size) # 20.1μs -> 11.7μs (72.0% faster)

# 2. Edge Test Cases

def test_edge_kernel_size_equals_input():
    # kernel size equals input size, valid padding
    input_shape = (1, 5, 5, 1)
    filters = 1
    kernel_size = (5, 5)
    expected = (1, 1, 1, 1)
    codeflash_output = compute_conv_output_shape(input_shape, filters, kernel_size) # 19.3μs -> 10.7μs (80.4% faster)

def test_edge_stride_greater_than_input():
    # stride greater than input size, valid padding
    input_shape = (1, 3, 3, 1)
    filters = 2
    kernel_size = (2, 2)
    strides = 4
    expected = (1, 1, 1, 2)  # output: floor((3-2)+1)/4 = floor(2/4)+1 = 0+1 = 1
    codeflash_output = compute_conv_output_shape(input_shape, filters, kernel_size, strides=strides) # 31.8μs -> 18.0μs (76.9% faster)

def test_edge_dilation_rate():
    # 2D convolution with dilation_rate > 1
    input_shape = (1, 10, 10, 1)
    filters = 2
    kernel_size = (3, 3)
    dilation_rate = (2, 2)
    # output: floor((10 - 2*(3-1) - 1)/1) + 1 = floor((10-4-1)/1)+1 = floor(5)+1=6
    expected = (1, 6, 6, 2)
    codeflash_output = compute_conv_output_shape(input_shape, filters, kernel_size, dilation_rate=dilation_rate) # 21.4μs -> 11.8μs (81.2% faster)

def test_edge_dilation_rate_scalar():
    # Dilation rate as scalar
    input_shape = (1, 10, 10, 1)
    filters = 2
    kernel_size = (3, 3)
    dilation_rate = 2
    expected = (1, 6, 6, 2)
    codeflash_output = compute_conv_output_shape(input_shape, filters, kernel_size, dilation_rate=dilation_rate) # 20.6μs -> 11.3μs (82.5% faster)

def test_edge_invalid_dilation_length():
    # Dilation rate tuple of wrong length
    input_shape = (1, 10, 10, 1)
    filters = 2
    kernel_size = (3, 3)
    dilation_rate = (2,)
    with pytest.raises(ValueError):
        compute_conv_output_shape(input_shape, filters, kernel_size, dilation_rate=dilation_rate) # 4.02μs -> 4.29μs (6.11% slower)

def test_edge_invalid_padding():
    # Invalid padding value
    input_shape = (1, 10, 10, 1)
    filters = 2
    kernel_size = (3, 3)
    with pytest.raises(ValueError):
        compute_conv_output_shape(input_shape, filters, kernel_size, padding="invalid") # 8.51μs -> 4.19μs (103% faster)

def test_edge_none_spatial_dim():
    # Input shape with None spatial dimension
    input_shape = (1, None, 10, 1)
    filters = 2
    kernel_size = (3, 3)
    # Output shape should have None for the corresponding spatial dimension
    expected = (1, None, 8, 2)
    codeflash_output = compute_conv_output_shape(input_shape, filters, kernel_size) # 27.7μs -> 13.1μs (111% faster)

def test_edge_none_multiple_spatial_dims():
    # Input shape with multiple None spatial dimensions
    input_shape = (1, None, None, 1)
    filters = 2
    kernel_size = (3, 3)
    expected = (1, None, None, 2)
    codeflash_output = compute_conv_output_shape(input_shape, filters, kernel_size) # 22.1μs -> 5.74μs (285% faster)

def test_edge_channels_first_none_spatial_dim():
    # channels_first, None spatial dimension
    input_shape = (1, 1, None, 10)
    filters = 2
    kernel_size = (3, 3)
    expected = (1, 2, None, 8)
    codeflash_output = compute_conv_output_shape(input_shape, filters, kernel_size, data_format="channels_first") # 20.7μs -> 12.3μs (68.2% faster)

def test_edge_kernel_shape_length_mismatch():
    # Kernel shape length mismatch with input
    input_shape = (1, 10, 10, 1)
    filters = 2
    kernel_size = (3,)
    with pytest.raises(ValueError):
        compute_conv_output_shape(input_shape, filters, kernel_size) # 3.12μs -> 3.24μs (3.55% slower)

def test_edge_causal_padding():
    # 'causal' padding should behave like 'same'
    input_shape = (1, 32, 3)
    filters = 16
    kernel_size = (5,)
    expected = (1, 32, 16)
    codeflash_output = compute_conv_output_shape(input_shape, filters, kernel_size, padding="causal") # 21.7μs -> 12.4μs (74.7% faster)

# 3. Large Scale Test Cases

def test_large_scale_1d_conv():
    # Large 1D input, channels_last
    input_shape = (16, 1000, 8)
    filters = 64
    kernel_size = (7,)
    expected = (16, 994, 64)
    codeflash_output = compute_conv_output_shape(input_shape, filters, kernel_size) # 20.4μs -> 10.5μs (95.1% faster)

def test_large_scale_2d_conv_stride():
    # Large 2D input, stride
    input_shape = (8, 512, 512, 3)
    filters = 128
    kernel_size = (5, 5)
    strides = (2, 2)
    # output: floor((512-5)+1)/2 = floor(508/2) = 254
    expected = (8, 254, 254, 128)
    codeflash_output = compute_conv_output_shape(input_shape, filters, kernel_size, strides=strides) # 20.9μs -> 11.9μs (76.1% faster)

def test_large_scale_3d_conv_channels_first():
    # Large 3D input, channels_first
    input_shape = (4, 8, 100, 200, 300)
    filters = 32
    kernel_size = (3, 5, 7)
    expected = (4, 32, 98, 196, 294)
    codeflash_output = compute_conv_output_shape(input_shape, filters, kernel_size, data_format="channels_first") # 20.5μs -> 12.0μs (71.4% faster)

def test_large_scale_2d_conv_none_dim():
    # Large 2D input with None dimension
    input_shape = (32, None, 512, 3)
    filters = 64
    kernel_size = (5, 5)
    expected = (32, None, 508, 64)
    codeflash_output = compute_conv_output_shape(input_shape, filters, kernel_size) # 23.4μs -> 10.0μs (134% faster)

def test_large_scale_2d_conv_same_padding():
    # Large 2D input, same padding
    input_shape = (16, 1000, 1000, 8)
    filters = 64
    kernel_size = (7, 7)
    expected = (16, 1000, 1000, 64)
    codeflash_output = compute_conv_output_shape(input_shape, filters, kernel_size, padding="same") # 17.3μs -> 10.6μs (62.6% faster)

def test_large_scale_2d_conv_stride_and_dilation():
    # Large 2D input, stride and dilation
    input_shape = (16, 1000, 1000, 8)
    filters = 64
    kernel_size = (7, 7)
    strides = (2, 2)
    dilation_rate = (2, 2)
    # output: floor((1000-2*(7-1)-1)/2)+1 = floor((1000-12-1)/2)+1 = floor(987/2)+1=493+1=494
    expected = (16, 494, 494, 64)
    codeflash_output = compute_conv_output_shape(input_shape, filters, kernel_size, strides=strides, dilation_rate=dilation_rate) # 20.1μs -> 11.2μs (79.6% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-compute_conv_output_shape-mirf96kr and push.

Codeflash Static Badge

The optimized code achieves a 94% speedup by eliminating expensive NumPy array operations and replacing them with efficient Python list comprehensions and native arithmetic.

**Key optimizations:**

1. **Eliminated unnecessary NumPy array conversions**: The original code converted `spatial_shape`, `kernel_shape[:-2]`, and `dilation_rate` to NumPy arrays, even when only basic indexing and arithmetic were needed. The optimized version keeps these as native Python tuples/lists.

2. **Replaced vectorized NumPy operations with list comprehensions**: The most expensive operations were the NumPy vectorized calculations for `output_spatial_shape`. These are now computed element-wise using list comprehensions with explicit indexing, avoiding NumPy's overhead for small arrays (typically 1-3 dimensions).

3. **Streamlined None dimension handling**: Instead of mutating a NumPy array in a loop to handle None dimensions, the optimized version uses a single list comprehension to identify None positions and a tuple comprehension to create the calculation-ready spatial shape.

4. **Eliminated redundant array operations**: Removed the final `[int(i) for i in output_spatial_shape]` conversion since the list comprehensions already produce integers directly.

**Why this works**: For small arrays (1-3D convolutions are most common), NumPy's vectorization overhead outweighs its benefits. The function references show this is called from convolutional layer constructors during model building, where the 94% speedup significantly improves model initialization time. The optimization is particularly effective for the common test cases with valid/same padding, showing 70-100% improvements across different input configurations.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 December 4, 2025 12:38
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Dec 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant