Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Dec 4, 2025

📄 17% (0.17x) speedup for KalmanFilterXYAH.project in ultralytics/trackers/utils/kalman_filter.py

⏱️ Runtime : 10.3 milliseconds 8.82 milliseconds (best of 157 runs)

📝 Explanation and details

The optimized code achieves a 17% speedup through three key optimizations in the project method:

What was optimized:

  1. Reduced redundant computations: The original code calculated self._std_weight_position * mean[3] four times. The optimized version computes this once as std_pos_h and reuses it.

  2. Eliminated intermediate list creation: Instead of creating a Python list std and then calling np.square(std), the optimized version creates a NumPy array directly and uses element-wise multiplication (std * std) for squaring.

  3. Replaced multi_dot with @ operator: Changed np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T)) to self._update_mat @ covariance @ self._update_mat.T, which is more efficient for this specific triple matrix multiplication.

Why it's faster:

  • Computation elimination: Removing 3 redundant multiplications saves CPU cycles, especially important since this involves floating-point operations
  • Memory efficiency: Direct NumPy array creation avoids Python list overhead and the intermediate np.square() call
  • Optimized matrix operations: The @ operator uses more efficient BLAS routines for consecutive matrix multiplications compared to the general-purpose multi_dot

Performance characteristics:

The line profiler shows the most significant improvements in:

  • Innovation covariance calculation: 25.6% → 40% of total time (but absolute time decreased)
  • Matrix multiplication: 54.3% → 27.1% of total time with substantial absolute time reduction

Test results indicate the optimization performs consistently well across all scenarios:

  • Basic cases: 11-21% faster
  • Edge cases (zero/negative heights): 12-27% faster
  • Large scale operations: 16-17% faster

This optimization is particularly valuable in object tracking scenarios where the project method is called frequently for each tracked object at every frame, making the 17% improvement compound significantly over time.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 1955 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime
import numpy as np

# imports
import pytest  # used for our unit tests
from ultralytics.trackers.utils.kalman_filter import KalmanFilterXYAH

# function to test
# (KalmanFilterXYAH.project is defined above, as per the prompt)

# unit tests


