Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Dec 5, 2025

📄 12% (0.12x) speedup for DPTImageProcessor._preprocess_segmentation_map in src/transformers/models/dpt/image_processing_dpt.py

⏱️ Runtime : 3.09 milliseconds 2.77 milliseconds (best of 19 runs)

📝 Explanation and details

The optimization adds a fast-path early return for NumPy arrays in the to_numpy_array function, which is frequently called during image preprocessing in the DPT model pipeline.

Key optimization:

  • Added if isinstance(img, np.ndarray) and is_valid_image(img): return img as the first check
  • This bypasses the expensive to_numpy() function call for arrays that are already NumPy arrays

Why this works:
In the original code, even when img was already a np.ndarray, it still went through the to_numpy() function which performs type checking and potential conversions. The line profiler shows that 49 out of 55 calls to to_numpy_array were hitting the to_numpy(img) path, taking 791,614 nanoseconds (41.5% of total time).

Performance impact:

  • The optimization reduces to_numpy_array execution time from 1.91ms to 1.09ms (43% faster)
  • In _preprocess_segmentation_map, the call to to_numpy_array drops from 2.03ms to 1.34ms (34% faster)
  • Overall pipeline speedup of 11%

Test case benefits:
The annotated tests show consistent improvements across all scenarios, with 60-80% speedups for basic NumPy array inputs (which are the most common case in image processing workflows). PIL image inputs see minimal impact since they still follow the original conversion path.

This optimization is particularly effective because image preprocessing often works with arrays that are already in NumPy format from previous pipeline stages, making the fast-path the common case rather than the exception.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 66 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import numpy as np

# imports
import pytest
from PIL import Image

from transformers.models.dpt.image_processing_dpt import DPTImageProcessor


# ----------- UNIT TESTS ------------


@pytest.fixture
def processor():
    return DPTImageProcessor()


# ------------------ BASIC TEST CASES ------------------


def test_basic_2d_numpy_array_identity(processor):
    # Basic 2D numpy array, no resize or reduce_labels
    arr = np.array([[1, 2], [3, 4]], dtype=np.uint8)
    codeflash_output = processor._preprocess_segmentation_map(arr)
    result = codeflash_output  # 24.1μs -> 14.7μs (63.3% faster)


def test_basic_3d_numpy_array_channels_first(processor):
    # 3D array, shape (1, H, W)
    arr = np.array([[[1, 2], [3, 4]]], dtype=np.uint8)
    codeflash_output = processor._preprocess_segmentation_map(arr)
    result = codeflash_output  # 22.1μs -> 12.7μs (73.7% faster)


def test_basic_pil_image_input(processor):
    # Input as PIL Image
    arr = np.array([[1, 2], [3, 4]], dtype=np.uint8)
    img = Image.fromarray(arr)
    codeflash_output = processor._preprocess_segmentation_map(img)
    result = codeflash_output  # 44.9μs -> 43.9μs (2.23% faster)


def test_basic_reduce_labels(processor):
    # Input with zeros, do_reduce_labels=True
    arr = np.array([[0, 1], [2, 0]], dtype=np.uint8)
    codeflash_output = processor._preprocess_segmentation_map(arr, do_reduce_labels=True)
    result = codeflash_output  # 49.7μs -> 37.3μs (33.2% faster)
    # 0 -> 255, then -1, so 255->254, then 254->255
    expected = np.array([[255, 0], [1, 255]], dtype=np.int64)


def test_edge_empty_array(processor):
    # Empty array
    arr = np.array([[]], dtype=np.uint8)
    codeflash_output = processor._preprocess_segmentation_map(arr)
    result = codeflash_output  # 24.1μs -> 14.7μs (64.0% faster)


def test_edge_single_pixel(processor):
    # Single pixel
    arr = np.array([[42]], dtype=np.uint8)
    codeflash_output = processor._preprocess_segmentation_map(arr)
    result = codeflash_output  # 23.1μs -> 14.2μs (63.3% faster)


def test_edge_all_zeros_reduce_labels(processor):
    # All zeros, reduce_labels
    arr = np.zeros((5, 5), dtype=np.uint8)
    codeflash_output = processor._preprocess_segmentation_map(arr, do_reduce_labels=True)
    result = codeflash_output  # 50.1μs -> 38.5μs (30.4% faster)


def test_edge_all_255_reduce_labels(processor):
    # All 255, reduce_labels
    arr = np.full((5, 5), 255, dtype=np.uint8)
    codeflash_output = processor._preprocess_segmentation_map(arr, do_reduce_labels=True)
    result = codeflash_output  # 46.0μs -> 33.6μs (37.2% faster)


def test_edge_2d_vs_3d_equivalence(processor):
    # 2D and 3D (channels_first) should yield same after squeeze
    arr2d = np.array([[5, 6], [7, 8]], dtype=np.uint8)
    arr3d = arr2d[None, ...]
    codeflash_output = processor._preprocess_segmentation_map(arr2d)
    out2d = codeflash_output  # 22.4μs -> 13.6μs (65.5% faster)
    codeflash_output = processor._preprocess_segmentation_map(arr3d)
    out3d = codeflash_output  # 7.92μs -> 4.71μs (68.3% faster)


