Skip to content

Commit 178f9d3

Browse files
authored
Merge pull request #448 from hardik1408/main
New problem: Gini impurity
2 parents 5bc3bbb + c0fe833 commit 178f9d3

File tree

2 files changed

+138
-0
lines changed

2 files changed

+138
-0
lines changed
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Learn: Gini Impurity and Best Split in Decision Trees
2+
3+
## Overview
4+
5+
A core concept in Decision Trees (and by extension, Random Forests) is how the model chooses where to split the data at each node. One popular criterion used for splitting is **Gini Impurity**.
6+
7+
In this task, you will implement:
8+
- Gini impurity computation
9+
- Finding the best feature and threshold to split on based on impurity reduction
10+
11+
This helps build the foundation for how trees grow in a Random Forest.
12+
13+
---
14+
15+
## Gini Impurity
16+
17+
For a set of samples with class labels \( y \), the Gini Impurity is defined as:
18+
19+
$$
20+
G(y) = 1 - \sum_{i=1}^{k} p_i^2
21+
$$
22+
23+
Where \( p_i \) is the proportion of samples belonging to class \( i \).
24+
25+
A pure node (all one class) has \( G = 0 \), and higher values indicate more class diversity.
26+
27+
---
28+
29+
## Gini Gain for a Split
30+
31+
Given a feature and a threshold to split the dataset into left and right subsets:
32+
33+
$$
34+
G_{\text{split}} = \frac{n_{\text{left}}}{n} G(y_{\text{left}}) + \frac{n_{\text{right}}}{n} G(y_{\text{right}})
35+
$$
36+
37+
We choose the split that **minimizes** $( G_{\text{split}} )$.
38+
39+
---
40+
41+
## Problem Statement
42+
43+
You are given a dataset $( X \in \mathbb{R}^{n \times d} )$ and labels $( y \in \{0, 1\}^n $). Implement the following functions:
44+
45+
### Functions to Implement
46+
47+
```python
48+
def find_best_split(X: np.ndarray, y: np.ndarray) -> Tuple[int, float]:
49+
...
50+
```
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import numpy as np
2+
from typing import Tuple
3+
4+
def find_best_split(X: np.ndarray, y: np.ndarray) -> Tuple[int, float]:
5+
"""
6+
Find the best feature and threshold to split the dataset based on Gini impurity.
7+
8+
:param X: Feature matrix of shape (n_samples, n_features)
9+
:param y: Labels array of shape (n_samples,), binary (0 or 1)
10+
:return: (feature_index, threshold) with lowest weighted Gini impurity
11+
"""
12+
13+
def gini_impurity(y_subset: np.ndarray) -> float:
14+
if len(y_subset) == 0:
15+
return 0.0
16+
p = np.mean(y_subset == 1)
17+
return 1.0 - (p ** 2 + (1 - p) ** 2)
18+
19+
n_samples, n_features = X.shape
20+
best_feature = -1
21+
best_threshold = float('inf')
22+
best_gini = float('inf')
23+
24+
for feature_index in range(n_features):
25+
thresholds = np.unique(X[:, feature_index])
26+
for threshold in thresholds:
27+
left_mask = X[:, feature_index] <= threshold
28+
right_mask = ~left_mask
29+
30+
y_left, y_right = y[left_mask], y[right_mask]
31+
g_left, g_right = gini_impurity(y_left), gini_impurity(y_right)
32+
33+
weighted_gini = (len(y_left) * g_left + len(y_right) * g_right) / n_samples
34+
35+
if weighted_gini < best_gini:
36+
best_gini = weighted_gini
37+
best_feature = feature_index
38+
best_threshold = threshold
39+
40+
return best_feature, best_threshold
41+
42+
def test():
43+
# Test 1: Balanced binary split
44+
X1 = np.array([[2.5], [3.5], [1.0], [4.0]])
45+
y1 = np.array([0, 1, 0, 1])
46+
f1, t1 = find_best_split(X1, y1)
47+
assert f1 == 0
48+
assert 1.0 <= t1 <= 3.5
49+
50+
# Test 2: Pure set (Gini = 0)
51+
X2 = np.array([[1], [2], [3]])
52+
y2 = np.array([1, 1, 1])
53+
f2, t2 = find_best_split(X2, y2)
54+
assert f2 == 0
55+
assert t2 in [1, 2, 3]
56+
57+
# Test 3: Alternating labels
58+
X3 = np.array([[1], [2], [3], [4]])
59+
y3 = np.array([0, 1, 0, 1])
60+
f3, t3 = find_best_split(X3, y3)
61+
assert f3 == 0
62+
assert t3 in [1, 2, 3, 4]
63+
64+
# Test 4: No good split (non-separable)
65+
X4 = np.array([[1], [1], [1]])
66+
y4 = np.array([0, 1, 0])
67+
f4, t4 = find_best_split(X4, y4)
68+
assert f4 == 0
69+
assert t4 == 1
70+
71+
# Test 5: Two features, first one irrelevant
72+
X5 = np.array([[0, 1], [0, 2], [0, 3], [0, 4]])
73+
y5 = np.array([0, 0, 1, 1])
74+
f5, t5 = find_best_split(X5, y5)
75+
assert f5 == 1
76+
assert t5 in [1, 2, 3, 4]
77+
78+
# Test 6: Tiny dataset
79+
X6 = np.array([[1], [2]])
80+
y6 = np.array([0, 1])
81+
f6, t6 = find_best_split(X6, y6)
82+
assert f6 == 0
83+
assert t6 in [1, 2]
84+
85+
print("All test cases passed.")
86+
87+
if __name__ == "__main__":
88+
test()

0 commit comments

Comments
 (0)