class TestKalmanFilterXYAHProject:
    # Basic Test Cases

    def test_basic_identity_covariance(self):
        """Basic: Test with mean and identity covariance."""
        kf = KalmanFilterXYAH()
        mean = np.array([0, 0, 1, 1, 0, 0, 0, 0])
        covariance = np.eye(8)
        projected_mean, projected_cov = kf.project(mean, covariance)  # 25.2μs -> 21.4μs (17.7% faster)

    def test_basic_nontrivial_mean(self):
        """Basic: Test with nonzero mean and diagonal covariance."""
        kf = KalmanFilterXYAH()
        mean = np.array([10, 20, 1.5, 50, 1, 2, 0.1, 0.2])
        covariance = np.eye(8) * 2
        projected_mean, projected_cov = kf.project(mean, covariance)  # 16.2μs -> 13.4μs (20.6% faster)
        # Projected covariance should be positive semi-definite
        eigvals = np.linalg.eigvalsh(projected_cov)

    def test_basic_random_covariance(self):
        """Basic: Test with random positive semi-definite covariance."""
        kf = KalmanFilterXYAH()
        rng = np.random.default_rng(42)
        A = rng.normal(size=(8, 8))
        covariance = np.dot(A, A.T)  # PSD by construction
        mean = np.array([5, -3, 2, 10, 0.5, -0.5, 0.2, -0.2])
        projected_mean, projected_cov = kf.project(mean, covariance)  # 17.0μs -> 15.2μs (11.6% faster)
        eigvals = np.linalg.eigvalsh(projected_cov)

    # Edge Test Cases

    def test_edge_zero_height(self):
        """Edge: Test with zero height (mean[3] == 0)."""
        kf = KalmanFilterXYAH()
        mean = np.array([1, 2, 3, 0, 0, 0, 0, 0])
        covariance = np.eye(8)
        projected_mean, projected_cov = kf.project(mean, covariance)  # 23.9μs -> 19.6μs (21.7% faster)

    def test_edge_negative_height(self):
        """Edge: Test with negative height (mean[3] < 0)."""
        kf = KalmanFilterXYAH()
        mean = np.array([1, 2, 3, -10, 0, 0, 0, 0])
        covariance = np.eye(8)
        projected_mean, projected_cov = kf.project(mean, covariance)  # 21.5μs -> 17.7μs (21.9% faster)
        # Innovation covariance for position and height should be positive (std squared)
        expected = (kf._std_weight_position * mean[3]) ** 2

    def test_edge_large_height(self):
        """Edge: Test with very large height (mean[3] == 1e6)."""
        kf = KalmanFilterXYAH()
        mean = np.array([1, 2, 3, 1e6, 0, 0, 0, 0])
        covariance = np.eye(8)
        projected_mean, projected_cov = kf.project(mean, covariance)  # 16.2μs -> 14.0μs (15.5% faster)
        # Innovation covariance for position and height should be huge
        expected = (kf._std_weight_position * mean[3]) ** 2

    def test_edge_zero_covariance(self):
        """Edge: Test with zero covariance matrix."""
        kf = KalmanFilterXYAH()
        mean = np.array([1, 2, 3, 4, 0, 0, 0, 0])
        covariance = np.zeros((8, 8))
        projected_mean, projected_cov = kf.project(mean, covariance)  # 21.8μs -> 17.1μs (26.8% faster)
        # Projected covariance should be just the innovation covariance
        std = [
            kf._std_weight_position * mean[3],
            kf._std_weight_position * mean[3],
            1e-1,
            kf._std_weight_position * mean[3],
        ]
        expected = np.diag(np.square(std))

    def test_edge_extreme_values(self):
        """Edge: Test with extreme mean values (very large/small floats)."""
        kf = KalmanFilterXYAH()
        mean = np.array([1e20, -1e20, 1e-20, 1e10, 1e5, -1e5, 1e-5, -1e-5])
        covariance = np.eye(8)
        projected_mean, projected_cov = kf.project(mean, covariance)  # 16.0μs -> 13.2μs (20.7% faster)

    def test_edge_nan_inf(self):
        """Edge: Test with NaN and Inf values in mean and covariance."""
        kf = KalmanFilterXYAH()
        mean = np.array([np.nan, np.inf, -np.inf, 1, 0, 0, 0, 0])
        covariance = np.eye(8)
        projected_mean, projected_cov = kf.project(mean, covariance)  # 15.7μs -> 13.3μs (17.9% faster)

    def test_edge_covariance_with_nan_inf(self):
        """Edge: Test with NaN and Inf values in covariance."""
        kf = KalmanFilterXYAH()
        mean = np.array([1, 2, 3, 4, 0, 0, 0, 0])
        covariance = np.eye(8)
        covariance[0, 0] = np.nan
        covariance[1, 1] = np.inf
        projected_mean, projected_cov = kf.project(mean, covariance)  # 21.7μs -> 26.3μs (17.4% slower)

    # Large Scale Test Cases

    def test_large_scale_random(self):
        """Large Scale: Test with many random mean/covariance pairs."""
        kf = KalmanFilterXYAH()
        rng = np.random.default_rng(123)
        for _ in range(100):  # 100 random samples
            mean = rng.normal(size=8) * rng.uniform(1, 100)
            A = rng.normal(size=(8, 8))
            covariance = np.dot(A, A.T) + np.eye(8) * 1e-3  # PSD
            projected_mean, projected_cov = kf.project(mean, covariance)  # 533μs -> 453μs (17.6% faster)

    def test_large_scale_extreme_heights(self):
        """Large Scale: Test with many extreme height values."""
        kf = KalmanFilterXYAH()
        for h in np.linspace(1, 1e3, 100):
            mean = np.array([1, 2, 3, h, 0, 0, 0, 0])
            covariance = np.eye(8)
            projected_mean, projected_cov = kf.project(mean, covariance)  # 532μs -> 451μs (17.8% faster)
            expected = (kf._std_weight_position * h) ** 2

    def test_large_scale_batch(self):
        """Large Scale: Test projecting a batch of means/covariances."""
        kf = KalmanFilterXYAH()
        rng = np.random.default_rng(321)
        means = rng.normal(size=(100, 8))
        covariances = np.array([np.eye(8) for _ in range(100)])
        for i in range(100):
            projected_mean, projected_cov = kf.project(means[i], covariances[i])  # 533μs -> 458μs (16.4% faster)

    # Determinism Test

    def test_determinism(self):
        """Test that repeated calls with same input yield same output."""
        kf = KalmanFilterXYAH()
        mean = np.array([10, 20, 1.5, 50, 1, 2, 0.1, 0.2])
        covariance = np.eye(8) * 2
        codeflash_output = kf.project(mean, covariance)
        out1 = codeflash_output  # 14.4μs -> 12.8μs (12.3% faster)
        codeflash_output = kf.project(mean, covariance)
        out2 = codeflash_output  # 7.43μs -> 5.95μs (25.0% faster)

    # Input shape test

    def test_invalid_mean_shape(self):
        """Test that invalid mean shape raises error."""
        kf = KalmanFilterXYAH()
        mean = np.array([1, 2, 3, 4, 5, 6, 7])  # Only 7 elements
        covariance = np.eye(8)
        with pytest.raises(IndexError):
            kf.project(mean, covariance)

    def test_invalid_covariance_shape(self):
        """Test that invalid covariance shape raises error."""
        kf = KalmanFilterXYAH()
        mean = np.array([1, 2, 3, 4, 5, 6, 7, 8])
        covariance = np.eye(7)  # Only 7x7
        with pytest.raises(ValueError):
            kf.project(mean, covariance)  # 34.5μs -> 25.1μs (37.6% faster)


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import numpy as np