def test_edge_invalid_ndim(processor):
    # Invalid shape (4D)
    arr = np.zeros((1, 2, 2, 2), dtype=np.uint8)
    with pytest.raises(ValueError):
        processor._preprocess_segmentation_map(arr)  # 16.4μs -> 7.41μs (121% faster)


def test_edge_negative_values_reduce_labels(processor):
    # Negative values, reduce_labels
    arr = np.array([[0, -1], [-2, 255]], dtype=np.int32)
    codeflash_output = processor._preprocess_segmentation_map(arr, do_reduce_labels=True)
    result = codeflash_output  # 46.3μs -> 34.3μs (35.1% faster)
    # 0->255, -1->254, -2->253, 255->254->255
    expected = np.array([[255, 254], [253, 255]], dtype=np.int64)


def test_edge_large_values_reduce_labels(processor):
    # Large values, reduce_labels
    arr = np.array([[300, 0], [1, 254]], dtype=np.int32)
    codeflash_output = processor._preprocess_segmentation_map(arr, do_reduce_labels=True)
    result = codeflash_output  # 45.2μs -> 33.6μs (34.7% faster)
    # 300->299, 0->255, 1->0, 254->253
    expected = np.array([[299, 255], [0, 253]], dtype=np.int64)


# ------------------ LARGE SCALE TEST CASES ------------------


def test_large_scale_512x512(processor):
    # 512x512 array
    arr = np.random.randint(0, 256, (512, 512), dtype=np.uint8)
    codeflash_output = processor._preprocess_segmentation_map(arr)
    result = codeflash_output  # 116μs -> 106μs (9.31% faster)


def test_large_scale_512x512_reduce_labels(processor):
    # 512x512 array, reduce_labels
    arr = np.random.randint(0, 256, (512, 512), dtype=np.uint8)
    codeflash_output = processor._preprocess_segmentation_map(arr, do_reduce_labels=True)
    result = codeflash_output  # 278μs -> 265μs (4.66% faster)


def test_large_scale_3d(processor):
    # 3D array (1, 512, 512)
    arr = np.random.randint(0, 256, (1, 512, 512), dtype=np.uint8)
    codeflash_output = processor._preprocess_segmentation_map(arr)
    result = codeflash_output  # 112μs -> 104μs (7.72% faster)


def test_large_scale_maximum_allowed(processor):
    # Maximum allowed: 1000x1000, dtype uint8, ~1MB
    arr = np.random.randint(0, 256, (1000, 1000), dtype=np.uint8)
    codeflash_output = processor._preprocess_segmentation_map(arr)
    result = codeflash_output  # 415μs -> 392μs (5.82% faster)


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import numpy as np

# imports
import pytest
from PIL import Image

from transformers.models.dpt.image_processing_dpt import DPTImageProcessor


# ---------------------- UNIT TESTS ----------------------

# 1. Basic Test Cases


def test_basic_2d_array_no_resize_no_reduce():
    # Basic: 2D array, no resize, no label reduction
    processor = DPTImageProcessor()
    seg_map = np.array([[1, 2], [3, 4]], dtype=np.uint8)
    codeflash_output = processor._preprocess_segmentation_map(seg_map)
    out = codeflash_output  # 26.6μs -> 16.2μs (63.8% faster)


def test_basic_2d_array_with_reduce_labels():
    # Basic: 2D array, label reduction
    processor = DPTImageProcessor()
    seg_map = np.array([[0, 1], [2, 0]], dtype=np.uint8)
    codeflash_output = processor._preprocess_segmentation_map(seg_map, do_reduce_labels=True)
    out = codeflash_output  # 53.1μs -> 39.3μs (35.2% faster)
    # 0 -> 255, 1 -> 0, 2 -> 1, then -1 for 255 -> 254, then 254->255
    expected = np.array([[255, 0], [1, 255]], dtype=np.int64)


def test_basic_3d_array_channels_first():
    # Basic: 3D array (1, H, W)
    processor = DPTImageProcessor()
    seg_map = np.array([[[1, 2], [3, 4]]], dtype=np.uint8)  # shape (1,2,2)
    codeflash_output = processor._preprocess_segmentation_map(seg_map)
    out = codeflash_output  # 28.1μs -> 15.7μs (79.5% faster)


def test_basic_pil_image_input():
    # Basic: PIL Image input
    processor = DPTImageProcessor()
    seg_map = np.array([[1, 2], [3, 4]], dtype=np.uint8)
    pil_img = Image.fromarray(seg_map)
    codeflash_output = processor._preprocess_segmentation_map(pil_img)
    out = codeflash_output  # 45.5μs -> 46.2μs (1.44% slower)


# 2. Edge Test Cases


