From becf0bd91519525d47cafdb93ac3e9e3601b1613 Mon Sep 17 00:00:00 2001 From: John Walz Date: Tue, 10 Dec 2024 15:36:04 -0500 Subject: [PATCH 1/8] feat: adding support for post processing functions and raw data --- .../post_processing_functions.ipynb | 351 ++++++++++++++++++ validmind/__init__.py | 2 + .../model_validation/sklearn/ROCCurve.py | 49 +-- validmind/tests/output.py | 11 +- validmind/tests/run.py | 7 +- validmind/vm_models/result/__init__.py | 4 +- validmind/vm_models/result/result.py | 24 ++ 7 files changed, 421 insertions(+), 27 deletions(-) create mode 100644 notebooks/code_sharing/post_processing_functions.ipynb diff --git a/notebooks/code_sharing/post_processing_functions.ipynb b/notebooks/code_sharing/post_processing_functions.ipynb new file mode 100644 index 000000000..a415d3996 --- /dev/null +++ b/notebooks/code_sharing/post_processing_functions.ipynb @@ -0,0 +1,351 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import xgboost as xgb\n", + "import validmind as vm\n", + "from validmind.datasets.classification import customer_churn\n", + "\n", + "vm.init(\n", + " api_host=\"\",\n", + " api_key=\"\",\n", + " api_secret=\"\",\n", + " model=\"\",\n", + ")\n", + "\n", + "raw_df = customer_churn.load_data()\n", + "\n", + "train_df, validation_df, test_df = customer_churn.preprocess(raw_df)\n", + "\n", + "x_train = train_df.drop(customer_churn.target_column, axis=1)\n", + "y_train = train_df[customer_churn.target_column]\n", + "x_val = validation_df.drop(customer_churn.target_column, axis=1)\n", + "y_val = validation_df[customer_churn.target_column]\n", + "\n", + "model = xgb.XGBClassifier(early_stopping_rounds=10)\n", + "model.set_params(\n", + " eval_metric=[\"error\", \"logloss\", \"auc\"],\n", + ")\n", + "model.fit(\n", + " x_train,\n", + " y_train,\n", + " eval_set=[(x_val, y_val)],\n", + " verbose=False,\n", + ")\n", + "\n", + "vm_raw_dataset = vm.init_dataset(\n", + " dataset=raw_df,\n", + " input_id=\"raw_dataset\",\n", + " target_column=customer_churn.target_column,\n", + " class_labels=customer_churn.class_labels,\n", + ")\n", + "\n", + "vm_train_ds = vm.init_dataset(\n", + " dataset=train_df,\n", + " input_id=\"train_dataset\",\n", + " target_column=customer_churn.target_column,\n", + ")\n", + "\n", + "vm_test_ds = vm.init_dataset(\n", + " dataset=test_df, input_id=\"test_dataset\", target_column=customer_churn.target_column\n", + ")\n", + "\n", + "vm_model = vm.init_model(\n", + " model,\n", + " input_id=\"model\",\n", + ")\n", + "\n", + "vm_train_ds.assign_predictions(\n", + " model=vm_model,\n", + ")\n", + "\n", + "vm_test_ds.assign_predictions(\n", + " model=vm_model,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from validmind.tests import run_test\n", + "\n", + "result = run_test(\n", + " \"validmind.model_validation.sklearn.ClassifierPerformance\",\n", + " inputs={\"dataset\": vm_test_ds, \"model\": vm_model},\n", + " generate_description=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Post-processing functions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Simple Tabular Updates" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from validmind.vm_models.result import TestResult\n", + "\n", + "\n", + "def add_class_labels(result: TestResult):\n", + " result.tables[0].data[\"Class\"] = (\n", + " result.tables[0]\n", + " .data[\"Class\"]\n", + " .map(lambda x: \"Churn\" if x == \"1\" else \"No Churn\" if x == \"0\" else x)\n", + " )\n", + "\n", + " return result\n", + "\n", + "\n", + "result = run_test(\n", + " \"validmind.model_validation.sklearn.ClassifierPerformance\",\n", + " inputs={\"dataset\": vm_test_ds, \"model\": vm_model},\n", + " generate_description=False,\n", + " post_process_fn=add_class_labels,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Adding Tables" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from validmind.vm_models.result import ResultTable\n", + "\n", + "def add_table(result: TestResult):\n", + " # add legend table to show map of class value to class label\n", + " result.add_table(\n", + " ResultTable(\n", + " title=\"Class Legend\",\n", + " data=[\n", + " {\"Class Value\": \"0\", \"Class Label\": \"No Churn\"},\n", + " {\"Class Value\": \"1\", \"Class Label\": \"Churn\"},\n", + " ],\n", + " )\n", + " )\n", + "\n", + " return result\n", + "\n", + "\n", + "result = run_test(\n", + " \"validmind.model_validation.sklearn.ClassifierPerformance\",\n", + " inputs={\"dataset\": vm_test_ds, \"model\": vm_model},\n", + " generate_description=False,\n", + " post_process_fn=add_table,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Removing Tables" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def remove_table(result: TestResult):\n", + " result.tables.pop(1)\n", + "\n", + " return result\n", + "\n", + "\n", + "result = run_test(\n", + " \"validmind.model_validation.sklearn.ClassifierPerformance\",\n", + " inputs={\"dataset\": vm_test_ds, \"model\": vm_model},\n", + " generate_description=False,\n", + " post_process_fn=remove_table,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating Figure from Tables" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from plotly_express import bar\n", + "from validmind.vm_models.figure import Figure\n", + "\n", + "\n", + "def create_figure(result: TestResult):\n", + " fig = bar(result.tables[0].data, x=\"Variable\", y=\"Total Count of Outliers\")\n", + "\n", + " if result.raw_data is not None:\n", + " # create a new figure from the raw data\n", + " fig = bar(result.raw_data, x=\"Variable\", y=\"Total Count of Outliers\")\n", + "\n", + " result.add_figure(\n", + " Figure(\n", + " figure=fig,\n", + " key=\"outlier_count_by_variable\",\n", + " ref_id=result.ref_id,\n", + " )\n", + " )\n", + "\n", + " return result\n", + "\n", + "\n", + "result = run_test(\n", + " \"validmind.data_validation.IQROutliersTable\",\n", + " inputs={\"dataset\": vm_test_ds},\n", + " generate_description=False,\n", + " post_process_fn=create_figure,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Creating Tables from Figures" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_table(result: TestResult):\n", + " for fig in result.figures:\n", + " data = fig.figure.data[0]\n", + "\n", + " table_data = [\n", + " {\"Percentile\": x, \"Outlier Count\": y}\n", + " for x, y in zip(data.x, data.y)\n", + " ]\n", + "\n", + " result.add_table(\n", + " ResultTable(\n", + " title=fig.figure.layout.title.text,\n", + " data=table_data,\n", + " )\n", + " )\n", + "\n", + " return result\n", + "\n", + "\n", + "result = run_test(\n", + " \"validmind.data_validation.IQROutliersBarPlot\",\n", + " inputs={\"dataset\": vm_test_ds},\n", + " generate_description=False,\n", + " post_process_fn=create_table,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Re-Draw Confusion Matrix" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "def re_draw_class_imbalance(result: TestResult):\n", + " data = result.tables[0].data\n", + " # Exited Percentage of Rows (%) Pass/Fail\n", + " # 0 0 80.25% Pass\n", + " # 1 1 19.75% Pass\n", + "\n", + " result.figures = []\n", + "\n", + " # use matplotlib to plot the confusion matrix\n", + " fig = plt.figure()\n", + "\n", + " # show a bar plot of the class imbalance with matplotlib\n", + " plt.bar(data[\"Exited\"], data[\"Percentage of Rows (%)\"])\n", + " plt.xlabel(\"Exited\")\n", + " plt.ylabel(\"Percentage of Rows (%)\")\n", + " plt.title(\"Class Imbalance\")\n", + "\n", + " result.add_figure(\n", + " Figure(\n", + " figure=fig,\n", + " key=\"confusion_matrix\",\n", + " ref_id=result.ref_id,\n", + " )\n", + " )\n", + "\n", + " plt.close()\n", + "\n", + " return result\n", + "\n", + "\n", + "result = run_test(\n", + " \"validmind.data_validation.ClassImbalance\",\n", + " inputs={\"dataset\": vm_test_ds},\n", + " generate_description=False,\n", + " post_process_fn=re_draw_class_imbalance,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "validmind-BbKYUwN1-py3.11", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/validmind/__init__.py b/validmind/__init__.py index 0763c2b5f..36cd9250e 100644 --- a/validmind/__init__.py +++ b/validmind/__init__.py @@ -50,6 +50,7 @@ run_test_suite, ) from .tests.decorator import tags, tasks, test +from .vm_models.result import RawData __all__ = [ # noqa "__version__", @@ -62,6 +63,7 @@ "init_model", "init_r_model", "preview_template", + "RawData", "reload", "run_documentation_tests", "run_test_suite", diff --git a/validmind/tests/model_validation/sklearn/ROCCurve.py b/validmind/tests/model_validation/sklearn/ROCCurve.py index 9dfacd61c..7113d0bc1 100644 --- a/validmind/tests/model_validation/sklearn/ROCCurve.py +++ b/validmind/tests/model_validation/sklearn/ROCCurve.py @@ -6,7 +6,7 @@ import plotly.graph_objects as go from sklearn.metrics import roc_auc_score, roc_curve -from validmind import tags, tasks +from validmind import RawData, tags, tasks from validmind.errors import SkipTestError from validmind.vm_models import VMDataset, VMModel @@ -77,28 +77,31 @@ def ROCCurve(model: VMModel, dataset: VMDataset): fpr, tpr, _ = roc_curve(y_true, y_prob, drop_intermediate=False) auc = roc_auc_score(y_true, y_prob) - return go.Figure( - data=[ - go.Scatter( - x=fpr, - y=tpr, - mode="lines", - name=f"ROC curve (AUC = {auc:.2f})", - line=dict(color="#DE257E"), + return ( + RawData(fpr=fpr, tpr=tpr, auc=auc), + go.Figure( + data=[ + go.Scatter( + x=fpr, + y=tpr, + mode="lines", + name=f"ROC curve (AUC = {auc:.2f})", + line=dict(color="#DE257E"), + ), + go.Scatter( + x=[0, 1], + y=[0, 1], + mode="lines", + name="Random (AUC = 0.5)", + line=dict(color="grey", dash="dash"), + ), + ], + layout=go.Layout( + title=f"ROC Curve for {model.input_id} on {dataset.input_id}", + xaxis=dict(title="False Positive Rate"), + yaxis=dict(title="True Positive Rate"), + width=700, + height=500, ), - go.Scatter( - x=[0, 1], - y=[0, 1], - mode="lines", - name="Random (AUC = 0.5)", - line=dict(color="grey", dash="dash"), - ), - ], - layout=go.Layout( - title=f"ROC Curve for {model.input_id} on {dataset.input_id}", - xaxis=dict(title="False Positive Rate"), - yaxis=dict(title="True Positive Rate"), - width=700, - height=500, ), ) diff --git a/validmind/tests/output.py b/validmind/tests/output.py index 762be7bc1..d5afc3f3c 100644 --- a/validmind/tests/output.py +++ b/validmind/tests/output.py @@ -15,7 +15,7 @@ is_plotly_figure, is_png_image, ) -from validmind.vm_models.result import ResultTable, TestResult +from validmind.vm_models.result import RawData, ResultTable, TestResult class OutputHandler(ABC): @@ -103,6 +103,14 @@ def process( result.add_table(ResultTable(data=table_data, title=table_name or None)) +class RawDataOutputHandler(OutputHandler): + def can_handle(self, item: Any) -> bool: + return isinstance(item, RawData) + + def process(self, item: Any, result: TestResult) -> None: + result.raw_data = item + + def process_output(item: Any, result: TestResult) -> None: """Process a single test output item and update the TestResult.""" handlers = [ @@ -110,6 +118,7 @@ def process_output(item: Any, result: TestResult) -> None: MetricOutputHandler(), FigureOutputHandler(), TableOutputHandler(), + RawDataOutputHandler(), ] for handler in handlers: diff --git a/validmind/tests/run.py b/validmind/tests/run.py index 9c806f306..c3b28f050 100644 --- a/validmind/tests/run.py +++ b/validmind/tests/run.py @@ -7,7 +7,7 @@ import time from datetime import datetime from inspect import getdoc -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from uuid import uuid4 from validmind import __version__ @@ -283,6 +283,7 @@ def run_test( show: bool = True, generate_description: bool = True, title: Optional[str] = None, + post_process_fn: Union[Callable[[TestResult], None], None] = None, **kwargs, ) -> TestResult: """Run a ValidMind or custom test @@ -306,6 +307,7 @@ def run_test( show (bool, optional): Whether to display results. Defaults to True. generate_description (bool, optional): Whether to generate a description. Defaults to True. title (str, optional): Custom title for the test result + post_process_fn (Callable[[TestResult], None], optional): Function to post-process the test result Returns: TestResult: A TestResult object containing the test results @@ -394,6 +396,9 @@ def run_test( end_time = time.perf_counter() result.metadata = _get_run_metadata(duration_seconds=end_time - start_time) + if post_process_fn: + result = post_process_fn(result) + if show: result.show() diff --git a/validmind/vm_models/result/__init__.py b/validmind/vm_models/result/__init__.py index 721e2f5cd..aca6c17e6 100644 --- a/validmind/vm_models/result/__init__.py +++ b/validmind/vm_models/result/__init__.py @@ -2,6 +2,6 @@ # See the LICENSE file in the root of this repository for details. # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial -from .result import ErrorResult, Result, ResultTable, TestResult +from .result import ErrorResult, RawData, Result, ResultTable, TestResult -__all__ = ["ErrorResult", "Result", "ResultTable", "TestResult"] +__all__ = ["ErrorResult", "RawData", "Result", "ResultTable", "TestResult"] diff --git a/validmind/vm_models/result/result.py b/validmind/vm_models/result/result.py index 08c26a7b8..7d7247557 100644 --- a/validmind/vm_models/result/result.py +++ b/validmind/vm_models/result/result.py @@ -34,6 +34,29 @@ logger = get_logger(__name__) +class RawData: + """Holds raw data for a test result""" + + def __init__(self, log: bool = False, **kwargs): + """Create a new RawData object + + Args: + log (bool): If True, log the raw data to ValidMind + **kwargs: Keyword arguments to set as attributes e.g. + `RawData(log=True, dataset_duplicates=df_duplicates)` + """ + self.log = log + + for key, value in kwargs.items(): + setattr(self, key, value) + + def __repr__(self) -> str: + return f"RawData({', '.join(self.__dict__.keys())})" + + def serialize(self): + return {key: getattr(self, key) for key in self.__dict__} + + @dataclass class ResultTable: """ @@ -118,6 +141,7 @@ class TestResult(Result): description: Optional[Union[str, DescriptionFuture]] = None metric: Optional[Union[int, float]] = None tables: Optional[List[ResultTable]] = None + raw_data: Optional[RawData] = None figures: Optional[List[Figure]] = None passed: Optional[bool] = None params: Optional[Dict[str, Any]] = None From 32fbd32b282366e0284040c8da872fc60a5ca6f0 Mon Sep 17 00:00:00 2001 From: John Walz Date: Tue, 10 Dec 2024 15:36:15 -0500 Subject: [PATCH 2/8] 2.7.0 --- pyproject.toml | 2 +- validmind/__version__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 44f805ccf..9ded3eeb8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ description = "ValidMind Library" license = "Commercial License" name = "validmind" readme = "README.pypi.md" -version = "2.6.7" +version = "2.7.0" [tool.poetry.dependencies] aiohttp = {extras = ["speedups"], version = "*"} diff --git a/validmind/__version__.py b/validmind/__version__.py index 492f7d9a9..2614ce9d9 100644 --- a/validmind/__version__.py +++ b/validmind/__version__.py @@ -1 +1 @@ -__version__ = "2.6.7" +__version__ = "2.7.0" From d9515821039d4ef140b859553f539d44dbc7781a Mon Sep 17 00:00:00 2001 From: John Walz Date: Wed, 11 Dec 2024 14:43:31 -0500 Subject: [PATCH 3/8] chore: saving changes to work on hotfix --- .../post_processing_functions.ipynb | 169 ++++++++++++++++-- validmind/tests/load.py | 6 +- .../model_validation/ragas/Faithfulness.py | 1 + validmind/vm_models/figure.py | 3 + validmind/vm_models/result/result.py | 6 +- 5 files changed, 170 insertions(+), 15 deletions(-) diff --git a/notebooks/code_sharing/post_processing_functions.ipynb b/notebooks/code_sharing/post_processing_functions.ipynb index a415d3996..1a597047a 100644 --- a/notebooks/code_sharing/post_processing_functions.ipynb +++ b/notebooks/code_sharing/post_processing_functions.ipynb @@ -10,12 +10,7 @@ "import validmind as vm\n", "from validmind.datasets.classification import customer_churn\n", "\n", - "vm.init(\n", - " api_host=\"\",\n", - " api_key=\"\",\n", - " api_secret=\"\",\n", - " model=\"\",\n", - ")\n", + "vm.init()\n", "\n", "raw_df = customer_churn.load_data()\n", "\n", @@ -209,10 +204,6 @@ "def create_figure(result: TestResult):\n", " fig = bar(result.tables[0].data, x=\"Variable\", y=\"Total Count of Outliers\")\n", "\n", - " if result.raw_data is not None:\n", - " # create a new figure from the raw data\n", - " fig = bar(result.raw_data, x=\"Variable\", y=\"Total Count of Outliers\")\n", - "\n", " result.add_figure(\n", " Figure(\n", " figure=fig,\n", @@ -268,10 +259,28 @@ " \"validmind.data_validation.IQROutliersBarPlot\",\n", " inputs={\"dataset\": vm_test_ds},\n", " generate_description=False,\n", - " post_process_fn=create_table,\n", + " # post_process_fn=create_table,\n", ")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "raise Exception(\"stop\")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -325,6 +334,144 @@ " post_process_fn=re_draw_class_imbalance,\n", ")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result = run_test(\n", + " \"validmind.data_validation.ClassImbalance\",\n", + " inputs={\"dataset\": vm_test_ds},\n", + " generate_description=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def post_process_class_imbalance(result: TestResult):\n", + " result.passed = None\n", + " result.figures = []\n", + "\n", + " return result\n", + "\n", + "\n", + "result = run_test(\n", + " \"validmind.data_validation.ClassImbalance\",\n", + " inputs={\"dataset\": vm_test_ds},\n", + " generate_description=False,\n", + " post_process_fn=post_process_class_imbalance,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result = run_test(\n", + " \"validmind.model_validation.sklearn.ROCCurve\",\n", + " inputs={\"dataset\": vm_test_ds, \"model\": vm_model},\n", + " generate_description=False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def post_process_roc_curve(result: TestResult):\n", + " result.raw_data.fpr" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "import pandas as pd\n", + "import numpy as np\n", + "from plotly_express import bar\n", + "from validmind.vm_models.figure import Figure\n", + "from validmind.vm_models.result import TestResult\n", + "import plotly.graph_objects as go\n", + "\n", + "\n", + "@vm.test(\"my_custom_tests.Sensitivity\")\n", + "def sensitivity_test(strike=None):\n", + " \"\"\"This is sensitivity test\"\"\"\n", + " price = strike * random.random()\n", + "\n", + " return pd.DataFrame({\"Option price\": [price]})\n", + "\n", + "\n", + "def process_results(result: TestResult):\n", + "\n", + " df = pd.DataFrame(result.tables[0].data)\n", + "\n", + " fig = go.Figure()\n", + "\n", + " fig.add_trace(\n", + " go.Scatter(x=df[\"strike\"].values, y=df[\"Option price\"].values, mode=\"lines\")\n", + " )\n", + "\n", + " fig.update_layout(\n", + " # title=params[\"title\"],\n", + " # xaxis_title=params[\"xlabel\"],\n", + " # yaxis_title=params[\"ylabel\"],\n", + " showlegend=True,\n", + " template=\"plotly_white\", # Adds a grid by default\n", + " )\n", + "\n", + " result.add_figure(\n", + " Figure(\n", + " figure=fig,\n", + " key=\"sensitivity_to_strike\",\n", + " ref_id=result.ref_id,\n", + " )\n", + " )\n", + "\n", + " return result\n", + "\n", + "\n", + "result = run_test(\n", + " \"my_custom_tests.Sensitivity:ToStrike\",\n", + " param_grid={\n", + " \"strike\": list(np.linspace(0, 100, 20)),\n", + " },\n", + " post_process_fn=process_results,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from validmind.tests import list_tests\n", + "\n", + "list_tests()" + ] } ], "metadata": { diff --git a/validmind/tests/load.py b/validmind/tests/load.py index cf756fa05..60df24862 100644 --- a/validmind/tests/load.py +++ b/validmind/tests/load.py @@ -185,7 +185,7 @@ def list_tags(): unique_tags = set() - for test in _load_tests(list_tests(pretty=False)): + for test in _load_tests(list_tests(pretty=False)).values(): unique_tags.update(test.__tags__) return list(unique_tags) @@ -201,7 +201,7 @@ def list_tasks_and_tags(): """ task_tags_dict = {} - for test in _load_tests(list_tests(pretty=False)): + for test in _load_tests(list_tests(pretty=False)).values(): for task in test.__tasks__: task_tags_dict.setdefault(task, set()).update(test.__tags__) @@ -222,7 +222,7 @@ def list_tasks(): unique_tasks = set() - for test in _load_tests(list_tests(pretty=False)): + for test in _load_tests(list_tests(pretty=False)).values(): unique_tasks.update(test.__tasks__) return list(unique_tasks) diff --git a/validmind/tests/model_validation/ragas/Faithfulness.py b/validmind/tests/model_validation/ragas/Faithfulness.py index e5331f559..f6b8363cb 100644 --- a/validmind/tests/model_validation/ragas/Faithfulness.py +++ b/validmind/tests/model_validation/ragas/Faithfulness.py @@ -123,6 +123,7 @@ def Faithfulness( fig_box = px.box(x=result_df[score_column].to_list()) return ( + RawData(scores=result_df), { # "Scores (will not be uploaded to ValidMind Platform)": result_df[ # ["retrieved_contexts", "response", "faithfulness"] diff --git a/validmind/vm_models/figure.py b/validmind/vm_models/figure.py index c8eeca5bf..db25e5627 100644 --- a/validmind/vm_models/figure.py +++ b/validmind/vm_models/figure.py @@ -55,6 +55,9 @@ def __post_init__(self): ): self.figure = go.FigureWidget(self.figure) + def __repr__(self): + return f"Figure(key={self.key}, ref_id={self.ref_id})" + def to_widget(self): """ Returns the ipywidget compatible representation of the figure. Ideally diff --git a/validmind/vm_models/result/result.py b/validmind/vm_models/result/result.py index 7d7247557..e3c40f8e9 100644 --- a/validmind/vm_models/result/result.py +++ b/validmind/vm_models/result/result.py @@ -147,7 +147,6 @@ class TestResult(Result): params: Optional[Dict[str, Any]] = None inputs: Optional[Dict[str, Union[List[VMInput], VMInput]]] = None metadata: Optional[Dict[str, Any]] = None - title: Optional[str] = None _was_description_generated: bool = False _unsafe: bool = False @@ -168,6 +167,11 @@ def __repr__(self) -> str: "passed", ] if getattr(self, attr) is not None + and ( + len(getattr(self, attr)) > 0 + if isinstance(getattr(self, attr), list) + else True + ) ] return f'TestResult("{self.result_id}", {", ".join(attrs)})' From 3e1c596078ef069015c17468d7bffebb30837c15 Mon Sep 17 00:00:00 2001 From: John Walz Date: Thu, 12 Dec 2024 13:35:46 -0500 Subject: [PATCH 4/8] feat: notebook updates --- .../post_processing_functions.ipynb | 167 ++++++++++-------- 1 file changed, 90 insertions(+), 77 deletions(-) diff --git a/notebooks/code_sharing/post_processing_functions.ipynb b/notebooks/code_sharing/post_processing_functions.ipynb index 1a597047a..c33f65060 100644 --- a/notebooks/code_sharing/post_processing_functions.ipynb +++ b/notebooks/code_sharing/post_processing_functions.ipynb @@ -1,5 +1,58 @@ { "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Post-Processing Functions in ValidMind\n", + "\n", + "Welcome! This notebook demonstrates how to use post-processing functions with ValidMind tests to customize test outputs. You'll learn various ways to modify test results including updating tables, adding/removing tables, creating figures from tables, and vice versa.\n", + "\n", + "## Contents\n", + "- [About Post-Processing Functions](#about-post-processing-functions)\n", + "- [Key Concepts](#key-concepts)\n", + "- [Setup and Prerequisites](#setup-and-prerequisites)\n", + "- [Simple Tabular Updates](#simple-tabular-updates)\n", + "- [Adding Tables](#adding-tables) \n", + "- [Removing Tables](#removing-tables)\n", + "- [Creating Figures from Tables](#creating-figures-from-tables)\n", + "- [Creating Tables from Figures](#creating-tables-from-figures)\n", + "- [Re-Drawing Confusion Matrix](#re-drawing-confusion-matrix)\n", + "- [Re-Drawing ROC Curve](#re-drawing-roc-curve)\n", + "- [Custom Test Example](#custom-test-example)\n", + "\n", + "## About Post-Processing Functions\n", + "\n", + "Post-processing functions allow you to customize the output of ValidMind tests before they are saved to the platform. These functions take a TestResult object as input and return a modified TestResult object.\n", + "\n", + "Common use cases include:\n", + "- Reformatting table data\n", + "- Adding or removing tables/figures\n", + "- Creating new visualizations from test data\n", + "- Customizing test pass/fail criteria\n", + "\n", + "### Key Concepts\n", + "\n", + "**TestResult object**: The main object that post-processing functions work with, containing:\n", + "- tables: List of ResultTable objects\n", + "- figures: List of Figure objects \n", + "- passed: Boolean indicating test pass/fail status\n", + "- raw_data: Additional data from test execution\n", + "\n", + "**ResultTable**: Object representing tabular data with:\n", + "- title: Table title\n", + "- data: Pandas DataFrame or list of dictionaries\n", + "\n", + "**Figure**: Object representing plots/visualizations with:\n", + "- figure: matplotlib or plotly figure object\n", + "- key: Unique identifier\n", + "- ref_id: Reference ID linking to test\n", + "\n", + "## Setup and Prerequisites\n", + "\n", + "First, we'll set up our environment and load sample data using the customer churn dataset:" + ] + }, { "cell_type": "code", "execution_count": null, @@ -63,6 +116,13 @@ ")" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As a refresher, here is how we run a test normally, without any post-processing:" + ] + }, { "cell_type": "code", "execution_count": null, @@ -82,14 +142,16 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Post-processing functions" + "## Post-processing functions" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Simple Tabular Updates" + "### Simple Tabular Updates\n", + "\n", + "The simplest form of post-processing is modifying existing table data. Here we demonstrate updating class labels in a classification performance table:" ] }, { @@ -123,7 +185,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Adding Tables" + "### Adding Tables\n", + "\n", + "Sometimes you may want to add supplementary tables to provide additional context or information. This example shows how to add a legend table mapping class values to labels:" ] }, { @@ -161,7 +225,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Removing Tables" + "### Removing Tables \n", + "\n", + "You can also remove tables that may not be relevant for your use case. Here we demonstrate removing a specific table from the results:" ] }, { @@ -188,7 +254,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Creating Figure from Tables" + "### Creating Figures from Tables\n", + "\n", + "A powerful use of post-processing is creating visualizations from tabular data. This example shows creating a bar plot from an outliers table:" ] }, { @@ -227,7 +295,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Creating Tables from Figures" + "### Creating Tables from Figures\n", + "\n", + "The reverse operation - extracting tabular data from figures - is also possible. Here we demonstrate creating a table from figure data:" ] }, { @@ -263,29 +333,13 @@ ")" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "result" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "raise Exception(\"stop\")" - ] - }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Re-Draw Confusion Matrix" + "### Re-Drawing Confusion Matrix\n", + "\n", + "Sometimes you may want to completely replace the default visualizations. This example shows how to redraw a confusion matrix using matplotlib:" ] }, { @@ -336,16 +390,12 @@ ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "result = run_test(\n", - " \"validmind.data_validation.ClassImbalance\",\n", - " inputs={\"dataset\": vm_test_ds},\n", - " generate_description=False,\n", - ")" + "### Re-Drawing ROC Curve\n", + "\n", + "Here is another example of re-drawing a figure. This time we are re-drawing the ROC curve:" ] }, { @@ -354,51 +404,25 @@ "metadata": {}, "outputs": [], "source": [ - "def post_process_class_imbalance(result: TestResult):\n", - " result.passed = None\n", - " result.figures = []\n", - "\n", - " return result\n", + "def post_process_roc_curve(result: TestResult):\n", + " result.raw_data.fpr\n", "\n", "\n", - "result = run_test(\n", - " \"validmind.data_validation.ClassImbalance\",\n", - " inputs={\"dataset\": vm_test_ds},\n", - " generate_description=False,\n", - " post_process_fn=post_process_class_imbalance,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "result" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ "result = run_test(\n", " \"validmind.model_validation.sklearn.ROCCurve\",\n", " inputs={\"dataset\": vm_test_ds, \"model\": vm_model},\n", " generate_description=False,\n", + " post_process_fn=post_process_roc_curve,\n", ")" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "def post_process_roc_curve(result: TestResult):\n", - " result.raw_data.fpr" + "### Custom Test Example\n", + "\n", + "While we envision that post-processing functions are most useful for modifying built-in (ValidMind Library) tests, there are cases where you may want to use them for your own custom tests. Let's see an example of a situation where this is the case:" ] }, { @@ -461,17 +485,6 @@ " post_process_fn=process_results,\n", ")" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from validmind.tests import list_tests\n", - "\n", - "list_tests()" - ] } ], "metadata": { From 5c442e8f6b1c227becdd031ac5fc7b2c07368e8b Mon Sep 17 00:00:00 2001 From: John Walz Date: Thu, 12 Dec 2024 13:36:48 -0500 Subject: [PATCH 5/8] feat: remove raw data from faithfulness --- validmind/tests/model_validation/ragas/Faithfulness.py | 1 - 1 file changed, 1 deletion(-) diff --git a/validmind/tests/model_validation/ragas/Faithfulness.py b/validmind/tests/model_validation/ragas/Faithfulness.py index f6b8363cb..e5331f559 100644 --- a/validmind/tests/model_validation/ragas/Faithfulness.py +++ b/validmind/tests/model_validation/ragas/Faithfulness.py @@ -123,7 +123,6 @@ def Faithfulness( fig_box = px.box(x=result_df[score_column].to_list()) return ( - RawData(scores=result_df), { # "Scores (will not be uploaded to ValidMind Platform)": result_df[ # ["retrieved_contexts", "response", "faithfulness"] From 017315828c58ed12b057b0c6d797ca47b53163f0 Mon Sep 17 00:00:00 2001 From: John Walz Date: Thu, 12 Dec 2024 14:07:07 -0500 Subject: [PATCH 6/8] fix: fixing workflow for patching main with prod --- .github/workflows/prod_patches_to_main.yaml | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/.github/workflows/prod_patches_to_main.yaml b/.github/workflows/prod_patches_to_main.yaml index 29fc41ac5..7379b7d2e 100644 --- a/.github/workflows/prod_patches_to_main.yaml +++ b/.github/workflows/prod_patches_to_main.yaml @@ -8,9 +8,9 @@ on: workflow_dispatch: inputs: custom_branch_name: - description: "Custom Branch Name (optional)" + description: 'Custom Branch Name (optional)' required: false - default: "" + default: '' jobs: release: @@ -19,16 +19,15 @@ jobs: - name: Checkout Prod Branch uses: actions/checkout@v3 with: - ref: "prod" + ref: 'prod' - name: Install poetry run: pipx install poetry - - name: Set up Python 3.8 + - name: Set up Python 3.11 uses: actions/setup-python@v3 with: - python-version: "3.8" - cache: "poetry" + python-version: '3.11' - name: Get Application Version id: get_version @@ -47,7 +46,7 @@ jobs: - name: Checkout Main Branch uses: actions/checkout@v3 with: - ref: "main" + ref: 'main' fetch-depth: 0 - name: Setup Git Config From 2a193df13577a0844ba291ef4e5cbc9d935694e0 Mon Sep 17 00:00:00 2001 From: John Walz Date: Thu, 12 Dec 2024 15:42:36 -0500 Subject: [PATCH 7/8] feat: adding more examples to the post processing notebooks and helper functions for adding and removing tables and figures --- .../post_processing_functions.ipynb | 316 ++++++++++++------ validmind/utils.py | 11 + validmind/vm_models/figure.py | 12 + validmind/vm_models/result/result.py | 92 ++++- 4 files changed, 329 insertions(+), 102 deletions(-) diff --git a/notebooks/code_sharing/post_processing_functions.ipynb b/notebooks/code_sharing/post_processing_functions.ipynb index c33f65060..3e8724eaf 100644 --- a/notebooks/code_sharing/post_processing_functions.ipynb +++ b/notebooks/code_sharing/post_processing_functions.ipynb @@ -33,21 +33,26 @@ "\n", "### Key Concepts\n", "\n", - "**TestResult object**: The main object that post-processing functions work with, containing:\n", - "- tables: List of ResultTable objects\n", - "- figures: List of Figure objects \n", - "- passed: Boolean indicating test pass/fail status\n", - "- raw_data: Additional data from test execution\n", - "\n", - "**ResultTable**: Object representing tabular data with:\n", - "- title: Table title\n", - "- data: Pandas DataFrame or list of dictionaries\n", - "\n", - "**Figure**: Object representing plots/visualizations with:\n", - "- figure: matplotlib or plotly figure object\n", - "- key: Unique identifier\n", - "- ref_id: Reference ID linking to test\n", - "\n", + "**`validmind.vm_models.result.TestResult`**: Whenever a test is run with the `run_test` function in ValidMind, the items returned/produced by the test are bundled into a single `TestResult` object. There are several attributes on this object that are useful to know about:\n", + "- `tables`: List of `validmind.vm_models.result.ResultTable` objects (see below)\n", + "- `figures`: List of `validmind.vm_models.figure.Figure` objects (see below)\n", + "- `passed`: Optional boolean indicating test pass/fail status. `None` indicates that the test is not a pass/fail test (previously known as a threshold test).\n", + "- `raw_data`: Optional `validmind.vm_models.result.RawData` object containing additional data from test execution. Some ValidMind tests will produce this raw data to be used in post-processing functions. This data is not displayed in the test result or sent to the ValidMind platform (currently). To view the available raw data, you can run `result.raw_data.inspect()` which will return a dictionary where the keys are the raw data attributes available and the values are string representations of the data.\n", + "\n", + "**`validmind.vm_models.result.ResultTable`**: ValidMind object representing tables displayed in the test result and sent to the ValidMind platform:\n", + "- `title`: Optional table title\n", + "- `data`: Pandas dataframe\n", + "\n", + "**`validmind.vm_models.figure.Figure`**: ValidMind object representing plots/visualizations displayed in the test result and sent to the ValidMind platform:\n", + "- `figure`: matplotlib or plotly figure object\n", + "- `key`: Unique identifier\n", + "- `ref_id`: Reference ID linking to test" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ "## Setup and Prerequisites\n", "\n", "First, we'll set up our environment and load sample data using the customer churn dataset:" @@ -63,8 +68,6 @@ "import validmind as vm\n", "from validmind.datasets.classification import customer_churn\n", "\n", - "vm.init()\n", - "\n", "raw_df = customer_churn.load_data()\n", "\n", "train_df, validation_df, test_df = customer_churn.preprocess(raw_df)\n", @@ -90,21 +93,27 @@ " input_id=\"raw_dataset\",\n", " target_column=customer_churn.target_column,\n", " class_labels=customer_churn.class_labels,\n", + " __log=False,\n", ")\n", "\n", "vm_train_ds = vm.init_dataset(\n", " dataset=train_df,\n", " input_id=\"train_dataset\",\n", " target_column=customer_churn.target_column,\n", + " __log=False,\n", ")\n", "\n", "vm_test_ds = vm.init_dataset(\n", - " dataset=test_df, input_id=\"test_dataset\", target_column=customer_churn.target_column\n", + " dataset=test_df,\n", + " input_id=\"test_dataset\",\n", + " target_column=customer_churn.target_column,\n", + " __log=False,\n", ")\n", "\n", "vm_model = vm.init_model(\n", " model,\n", " input_id=\"model\",\n", + " __log=False,\n", ")\n", "\n", "vm_train_ds.assign_predictions(\n", @@ -151,7 +160,14 @@ "source": [ "### Simple Tabular Updates\n", "\n", - "The simplest form of post-processing is modifying existing table data. Here we demonstrate updating class labels in a classification performance table:" + "The simplest form of post-processing is modifying existing table data. Here we demonstrate updating class labels in a classification performance table.\n", + "\n", + "Some key concepts to keep in mind:\n", + "- Tables produced by a test are accessible via the `result.tables` attribute\n", + " - The `result.tables` attribute is a list of `ResultTable` objects which are simple data structures that contain a `data` attribute and an optional `title` attribute\n", + " - The `data` attribute is guaranteed to be a `pd.DataFrame` whether the test code itself returns a `pd.DataFrame` or a list of dictionaries\n", + " - The `title` attribute is optional and can be set by tests that return a dictionary where the keys are the table titles and the values are the table data (e.g. `{\"Classifier Performance\": performance_df, \"Class Legend\": [{\"Class Value\": \"0\", \"Class Label\": \"No Churn\"}, {\"Class Value\": \"1\", \"Class Label\": \"Churn\"}]}}`)\n", + "- Post-processing functions can directly modify any of the tables in the `result.tables` list and return the modified `TestResult` object... This can be useful for renaming columns, adding/removing rows, etc." ] }, { @@ -187,7 +203,7 @@ "source": [ "### Adding Tables\n", "\n", - "Sometimes you may want to add supplementary tables to provide additional context or information. This example shows how to add a legend table mapping class values to labels:" + "This example shows how to add a legend table mapping class values to labels using the `TestResult.add_table()` method:" ] }, { @@ -196,18 +212,14 @@ "metadata": {}, "outputs": [], "source": [ - "from validmind.vm_models.result import ResultTable\n", - "\n", "def add_table(result: TestResult):\n", " # add legend table to show map of class value to class label\n", " result.add_table(\n", - " ResultTable(\n", - " title=\"Class Legend\",\n", - " data=[\n", - " {\"Class Value\": \"0\", \"Class Label\": \"No Churn\"},\n", - " {\"Class Value\": \"1\", \"Class Label\": \"Churn\"},\n", - " ],\n", - " )\n", + " title=\"Class Legend\",\n", + " table=[\n", + " {\"Class Value\": \"0\", \"Class Label\": \"No Churn\"},\n", + " {\"Class Value\": \"1\", \"Class Label\": \"Churn\"},\n", + " ],\n", " )\n", "\n", " return result\n", @@ -227,7 +239,7 @@ "source": [ "### Removing Tables \n", "\n", - "You can also remove tables that may not be relevant for your use case. Here we demonstrate removing a specific table from the results:" + "If there are tables in the test result that you don't want to display or log to the ValidMind platform, you can remove them using the `TestResult.remove_table()` method." ] }, { @@ -237,7 +249,7 @@ "outputs": [], "source": [ "def remove_table(result: TestResult):\n", - " result.tables.pop(1)\n", + " result.remove_table(1)\n", "\n", " return result\n", "\n", @@ -256,7 +268,7 @@ "source": [ "### Creating Figures from Tables\n", "\n", - "A powerful use of post-processing is creating visualizations from tabular data. This example shows creating a bar plot from an outliers table:" + "A powerful use of post-processing is creating visualizations from tabular data. This example shows creating a bar plot from an outliers table using the `TestResult.add_figure()` method. This method can take a `matplotlib`, `plotly`, raw PNG `bytes`, or a `validmind.vm_models.figure.Figure` object." ] }, { @@ -266,18 +278,11 @@ "outputs": [], "source": [ "from plotly_express import bar\n", - "from validmind.vm_models.figure import Figure\n", "\n", "\n", "def create_figure(result: TestResult):\n", - " fig = bar(result.tables[0].data, x=\"Variable\", y=\"Total Count of Outliers\")\n", - "\n", " result.add_figure(\n", - " Figure(\n", - " figure=fig,\n", - " key=\"outlier_count_by_variable\",\n", - " ref_id=result.ref_id,\n", - " )\n", + " bar(result.tables[0].data, x=\"Variable\", y=\"Total Count of Outliers\")\n", " )\n", "\n", " return result\n", @@ -297,7 +302,7 @@ "source": [ "### Creating Tables from Figures\n", "\n", - "The reverse operation - extracting tabular data from figures - is also possible. Here we demonstrate creating a table from figure data:" + "The reverse operation - extracting tabular data from figures - is also possible. However, its recommended instead to use the raw data produced by the test (assuming it is available) as the approach below requires deeper knowledge of the underlying figure (e.g. `matplotlib` or `plotly`) and may not be as robust/maintainable." ] }, { @@ -310,16 +315,12 @@ " for fig in result.figures:\n", " data = fig.figure.data[0]\n", "\n", - " table_data = [\n", - " {\"Percentile\": x, \"Outlier Count\": y}\n", - " for x, y in zip(data.x, data.y)\n", - " ]\n", - "\n", " result.add_table(\n", - " ResultTable(\n", - " title=fig.figure.layout.title.text,\n", - " data=table_data,\n", - " )\n", + " title=fig.figure.layout.title.text,\n", + " table=[\n", + " {\"Percentile\": x, \"Outlier Count\": y}\n", + " for x, y in zip(data.x, data.y)\n", + " ],\n", " )\n", "\n", " return result\n", @@ -329,7 +330,7 @@ " \"validmind.data_validation.IQROutliersBarPlot\",\n", " inputs={\"dataset\": vm_test_ds},\n", " generate_description=False,\n", - " # post_process_fn=create_table,\n", + " post_process_fn=create_table,\n", ")" ] }, @@ -339,7 +340,7 @@ "source": [ "### Re-Drawing Confusion Matrix\n", "\n", - "Sometimes you may want to completely replace the default visualizations. This example shows how to redraw a confusion matrix using matplotlib:" + "A less common example is re-drawing a figure. This example uses the table produced by the test to create a matplotlib confusion matrix figure and removes the existing plotly figure." ] }, { @@ -353,29 +354,22 @@ "\n", "def re_draw_class_imbalance(result: TestResult):\n", " data = result.tables[0].data\n", - " # Exited Percentage of Rows (%) Pass/Fail\n", - " # 0 0 80.25% Pass\n", - " # 1 1 19.75% Pass\n", "\n", - " result.figures = []\n", + " # remove the existing figure\n", + " result.remove_figure(0)\n", "\n", " # use matplotlib to plot the confusion matrix\n", " fig = plt.figure()\n", "\n", - " # show a bar plot of the class imbalance with matplotlib\n", " plt.bar(data[\"Exited\"], data[\"Percentage of Rows (%)\"])\n", " plt.xlabel(\"Exited\")\n", " plt.ylabel(\"Percentage of Rows (%)\")\n", " plt.title(\"Class Imbalance\")\n", "\n", - " result.add_figure(\n", - " Figure(\n", - " figure=fig,\n", - " key=\"confusion_matrix\",\n", - " ref_id=result.ref_id,\n", - " )\n", - " )\n", + " # add the figure to the result\n", + " result.add_figure(fig)\n", "\n", + " # close the figure to avoid showing it in the test result\n", " plt.close()\n", "\n", " return result\n", @@ -395,7 +389,51 @@ "source": [ "### Re-Drawing ROC Curve\n", "\n", - "Here is another example of re-drawing a figure. This time we are re-drawing the ROC curve:" + "This example shows re-drawing the ROC curve using the raw data produced by the test. This is the recommended approach to reproducing figures or tables from test results as it allows you to get intermediate and other raw data that was originally used by the test to produce the figures or tables we want to reproduce." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, let's run the test without post-processing to see the original result." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# run the test without post-processing\n", + "result = run_test(\n", + " \"validmind.model_validation.sklearn.ROCCurve\",\n", + " inputs={\"dataset\": vm_test_ds, \"model\": vm_model},\n", + " generate_description=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we have a `TestResult` object, we can inspect the raw data to see what is available." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result.raw_data.inspect()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we know what is available in the raw data, we can build a post-processing function that uses this raw data to reproduce the ROC curve." ] }, { @@ -405,7 +443,28 @@ "outputs": [], "source": [ "def post_process_roc_curve(result: TestResult):\n", - " result.raw_data.fpr\n", + " fpr = result.raw_data.fpr\n", + " tpr = result.raw_data.tpr\n", + " auc = result.raw_data.auc\n", + "\n", + " # remove the existing figure\n", + " result.remove_figure(0)\n", + "\n", + " # use matplotlib to plot the ROC curve\n", + " fig = plt.figure()\n", + "\n", + " plt.plot(fpr, tpr, label=f\"ROC Curve (AUC = {auc:.2f})\")\n", + " plt.xlabel(\"False Positive Rate\")\n", + " plt.ylabel(\"True Positive Rate\")\n", + " plt.title(\"ROC Curve\")\n", + "\n", + " plt.legend()\n", + "\n", + " plt.close()\n", + "\n", + " result.add_figure(fig)\n", + "\n", + " return result\n", "\n", "\n", "result = run_test(\n", @@ -422,7 +481,7 @@ "source": [ "### Custom Test Example\n", "\n", - "While we envision that post-processing functions are most useful for modifying built-in (ValidMind Library) tests, there are cases where you may want to use them for your own custom tests. Let's see an example of a situation where this is the case:" + "While we envision that post-processing functions are most useful for modifying built-in (ValidMind Library) tests, there are also cases where you may want to use them for your own custom tests. Let's see an example of this." ] }, { @@ -431,58 +490,121 @@ "metadata": {}, "outputs": [], "source": [ - "import random\n", "import pandas as pd\n", "import numpy as np\n", - "from plotly_express import bar\n", - "from validmind.vm_models.figure import Figure\n", - "from validmind.vm_models.result import TestResult\n", - "import plotly.graph_objects as go\n", + "from validmind import test\n", + "from validmind.tests import run_test\n", "\n", "\n", - "@vm.test(\"my_custom_tests.Sensitivity\")\n", - "def sensitivity_test(strike=None):\n", - " \"\"\"This is sensitivity test\"\"\"\n", - " price = strike * random.random()\n", + "@test(\"custom.CorrelationBetweenVariables\")\n", + "def CorrelationBetweenVariables(var1: str, var2: str):\n", + " \"\"\"This fake test shows the relationship between two variables\"\"\"\n", + " data = pd.DataFrame(\n", + " {\n", + " \"Variable 1\": np.random.rand(20),\n", + " \"Variable 2\": np.random.rand(20),\n", + " }\n", + " )\n", "\n", - " return pd.DataFrame({\"Option price\": [price]})\n", + " return [{\"Correlation between var1 and var2\": data.corr().iloc[0, 1]}]\n", "\n", "\n", - "def process_results(result: TestResult):\n", + "variables = [\"Age\", \"Balance\", \"CreditScore\", \"EstimatedSalary\"]\n", + "\n", + "result = run_test(\n", + " \"custom.CorrelationBetweenVariables\",\n", + " param_grid={\n", + " \"var1\": variables,\n", + " \"var2\": variables,\n", + " }, # this will automatically generate all combinations of variables for var1 and var2\n", + " generate_description=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As you can see, the test result now contains a table with the correlation between each pair of variables like this:\n", + "\n", + "| var1 | var2 | Correlation between var1 and var2 |\n", + "|------|------|-----------------------------------|\n", + "| Age | Age | 0.3001 |\n", + "| Age | Balance | -0.4185 |\n", + "| Age | CreditScore | 0.2952 |\n", + "| Age | EstimatedSalary | -0.2855 |\n", + "| Balance | Age | 0.0141 |\n", + "| Balance | Balance | -0.1513 |\n", + "| Balance | CreditScore | 0.2401 |\n", + "| Balance | EstimatedSalary | 0.1198 |\n", + "| CreditScore | Age | -0.2320 |\n", + "| CreditScore | Balance | 0.4125 |\n", + "| CreditScore | CreditScore | 0.1726 |\n", + "| CreditScore | EstimatedSalary | 0.3187 |\n", + "| EstimatedSalary | Age | -0.1774 |\n", + "| EstimatedSalary | Balance | -0.1202 |\n", + "| EstimatedSalary | CreditScore | 0.1488 |\n", + "| EstimatedSalary | EstimatedSalary | 0.0524 |\n", + "\n", + "Now let's say we don't really want to see the big table of correlations. Instead, we want to see a heatmap of the correlations. We can use a post-processing function to create a heatmap from the table and add it to the test result while removing the table." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import plotly.graph_objects as go\n", "\n", - " df = pd.DataFrame(result.tables[0].data)\n", "\n", - " fig = go.Figure()\n", + "def create_heatmap(result: TestResult):\n", + " # get the data from the existing table\n", + " data = result.tables[0].data\n", "\n", - " fig.add_trace(\n", - " go.Scatter(x=df[\"strike\"].values, y=df[\"Option price\"].values, mode=\"lines\")\n", + " # remove the existing table\n", + " result.remove_table(0)\n", + " \n", + " # Create a pivot table from the data to get it in matrix form\n", + " matrix = pd.pivot_table(\n", + " data,\n", + " values='Correlation between var1 and var2',\n", + " index='var1',\n", + " columns='var2'\n", " )\n", "\n", + " # remove the existing figure \n", + " result.remove_figure(0)\n", + "\n", + " # use plotly to create a heatmap\n", + " fig = go.Figure(data=go.Heatmap(\n", + " z=matrix.values,\n", + " x=matrix.columns,\n", + " y=matrix.index,\n", + " colorscale='RdBu',\n", + " zmid=0, # Center the color scale at 0\n", + " ))\n", + "\n", " fig.update_layout(\n", - " # title=params[\"title\"],\n", - " # xaxis_title=params[\"xlabel\"],\n", - " # yaxis_title=params[\"ylabel\"],\n", - " showlegend=True,\n", - " template=\"plotly_white\", # Adds a grid by default\n", + " title=\"Correlation Heatmap\",\n", + " xaxis_title=\"Variable\",\n", + " yaxis_title=\"Variable\",\n", " )\n", "\n", - " result.add_figure(\n", - " Figure(\n", - " figure=fig,\n", - " key=\"sensitivity_to_strike\",\n", - " ref_id=result.ref_id,\n", - " )\n", - " )\n", + " # add the figure to the result\n", + " result.add_figure(fig)\n", "\n", " return result\n", "\n", "\n", "result = run_test(\n", - " \"my_custom_tests.Sensitivity:ToStrike\",\n", + " \"custom.CorrelationBetweenVariables\",\n", " param_grid={\n", - " \"strike\": list(np.linspace(0, 100, 20)),\n", + " \"var1\": variables,\n", + " \"var2\": variables,\n", " },\n", - " post_process_fn=process_results,\n", + " generate_description=False,\n", + " post_process_fn=create_heatmap,\n", ")" ] } diff --git a/validmind/utils.py b/validmind/utils.py index a2de3584e..42b3ec75b 100644 --- a/validmind/utils.py +++ b/validmind/utils.py @@ -168,6 +168,17 @@ def iterencode(self, obj, _one_shot: bool = ...): return super().iterencode(obj, _one_shot) +class HumanReadableEncoder(NumpyEncoder): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # truncate ndarrays to 10 items + self.type_handlers[self.is_numpy_ndarray] = lambda obj: ( + obj.tolist()[:5] + ["..."] + obj.tolist()[-5:] + if len(obj) > 10 + else obj.tolist() + ) + + def get_full_typename(o: Any) -> Any: """We determine types based on type names so we don't have to import (and therefore depend on) PyTorch, TensorFlow, etc. diff --git a/validmind/vm_models/figure.py b/validmind/vm_models/figure.py index db25e5627..d843889b8 100644 --- a/validmind/vm_models/figure.py +++ b/validmind/vm_models/figure.py @@ -33,6 +33,18 @@ def is_png_image(figure) -> bool: return isinstance(figure, bytes) +def create_figure( + figure: Union[matplotlib.figure.Figure, go.Figure, go.FigureWidget, bytes], + key: str, + ref_id: str, +) -> "Figure": + """Create a VM Figure object from a raw figure object""" + if is_matplotlib_figure(figure) or is_plotly_figure(figure) or is_png_image(figure): + return Figure(key=key, figure=figure, ref_id=ref_id) + + raise ValueError(f"Unsupported figure type: {type(figure)}") + + @dataclass class Figure: """ diff --git a/validmind/vm_models/result/result.py b/validmind/vm_models/result/result.py index e3c40f8e9..8cea3641b 100644 --- a/validmind/vm_models/result/result.py +++ b/validmind/vm_models/result/result.py @@ -12,14 +12,22 @@ from typing import Any, Dict, List, Optional, Union from uuid import uuid4 +import matplotlib import pandas as pd +import plotly.graph_objs as go from ipywidgets import HTML, VBox from ... import api_client from ...ai.utils import DescriptionFuture from ...logging import get_logger -from ...utils import NumpyEncoder, display, run_async, test_id_to_name -from ..figure import Figure +from ...utils import ( + HumanReadableEncoder, + NumpyEncoder, + display, + run_async, + test_id_to_name, +) +from ..figure import Figure, create_figure from ..input import VMInput from .utils import ( AI_REVISION_NAME, @@ -53,6 +61,19 @@ def __init__(self, log: bool = False, **kwargs): def __repr__(self) -> str: return f"RawData({', '.join(self.__dict__.keys())})" + def inspect(self, show: bool = True): + """Inspect the raw data""" + raw_data = { + key: getattr(self, key) + for key in self.__dict__ + if not key.startswith("_") and key != "log" + } + + if not show: + return raw_data + + print(json.dumps(raw_data, indent=2, cls=HumanReadableEncoder)) + def serialize(self): return {key: getattr(self, key) for key in self.__dict__} @@ -64,7 +85,7 @@ class ResultTable: """ data: Union[List[Any], pd.DataFrame] - title: str + title: Optional[str] = None def __repr__(self) -> str: return f'ResultTable(title="{self.title}")' if self.title else "ResultTable" @@ -192,21 +213,82 @@ def _get_flat_inputs(self): return list(inputs.values()) - def add_table(self, table: ResultTable): + def add_table( + self, + table: Union[ResultTable, pd.DataFrame, List[Dict[str, Any]]], + title: Optional[str] = None, + ): + """Add a new table to the result + + Args: + table (Union[ResultTable, pd.DataFrame, List[Dict[str, Any]]]): The table to add + title (Optional[str]): The title of the table (can optionally be provided for + pd.DataFrame and List[Dict[str, Any]] tables) + """ if self.tables is None: self.tables = [] + if isinstance(table, (pd.DataFrame, list)): + table = ResultTable(data=table, title=title) + self.tables.append(table) - def add_figure(self, figure: Figure): + def remove_table(self, index: int): + """Remove a table from the result by index + + Args: + index (int): The index of the table to remove (default is 0) + """ + if self.tables is None: + return + + self.tables.pop(index) + + def add_figure( + self, + figure: Union[ + matplotlib.figure.Figure, + go.Figure, + go.FigureWidget, + bytes, + Figure, + ], + ): + """Add a new figure to the result + + Args: + figure (Union[matplotlib.figure.Figure, go.Figure, go.FigureWidget, + bytes, Figure]): The figure to add (can be either a VM Figure object, + a raw figure object from the supported libraries, or a png image as + raw bytes) + """ if self.figures is None: self.figures = [] + if not isinstance(figure, Figure): + random_tag = str(uuid4())[:4] + figure = create_figure( + figure=figure, + ref_id=self.ref_id, + key=f"{self.result_id}:{random_tag}", + ) + if figure.ref_id != self.ref_id: figure.ref_id = self.ref_id self.figures.append(figure) + def remove_figure(self, index: int = 0): + """Remove a figure from the result by index + + Args: + index (int): The index of the figure to remove (default is 0) + """ + if self.figures is None: + return + + self.figures.pop(index) + def to_widget(self): if isinstance(self.description, DescriptionFuture): self.description = self.description.get_description() From 714df4680accd9a95e2da297e426eb72079fe00f Mon Sep 17 00:00:00 2001 From: John Walz Date: Thu, 12 Dec 2024 15:44:41 -0500 Subject: [PATCH 8/8] 2.6.11 --- pyproject.toml | 2 +- validmind/__version__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8588d6395..baeee3750 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ description = "ValidMind Library" license = "Commercial License" name = "validmind" readme = "README.pypi.md" -version = "2.6.10" +version = "2.6.11" [tool.poetry.dependencies] aiohttp = {extras = ["speedups"], version = "*"} diff --git a/validmind/__version__.py b/validmind/__version__.py index fcfff714f..396adecc4 100644 --- a/validmind/__version__.py +++ b/validmind/__version__.py @@ -1 +1 @@ -__version__ = "2.6.10" +__version__ = "2.6.11"