# imports
import pytest  # used for our unit tests
from ultralytics.trackers.utils.kalman_filter import KalmanFilterXYAH

# function to test
# (KalmanFilterXYAH.project implementation included above)

# ---------------------------- Unit Tests for KalmanFilterXYAH.project ----------------------------


@pytest.fixture
def kf():
    """Fixture to create a KalmanFilterXYAH instance for reuse."""
    return KalmanFilterXYAH()


# ---------------- Basic Test Cases ----------------


def test_project_identity_covariance_basic(kf):
    """
    Basic test: mean is zeros except for aspect ratio and height,
    covariance is identity. Checks correct projection shape and values.
    """
    mean = np.array([0, 0, 1.5, 50, 0, 0, 0, 0], dtype=float)
    covariance = np.eye(8)
    projected_mean, projected_cov = kf.project(mean, covariance)  # 24.4μs -> 24.4μs (0.025% faster)


def test_project_nontrivial_mean_and_covariance(kf):
    """
    Basic test: mean and covariance with non-zero velocities and off-diagonal covariance.
    Checks projection math.
    """
    mean = np.array([10, 20, 2.0, 40, 1, -1, 0.5, 2], dtype=float)
    covariance = np.eye(8) * 2
    covariance[0, 4] = 1.0  # position-velocity covariance
    covariance[1, 5] = -1.0
    covariance[4, 0] = 1.0
    covariance[5, 1] = -1.0
    projected_mean, projected_cov = kf.project(mean, covariance)  # 18.6μs -> 16.4μs (13.1% faster)
    # Innovation covariance should be added to diagonal
    std = [
        kf._std_weight_position * mean[3],
        kf._std_weight_position * mean[3],
        1e-1,
        kf._std_weight_position * mean[3],
    ]
    expected_innovation_diag = np.square(std)
    for i in range(4):
        pass