def test_edge_empty_array():
    # Edge: Empty array
    processor = DPTImageProcessor()
    seg_map = np.array([[]], dtype=np.uint8)
    codeflash_output = processor._preprocess_segmentation_map(seg_map)
    out = codeflash_output  # 26.8μs -> 16.1μs (66.1% faster)


def test_edge_single_pixel():
    # Edge: Single pixel
    processor = DPTImageProcessor()
    seg_map = np.array([[42]], dtype=np.uint8)
    codeflash_output = processor._preprocess_segmentation_map(seg_map)
    out = codeflash_output  # 26.5μs -> 15.5μs (71.0% faster)


def test_edge_all_zero_reduce_labels():
    # Edge: All zeros, reduce labels
    processor = DPTImageProcessor()
    seg_map = np.zeros((3, 3), dtype=np.uint8)
    codeflash_output = processor._preprocess_segmentation_map(seg_map, do_reduce_labels=True)
    out = codeflash_output  # 53.9μs -> 40.0μs (34.7% faster)
    expected = np.full((3, 3), 255, dtype=np.int64)


def test_edge_max_label_reduce_labels():
    # Edge: Max label value (255), reduce labels
    processor = DPTImageProcessor()
    seg_map = np.full((2, 2), 255, dtype=np.uint8)
    codeflash_output = processor._preprocess_segmentation_map(seg_map, do_reduce_labels=True)
    out = codeflash_output  # 50.4μs -> 34.7μs (45.0% faster)
    # 255 - 1 = 254, then 254 -> 255
    expected = np.full((2, 2), 255, dtype=np.int64)


def test_edge_invalid_shape_raises():
    # Edge: Invalid shape (3D with wrong channel)
    processor = DPTImageProcessor()
    seg_map = np.ones((2, 2, 2), dtype=np.uint8)  # Not (1, H, W)
    with pytest.raises(ValueError):
        processor._preprocess_segmentation_map(seg_map)  # 18.8μs -> 8.26μs (128% faster)


def test_edge_dtype_conversion():
    # Edge: Input dtype float, output should be int64
    processor = DPTImageProcessor()
    seg_map = np.array([[1.1, 2.9], [3.5, 4.0]], dtype=np.float32)
    codeflash_output = processor._preprocess_segmentation_map(seg_map)
    out = codeflash_output  # 30.7μs -> 17.4μs (76.1% faster)


# 3. Large Scale Test Cases


def test_large_scale_512x512():
    # Large scale: 512x512 array
    processor = DPTImageProcessor()
    seg_map = np.random.randint(0, 10, size=(512, 512), dtype=np.uint8)
    codeflash_output = processor._preprocess_segmentation_map(seg_map)
    out = codeflash_output  # 118μs -> 108μs (8.87% faster)


def test_large_scale_512x512_with_reduce_labels():
    # Large scale: 512x512 array, reduce labels
    processor = DPTImageProcessor()
    seg_map = np.random.randint(0, 10, size=(512, 512), dtype=np.uint8)
    codeflash_output = processor._preprocess_segmentation_map(seg_map, do_reduce_labels=True)
    out = codeflash_output  # 1.15ms -> 1.13ms (2.11% faster)


def test_large_scale_3d_array_channels_first():
    # Large scale: 3D array (1, H, W)
    processor = DPTImageProcessor()
    seg_map = np.random.randint(0, 100, size=(1, 512, 512), dtype=np.uint8)
    codeflash_output = processor._preprocess_segmentation_map(seg_map)
    out = codeflash_output  # 115μs -> 105μs (9.22% faster)

To edit these changes git checkout codeflash/optimize-DPTImageProcessor._preprocess_segmentation_map-mishac5l and push.

Codeflash Static Badge

The optimization adds a **fast-path early return** for NumPy arrays in the `to_numpy_array` function, which is frequently called during image preprocessing in the DPT model pipeline.

**Key optimization:**
- Added `if isinstance(img, np.ndarray) and is_valid_image(img): return img` as the first check
- This bypasses the expensive `to_numpy()` function call for arrays that are already NumPy arrays

**Why this works:**
In the original code, even when `img` was already a `np.ndarray`, it still went through the `to_numpy()` function which performs type checking and potential conversions. The line profiler shows that 49 out of 55 calls to `to_numpy_array` were hitting the `to_numpy(img)` path, taking 791,614 nanoseconds (41.5% of total time).

**Performance impact:**
- The optimization reduces `to_numpy_array` execution time from 1.91ms to 1.09ms (**43% faster**)
- In `_preprocess_segmentation_map`, the call to `to_numpy_array` drops from 2.03ms to 1.34ms (**34% faster**)
- Overall pipeline speedup of **11%**

**Test case benefits:**
The annotated tests show consistent improvements across all scenarios, with **60-80% speedups** for basic NumPy array inputs (which are the most common case in image processing workflows). PIL image inputs see minimal impact since they still follow the original conversion path.

This optimization is particularly effective because image preprocessing often works with arrays that are already in NumPy format from previous pipeline stages, making the fast-path the common case rather than the exception.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 December 5, 2025 06:23
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Dec 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant