|
| 1 | +# /// script |
| 2 | +# requires-python = ">=3.12" |
| 3 | +# dependencies = [ |
| 4 | +# "marimo", |
| 5 | +# "numpy==2.2.2", |
| 6 | +# "plotly==6.0.0", |
| 7 | +# ] |
| 8 | +# /// |
| 9 | + |
| 10 | +import marimo |
| 11 | + |
| 12 | +__generated_with = "0.10.17" |
| 13 | +app = marimo.App(width="medium") |
| 14 | + |
| 15 | + |
| 16 | +@app.cell(hide_code=True) |
| 17 | +def _(mo): |
| 18 | + mo.md( |
| 19 | + r""" |
| 20 | + # Understanding Confusion Matrices in Binary Classification |
| 21 | +
|
| 22 | + The [Confusion Matrix](https://en.wikipedia.org/wiki/Confusion_matrix) is a fundamental tool for evaluating classification models. It provides a detailed breakdown of correct and incorrect predictions, helping us understand where our model succeeds and fails. Let's explore it interactively! |
| 23 | + """ |
| 24 | + ).center() |
| 25 | + return |
| 26 | + |
| 27 | + |
| 28 | +@app.cell(hide_code=True) |
| 29 | +def _(mo): |
| 30 | + definition = mo.md(r""" |
| 31 | + For binary classification, the confusion matrix is a $2 \times 2$ matrix: |
| 32 | +
|
| 33 | + \[ |
| 34 | + M = \begin{pmatrix} |
| 35 | + \text{TP} & \text{FN} \\ |
| 36 | + \text{FP} & \text{TN} |
| 37 | + \end{pmatrix} |
| 38 | + \] |
| 39 | +
|
| 40 | + where: |
| 41 | +
|
| 42 | + - TP (True Positives): Correctly predicted positive cases |
| 43 | +
|
| 44 | + - FN (False Negatives): Incorrectly predicted negative cases |
| 45 | +
|
| 46 | + - FP (False Positives): Incorrectly predicted positive cases |
| 47 | +
|
| 48 | + - TN (True Negatives): Correctly predicted negative cases |
| 49 | + """) |
| 50 | + |
| 51 | + mo.accordion({"### Mathematical Definition": definition}) |
| 52 | + return (definition,) |
| 53 | + |
| 54 | + |
| 55 | +@app.cell |
| 56 | +def _(flow, mo): |
| 57 | + mo.accordion({"Process Flow": flow.center()}) |
| 58 | + return |
| 59 | + |
| 60 | + |
| 61 | +@app.cell(hide_code=True) |
| 62 | +def _(mo): |
| 63 | + # flowchart showing confusion matrix computation |
| 64 | + flow = mo.mermaid(""" |
| 65 | + graph TD |
| 66 | + A[Input Data<br>y_true, y_pred] --> B[Count Predictions] |
| 67 | + B --> C[Organize in 2x2 Matrix] |
| 68 | + C --> D[Calculate Metrics] |
| 69 | + D --> E[Precision] |
| 70 | + D --> F[Recall] |
| 71 | + D --> G[Accuracy] |
| 72 | + D --> H[F1-Score] |
| 73 | + """) |
| 74 | + return (flow,) |
| 75 | + |
| 76 | + |
| 77 | +@app.cell |
| 78 | +def _(mo): |
| 79 | + mo.md( |
| 80 | + """ |
| 81 | + ### Input Data |
| 82 | + Enter the actual and predicted classifications for each individual (0 for negative, 1 for positive). |
| 83 | + """ |
| 84 | + ) |
| 85 | + return |
| 86 | + |
| 87 | + |
| 88 | +@app.cell |
| 89 | +def _(data_controls): |
| 90 | + data_controls |
| 91 | + return |
| 92 | + |
| 93 | + |
| 94 | +@app.cell |
| 95 | +def _(mo): |
| 96 | + # Create number inputs for 12 individuals |
| 97 | + n_samples = 12 |
| 98 | + actual_inputs = mo.ui.array([ |
| 99 | + mo.ui.number(value=0, start=0, stop=1, label=f"Actual {i+1}") |
| 100 | + for i in range(n_samples) |
| 101 | + ]) |
| 102 | + predicted_inputs = mo.ui.array([ |
| 103 | + mo.ui.number(value=0, start=0, stop=1, label=f"Predicted {i+1}") |
| 104 | + for i in range(n_samples) |
| 105 | + ]) |
| 106 | + |
| 107 | + # Create data table using markdown with LaTeX |
| 108 | + data_table = mo.md(r""" |
| 109 | + $$ |
| 110 | + \begin{array}{|c|c|c|c|c|c|c|c|c|c|c|c|c|} |
| 111 | + \hline |
| 112 | + \text{Individual} & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 & 9 & 10 & 11 & 12 \\ |
| 113 | + \hline |
| 114 | + \text{Actual} & a_1 & a_2 & a_3 & a_4 & a_5 & a_6 & a_7 & a_8 & a_9 & a_{10} & a_{11} & a_{12} \\ |
| 115 | + \hline |
| 116 | + \text{Predicted} & p_1 & p_2 & p_3 & p_4 & p_5 & p_6 & p_7 & p_8 & p_9 & p_{10} & p_{11} & p_{12} \\ |
| 117 | + \hline |
| 118 | + \end{array} |
| 119 | + $$ |
| 120 | + """) |
| 121 | + |
| 122 | + # Stack inputs horizontally and data table below |
| 123 | + data_controls = mo.vstack([ |
| 124 | + mo.hstack([ |
| 125 | + mo.vstack([ |
| 126 | + mo.md("**Actual Classifications:**"), |
| 127 | + actual_inputs |
| 128 | + ]), |
| 129 | + mo.vstack([ |
| 130 | + mo.md("**Predicted Classifications:**"), |
| 131 | + predicted_inputs |
| 132 | + ]) |
| 133 | + ], justify="start", align="start"), |
| 134 | + mo.md("### Data Table:"), |
| 135 | + data_table |
| 136 | + ], gap=2) # Added gap for better spacing |
| 137 | + return ( |
| 138 | + actual_inputs, |
| 139 | + data_controls, |
| 140 | + data_table, |
| 141 | + n_samples, |
| 142 | + predicted_inputs, |
| 143 | + ) |
| 144 | + |
| 145 | + |
| 146 | +@app.cell |
| 147 | +def _(compute_button): |
| 148 | + compute_button.center() |
| 149 | + return |
| 150 | + |
| 151 | + |
| 152 | +@app.cell |
| 153 | +def _(mo): |
| 154 | + compute_button = mo.ui.run_button(label="Compute Confusion Matrix") |
| 155 | + return (compute_button,) |
| 156 | + |
| 157 | + |
| 158 | +@app.cell |
| 159 | +def _( |
| 160 | + actual_inputs, |
| 161 | + compute_button, |
| 162 | + explanation, |
| 163 | + mo, |
| 164 | + np, |
| 165 | + predicted_inputs, |
| 166 | + px, |
| 167 | +): |
| 168 | + results = None |
| 169 | + if compute_button.value: |
| 170 | + # get data from inputs |
| 171 | + actual_values = np.array([inp.value for inp in actual_inputs]) |
| 172 | + predicted_values = np.array([inp.value for inp in predicted_inputs]) |
| 173 | + |
| 174 | + # results for each individual |
| 175 | + results_array = [] |
| 176 | + for actual, pred in zip(actual_values, predicted_values): |
| 177 | + if actual == 1 and pred == 1: |
| 178 | + result = "TP" |
| 179 | + elif actual == 1 and pred == 0: |
| 180 | + result = "FN" |
| 181 | + elif actual == 0 and pred == 1: |
| 182 | + result = "FP" |
| 183 | + else: |
| 184 | + result = "TN" |
| 185 | + results_array.append(result) |
| 186 | + |
| 187 | + # confusion matrix calc |
| 188 | + tp = sum(1 for r in results_array if r == "TP") |
| 189 | + fn = sum(1 for r in results_array if r == "FN") |
| 190 | + fp = sum(1 for r in results_array if r == "FP") |
| 191 | + tn = sum(1 for r in results_array if r == "TN") |
| 192 | + |
| 193 | + conf_matrix = np.array([[tp, fn], [fp, tn]]) |
| 194 | + total = tp + fn + fp + tn |
| 195 | + |
| 196 | + # performance metrics calc |
| 197 | + accuracy = (tp + tn) / total if total > 0 else 0 |
| 198 | + precision = tp / (tp + fp) if (tp + fp) > 0 else 0 |
| 199 | + recall = tp / (tp + fn) if (tp + fn) > 0 else 0 |
| 200 | + f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 |
| 201 | + |
| 202 | + # results table using markdown with LaTeX |
| 203 | + results_table = mo.md(r""" |
| 204 | + $$ |
| 205 | + \begin{array}{|c|c|c|c|c|c|c|c|c|c|c|c|c|} |
| 206 | + \hline |
| 207 | + \text{Individual} & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 & 9 & 10 & 11 & 12 \\ |
| 208 | + \hline |
| 209 | + \text{Actual} & """ + " & ".join(str(v) for v in actual_values) + r""" \\ |
| 210 | + \hline |
| 211 | + \text{Predicted} & """ + " & ".join(str(v) for v in predicted_values) + r""" \\ |
| 212 | + \hline |
| 213 | + \text{Result} & """ + " & ".join(results_array) + r""" \\ |
| 214 | + \hline |
| 215 | + \end{array} |
| 216 | + $$ |
| 217 | + """) |
| 218 | + |
| 219 | + # Create confusion matrix visualization |
| 220 | + fig = px.imshow( |
| 221 | + conf_matrix, |
| 222 | + labels=dict(x="Predicted", y="Actual"), |
| 223 | + x=['Positive', 'Negative'], |
| 224 | + y=['Positive', 'Negative'], |
| 225 | + aspect="auto", |
| 226 | + title="Confusion Matrix Heatmap", |
| 227 | + color_continuous_scale="RdBu", |
| 228 | + width=500, |
| 229 | + height=500, |
| 230 | + text_auto=True |
| 231 | + ) |
| 232 | + |
| 233 | + fig.update_traces( |
| 234 | + texttemplate="%{z}", |
| 235 | + textfont={"size": 20}, |
| 236 | + hoverongaps=False, |
| 237 | + hovertemplate="<br>".join([ |
| 238 | + "Actual: %{y}", |
| 239 | + "Predicted: %{x}", |
| 240 | + "Count: %{z}", |
| 241 | + "<extra></extra>" |
| 242 | + ]) |
| 243 | + ) |
| 244 | + |
| 245 | + # matrix interpretation |
| 246 | + matrix_interpretation = mo.md(f""" |
| 247 | + ### Matrix Interpretation |
| 248 | + |
| 249 | + - True Positives (TP): {tp} (Actual: Positive, Predicted: Positive) |
| 250 | + |
| 251 | + - False Negatives (FN): {fn} (Actual: Positive, Predicted: Negative) |
| 252 | + |
| 253 | + - False Positives (FP): {fp} (Actual: Negative, Predicted: Positive) |
| 254 | + |
| 255 | + - True Negatives (TN): {tn} (Actual: Negative, Predicted: Negative) |
| 256 | +
|
| 257 | + **Metrics:** |
| 258 | + |
| 259 | + - Accuracy: {accuracy:.2f} |
| 260 | + |
| 261 | + - Precision: {precision:.2f} |
| 262 | + |
| 263 | + - Recall: {recall:.2f} |
| 264 | + |
| 265 | + - F1 Score: {f1:.2f} |
| 266 | + """) |
| 267 | + |
| 268 | + results = mo.vstack([ |
| 269 | + mo.md("### Results"), |
| 270 | + results_table, |
| 271 | + # confusion matrix and interpretation side-by-side |
| 272 | + mo.hstack([ |
| 273 | + fig, |
| 274 | + matrix_interpretation |
| 275 | + ], justify="start", align="start"), |
| 276 | + explanation, |
| 277 | + # final callout |
| 278 | + mo.callout( |
| 279 | + mo.md(""" |
| 280 | + 🎉 Congratulations! You've successfully: |
| 281 | + |
| 282 | + - Understood how confusion matrices work in binary classification |
| 283 | + |
| 284 | + - Learned to interpret TP, FN, FP, and TN |
| 285 | + |
| 286 | + - Explored key metrics like accuracy, precision, recall, and F1 score |
| 287 | + |
| 288 | + - Gained hands-on experience with interactive confusion matrix analysis |
| 289 | + """), |
| 290 | + kind="success" |
| 291 | + ) |
| 292 | + ]) |
| 293 | + results |
| 294 | + return ( |
| 295 | + accuracy, |
| 296 | + actual, |
| 297 | + actual_values, |
| 298 | + conf_matrix, |
| 299 | + f1, |
| 300 | + fig, |
| 301 | + fn, |
| 302 | + fp, |
| 303 | + matrix_interpretation, |
| 304 | + precision, |
| 305 | + pred, |
| 306 | + predicted_values, |
| 307 | + recall, |
| 308 | + result, |
| 309 | + results, |
| 310 | + results_array, |
| 311 | + results_table, |
| 312 | + tn, |
| 313 | + total, |
| 314 | + tp, |
| 315 | + ) |
| 316 | + |
| 317 | + |
| 318 | +@app.cell(hide_code=True) |
| 319 | +def _(mo): |
| 320 | + explanation = mo.accordion({ |
| 321 | + "🎯 Understanding the Results": mo.md(""" |
| 322 | + **Interpreting the Confusion Matrix:** |
| 323 | +
|
| 324 | + 1. **Top-left (TP)**: Correctly identified positive cases |
| 325 | + 2. **Top-right (FN)**: Missed positive cases |
| 326 | + 3. **Bottom-left (FP)**: False alarms |
| 327 | + 4. **Bottom-right (TN)**: Correctly identified negative cases |
| 328 | + """), |
| 329 | + |
| 330 | + "📊 Derived Metrics": mo.md(""" |
| 331 | + - **Accuracy**: Overall correctness (TP + TN) / Total |
| 332 | + - **Precision**: Positive predictive value TP / (TP + FP) |
| 333 | + - **Recall**: True positive rate TP / (TP + FN) |
| 334 | + - **F1 Score**: Harmonic mean of precision and recall |
| 335 | + """), |
| 336 | + |
| 337 | + "💡 Best Practices": mo.md(""" |
| 338 | + 1. Consider class imbalance |
| 339 | + 2. Look at all metrics, not just accuracy |
| 340 | + 3. Choose metrics based on your problem context |
| 341 | + 4. Use confusion matrix for model debugging |
| 342 | + """) |
| 343 | + }) |
| 344 | + return (explanation,) |
| 345 | + |
| 346 | + |
| 347 | +@app.cell |
| 348 | +def _(): |
| 349 | + import marimo as mo |
| 350 | + return (mo,) |
| 351 | + |
| 352 | + |
| 353 | +@app.cell |
| 354 | +def _(): |
| 355 | + import numpy as np |
| 356 | + import plotly.express as px |
| 357 | + return np, px |
| 358 | + |
| 359 | + |
| 360 | +if __name__ == "__main__": |
| 361 | + app.run() |
0 commit comments