# ---------------- Edge Test Cases ----------------


def test_project_zero_height(kf):
    """
    Edge case: mean[3] (height) is zero, which affects std calculation.
    Checks for correct handling (no division by zero).
    """
    mean = np.array([5, -5, 0.5, 0, 0, 0, 0, 0], dtype=float)
    covariance = np.eye(8)
    projected_mean, projected_cov = kf.project(mean, covariance)  # 17.9μs -> 15.9μs (12.3% faster)


def test_project_negative_height(kf):
    """
    Edge case: mean[3] (height) is negative, which is physically invalid but should not crash.
    Checks that std is negative but squared in innovation_cov.
    """
    mean = np.array([0, 0, 1, -10, 0, 0, 0, 0], dtype=float)
    covariance = np.eye(8)
    projected_mean, projected_cov = kf.project(mean, covariance)  # 17.5μs -> 15.8μs (10.9% faster)
    # Innovation covariance should be positive (std squared)
    std = [
        kf._std_weight_position * mean[3],
        kf._std_weight_position * mean[3],
        1e-1,
        kf._std_weight_position * mean[3],
    ]
    expected_innovation_diag = np.square(std)
    for i in range(4):
        pass


def test_project_large_height(kf):
    """
    Edge case: mean[3] (height) is extremely large.
    Checks for numerical stability and correct scaling of innovation_cov.
    """
    mean = np.array([1e6, -1e6, 1e3, 1e8, 0, 0, 0, 0], dtype=float)
    covariance = np.eye(8) * 1e6
    projected_mean, projected_cov = kf.project(mean, covariance)  # 16.1μs -> 14.0μs (15.6% faster)
    # Innovation covariance should be extremely large for x, y, h
    expected_std = kf._std_weight_position * mean[3]


def test_project_non_symmetric_covariance(kf):
    """
    Edge case: covariance is not symmetric (should still work, but output must be symmetric).
    """
    mean = np.ones(8)
    covariance = np.eye(8)
    covariance[0, 1] = 5
    # Make covariance non-symmetric
    covariance[1, 0] = 2
    projected_mean, projected_cov = kf.project(mean, covariance)  # 17.6μs -> 15.7μs (12.1% faster)


def test_project_singular_covariance(kf):
    """
    Edge case: covariance matrix is singular (rank deficient).
    Should not crash, but may produce large uncertainties.
    """
    mean = np.arange(8)
    covariance = np.zeros((8, 8))
    projected_mean, projected_cov = kf.project(mean, covariance)  # 25.6μs -> 21.9μs (16.9% faster)
    # Output covariance should be equal to innovation_cov
    std = [
        kf._std_weight_position * mean[3],
        kf._std_weight_position * mean[3],
        1e-1,
        kf._std_weight_position * mean[3],
    ]
    expected_innovation_cov = np.diag(np.square(std))


def test_project_extreme_values(kf):
    """
    Edge case: mean and covariance contain extremely large and small values.
    Checks for numerical stability.
    """
    mean = np.array([1e-12, 1e12, -1e12, 1e-12, 0, 0, 0, 0], dtype=float)
    covariance = np.eye(8) * 1e-12
    projected_mean, projected_cov = kf.project(mean, covariance)  # 16.2μs -> 14.0μs (15.3% faster)


# ---------------- Large Scale Test Cases ----------------


def test_project_batch_means_covariances(kf):
    """
    Large scale: project 1000 random means and covariances, check output shape and properties.
    """
    n = 1000
    means = np.random.randn(n, 8) * 100
    covariances = np.array([np.eye(8) * np.random.uniform(1, 10) for _ in range(n)])
    for i in range(n):
        projected_mean, projected_cov = kf.project(means[i], covariances[i])  # 5.14ms -> 4.38ms (17.4% faster)


