|
| 1 | +import numpy as np |
| 2 | +from typing import Tuple |
| 3 | + |
| 4 | +def find_best_split(X: np.ndarray, y: np.ndarray) -> Tuple[int, float]: |
| 5 | + """ |
| 6 | + Find the best feature and threshold to split the dataset based on Gini impurity. |
| 7 | +
|
| 8 | + :param X: Feature matrix of shape (n_samples, n_features) |
| 9 | + :param y: Labels array of shape (n_samples,), binary (0 or 1) |
| 10 | + :return: (feature_index, threshold) with lowest weighted Gini impurity |
| 11 | + """ |
| 12 | + |
| 13 | + def gini_impurity(y_subset: np.ndarray) -> float: |
| 14 | + if len(y_subset) == 0: |
| 15 | + return 0.0 |
| 16 | + p = np.mean(y_subset == 1) |
| 17 | + return 1.0 - (p ** 2 + (1 - p) ** 2) |
| 18 | + |
| 19 | + n_samples, n_features = X.shape |
| 20 | + best_feature = -1 |
| 21 | + best_threshold = float('inf') |
| 22 | + best_gini = float('inf') |
| 23 | + |
| 24 | + for feature_index in range(n_features): |
| 25 | + thresholds = np.unique(X[:, feature_index]) |
| 26 | + for threshold in thresholds: |
| 27 | + left_mask = X[:, feature_index] <= threshold |
| 28 | + right_mask = ~left_mask |
| 29 | + |
| 30 | + y_left, y_right = y[left_mask], y[right_mask] |
| 31 | + g_left, g_right = gini_impurity(y_left), gini_impurity(y_right) |
| 32 | + |
| 33 | + weighted_gini = (len(y_left) * g_left + len(y_right) * g_right) / n_samples |
| 34 | + |
| 35 | + if weighted_gini < best_gini: |
| 36 | + best_gini = weighted_gini |
| 37 | + best_feature = feature_index |
| 38 | + best_threshold = threshold |
| 39 | + |
| 40 | + return best_feature, best_threshold |
| 41 | + |
| 42 | +def test(): |
| 43 | + # Test 1: Balanced binary split |
| 44 | + X1 = np.array([[2.5], [3.5], [1.0], [4.0]]) |
| 45 | + y1 = np.array([0, 1, 0, 1]) |
| 46 | + f1, t1 = find_best_split(X1, y1) |
| 47 | + assert f1 == 0 |
| 48 | + assert 1.0 <= t1 <= 3.5 |
| 49 | + |
| 50 | + # Test 2: Pure set (Gini = 0) |
| 51 | + X2 = np.array([[1], [2], [3]]) |
| 52 | + y2 = np.array([1, 1, 1]) |
| 53 | + f2, t2 = find_best_split(X2, y2) |
| 54 | + assert f2 == 0 |
| 55 | + assert t2 in [1, 2, 3] |
| 56 | + |
| 57 | + # Test 3: Alternating labels |
| 58 | + X3 = np.array([[1], [2], [3], [4]]) |
| 59 | + y3 = np.array([0, 1, 0, 1]) |
| 60 | + f3, t3 = find_best_split(X3, y3) |
| 61 | + assert f3 == 0 |
| 62 | + assert t3 in [1, 2, 3, 4] |
| 63 | + |
| 64 | + # Test 4: No good split (non-separable) |
| 65 | + X4 = np.array([[1], [1], [1]]) |
| 66 | + y4 = np.array([0, 1, 0]) |
| 67 | + f4, t4 = find_best_split(X4, y4) |
| 68 | + assert f4 == 0 |
| 69 | + assert t4 == 1 |
| 70 | + |
| 71 | + # Test 5: Two features, first one irrelevant |
| 72 | + X5 = np.array([[0, 1], [0, 2], [0, 3], [0, 4]]) |
| 73 | + y5 = np.array([0, 0, 1, 1]) |
| 74 | + f5, t5 = find_best_split(X5, y5) |
| 75 | + assert f5 == 1 |
| 76 | + assert t5 in [1, 2, 3, 4] |
| 77 | + |
| 78 | + # Test 6: Tiny dataset |
| 79 | + X6 = np.array([[1], [2]]) |
| 80 | + y6 = np.array([0, 1]) |
| 81 | + f6, t6 = find_best_split(X6, y6) |
| 82 | + assert f6 == 0 |
| 83 | + assert t6 in [1, 2] |
| 84 | + |
| 85 | + print("All test cases passed.") |
| 86 | + |
| 87 | +if __name__ == "__main__": |
| 88 | + test() |
0 commit comments