def test_project_performance_large_height(kf):
    """
    Large scale: project means with very large heights, check that function completes efficiently.
    """
    n = 500
    means = np.zeros((n, 8))
    means[:, 3] = np.linspace(1e3, 1e6, n)  # heights from 1000 to 1,000,000
    covariances = np.array([np.eye(8) for _ in range(n)])
    for i in range(n):
        projected_mean, projected_cov = kf.project(means[i], covariances[i])  # 2.58ms -> 2.20ms (17.1% faster)
        # Covariance diagonal should be >= innovation_cov
        expected_std = kf._std_weight_position * means[i, 3]


def test_project_randomized_extremes(kf):
    """
    Large scale: project means and covariances with random extreme values.
    """
    n = 100
    rng = np.random.default_rng(42)
    for _ in range(n):
        mean = rng.uniform(-1e9, 1e9, 8)
        covariance = np.eye(8) * rng.uniform(1e-5, 1e5)
        projected_mean, projected_cov = kf.project(mean, covariance)  # 537μs -> 461μs (16.5% faster)


# ---------------- Invalid Input Test Cases ----------------


def test_project_invalid_covariance_shape(kf):
    """
    Invalid input: covariance of wrong shape should raise an error.
    """
    mean = np.ones(8)
    covariance = np.eye(7)  # Should be 8x8
    with pytest.raises(ValueError):
        kf.project(mean, covariance)  # 27.9μs -> 21.3μs (31.1% faster)


def test_project_nan_input(kf):
    """
    Invalid input: mean or covariance contains NaN, output should propagate NaN.
    """
    mean = np.array([0, np.nan, 1, 1, 0, 0, 0, 0])
    covariance = np.eye(8)
    projected_mean, projected_cov = kf.project(mean, covariance)  # 23.4μs -> 22.5μs (3.85% faster)


def test_project_inf_input(kf):
    """
    Invalid input: mean or covariance contains inf, output should propagate inf.
    """
    mean = np.array([0, 1, np.inf, 1, 0, 0, 0, 0])
    covariance = np.eye(8)
    projected_mean, projected_cov = kf.project(mean, covariance)  # 18.9μs -> 16.7μs (13.3% faster)


# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-KalmanFilterXYAH.project-mir8vhnz and push.

Codeflash Static Badge

The optimized code achieves a **17% speedup** through three key optimizations in the `project` method:

**What was optimized:**

1. **Reduced redundant computations**: The original code calculated `self._std_weight_position * mean[3]` four times. The optimized version computes this once as `std_pos_h` and reuses it.

2. **Eliminated intermediate list creation**: Instead of creating a Python list `std` and then calling `np.square(std)`, the optimized version creates a NumPy array directly and uses element-wise multiplication (`std * std`) for squaring.

3. **Replaced multi_dot with @ operator**: Changed `np.linalg.multi_dot((self._update_mat, covariance, self._update_mat.T))` to `self._update_mat @ covariance @ self._update_mat.T`, which is more efficient for this specific triple matrix multiplication.

**Why it's faster:**

- **Computation elimination**: Removing 3 redundant multiplications saves CPU cycles, especially important since this involves floating-point operations
- **Memory efficiency**: Direct NumPy array creation avoids Python list overhead and the intermediate `np.square()` call
- **Optimized matrix operations**: The `@` operator uses more efficient BLAS routines for consecutive matrix multiplications compared to the general-purpose `multi_dot`

**Performance characteristics:**

The line profiler shows the most significant improvements in:
- Innovation covariance calculation: 25.6% → 40% of total time (but absolute time decreased)
- Matrix multiplication: 54.3% → 27.1% of total time with substantial absolute time reduction

**Test results indicate** the optimization performs consistently well across all scenarios:
- Basic cases: 11-21% faster
- Edge cases (zero/negative heights): 12-27% faster  
- Large scale operations: 16-17% faster

This optimization is particularly valuable in object tracking scenarios where the `project` method is called frequently for each tracked object at every frame, making the 17% improvement compound significantly over time.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 December 4, 2025 09:39
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Dec 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant