diff --git a/.github/workflows/prod_patches_to_main.yaml b/.github/workflows/prod_patches_to_main.yaml index 29fc41ac5..7379b7d2e 100644 --- a/.github/workflows/prod_patches_to_main.yaml +++ b/.github/workflows/prod_patches_to_main.yaml @@ -8,9 +8,9 @@ on: workflow_dispatch: inputs: custom_branch_name: - description: "Custom Branch Name (optional)" + description: 'Custom Branch Name (optional)' required: false - default: "" + default: '' jobs: release: @@ -19,16 +19,15 @@ jobs: - name: Checkout Prod Branch uses: actions/checkout@v3 with: - ref: "prod" + ref: 'prod' - name: Install poetry run: pipx install poetry - - name: Set up Python 3.8 + - name: Set up Python 3.11 uses: actions/setup-python@v3 with: - python-version: "3.8" - cache: "poetry" + python-version: '3.11' - name: Get Application Version id: get_version @@ -47,7 +46,7 @@ jobs: - name: Checkout Main Branch uses: actions/checkout@v3 with: - ref: "main" + ref: 'main' fetch-depth: 0 - name: Setup Git Config diff --git a/notebooks/code_sharing/post_processing_functions.ipynb b/notebooks/code_sharing/post_processing_functions.ipynb new file mode 100644 index 000000000..3e8724eaf --- /dev/null +++ b/notebooks/code_sharing/post_processing_functions.ipynb @@ -0,0 +1,633 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Post-Processing Functions in ValidMind\n", + "\n", + "Welcome! This notebook demonstrates how to use post-processing functions with ValidMind tests to customize test outputs. You'll learn various ways to modify test results including updating tables, adding/removing tables, creating figures from tables, and vice versa.\n", + "\n", + "## Contents\n", + "- [About Post-Processing Functions](#about-post-processing-functions)\n", + "- [Key Concepts](#key-concepts)\n", + "- [Setup and Prerequisites](#setup-and-prerequisites)\n", + "- [Simple Tabular Updates](#simple-tabular-updates)\n", + "- [Adding Tables](#adding-tables) \n", + "- [Removing Tables](#removing-tables)\n", + "- [Creating Figures from Tables](#creating-figures-from-tables)\n", + "- [Creating Tables from Figures](#creating-tables-from-figures)\n", + "- [Re-Drawing Confusion Matrix](#re-drawing-confusion-matrix)\n", + "- [Re-Drawing ROC Curve](#re-drawing-roc-curve)\n", + "- [Custom Test Example](#custom-test-example)\n", + "\n", + "## About Post-Processing Functions\n", + "\n", + "Post-processing functions allow you to customize the output of ValidMind tests before they are saved to the platform. These functions take a TestResult object as input and return a modified TestResult object.\n", + "\n", + "Common use cases include:\n", + "- Reformatting table data\n", + "- Adding or removing tables/figures\n", + "- Creating new visualizations from test data\n", + "- Customizing test pass/fail criteria\n", + "\n", + "### Key Concepts\n", + "\n", + "**`validmind.vm_models.result.TestResult`**: Whenever a test is run with the `run_test` function in ValidMind, the items returned/produced by the test are bundled into a single `TestResult` object. There are several attributes on this object that are useful to know about:\n", + "- `tables`: List of `validmind.vm_models.result.ResultTable` objects (see below)\n", + "- `figures`: List of `validmind.vm_models.figure.Figure` objects (see below)\n", + "- `passed`: Optional boolean indicating test pass/fail status. `None` indicates that the test is not a pass/fail test (previously known as a threshold test).\n", + "- `raw_data`: Optional `validmind.vm_models.result.RawData` object containing additional data from test execution. Some ValidMind tests will produce this raw data to be used in post-processing functions. This data is not displayed in the test result or sent to the ValidMind platform (currently). To view the available raw data, you can run `result.raw_data.inspect()` which will return a dictionary where the keys are the raw data attributes available and the values are string representations of the data.\n", + "\n", + "**`validmind.vm_models.result.ResultTable`**: ValidMind object representing tables displayed in the test result and sent to the ValidMind platform:\n", + "- `title`: Optional table title\n", + "- `data`: Pandas dataframe\n", + "\n", + "**`validmind.vm_models.figure.Figure`**: ValidMind object representing plots/visualizations displayed in the test result and sent to the ValidMind platform:\n", + "- `figure`: matplotlib or plotly figure object\n", + "- `key`: Unique identifier\n", + "- `ref_id`: Reference ID linking to test" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup and Prerequisites\n", + "\n", + "First, we'll set up our environment and load sample data using the customer churn dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import xgboost as xgb\n", + "import validmind as vm\n", + "from validmind.datasets.classification import customer_churn\n", + "\n", + "raw_df = customer_churn.load_data()\n", + "\n", + "train_df, validation_df, test_df = customer_churn.preprocess(raw_df)\n", + "\n", + "x_train = train_df.drop(customer_churn.target_column, axis=1)\n", + "y_train = train_df[customer_churn.target_column]\n", + "x_val = validation_df.drop(customer_churn.target_column, axis=1)\n", + "y_val = validation_df[customer_churn.target_column]\n", + "\n", + "model = xgb.XGBClassifier(early_stopping_rounds=10)\n", + "model.set_params(\n", + " eval_metric=[\"error\", \"logloss\", \"auc\"],\n", + ")\n", + "model.fit(\n", + " x_train,\n", + " y_train,\n", + " eval_set=[(x_val, y_val)],\n", + " verbose=False,\n", + ")\n", + "\n", + "vm_raw_dataset = vm.init_dataset(\n", + " dataset=raw_df,\n", + " input_id=\"raw_dataset\",\n", + " target_column=customer_churn.target_column,\n", + " class_labels=customer_churn.class_labels,\n", + " __log=False,\n", + ")\n", + "\n", + "vm_train_ds = vm.init_dataset(\n", + " dataset=train_df,\n", + " input_id=\"train_dataset\",\n", + " target_column=customer_churn.target_column,\n", + " __log=False,\n", + ")\n", + "\n", + "vm_test_ds = vm.init_dataset(\n", + " dataset=test_df,\n", + " input_id=\"test_dataset\",\n", + " target_column=customer_churn.target_column,\n", + " __log=False,\n", + ")\n", + "\n", + "vm_model = vm.init_model(\n", + " model,\n", + " input_id=\"model\",\n", + " __log=False,\n", + ")\n", + "\n", + "vm_train_ds.assign_predictions(\n", + " model=vm_model,\n", + ")\n", + "\n", + "vm_test_ds.assign_predictions(\n", + " model=vm_model,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As a refresher, here is how we run a test normally, without any post-processing:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from validmind.tests import run_test\n", + "\n", + "result = run_test(\n", + " \"validmind.model_validation.sklearn.ClassifierPerformance\",\n", + " inputs={\"dataset\": vm_test_ds, \"model\": vm_model},\n", + " generate_description=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Post-processing functions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Simple Tabular Updates\n", + "\n", + "The simplest form of post-processing is modifying existing table data. Here we demonstrate updating class labels in a classification performance table.\n", + "\n", + "Some key concepts to keep in mind:\n", + "- Tables produced by a test are accessible via the `result.tables` attribute\n", + " - The `result.tables` attribute is a list of `ResultTable` objects which are simple data structures that contain a `data` attribute and an optional `title` attribute\n", + " - The `data` attribute is guaranteed to be a `pd.DataFrame` whether the test code itself returns a `pd.DataFrame` or a list of dictionaries\n", + " - The `title` attribute is optional and can be set by tests that return a dictionary where the keys are the table titles and the values are the table data (e.g. `{\"Classifier Performance\": performance_df, \"Class Legend\": [{\"Class Value\": \"0\", \"Class Label\": \"No Churn\"}, {\"Class Value\": \"1\", \"Class Label\": \"Churn\"}]}}`)\n", + "- Post-processing functions can directly modify any of the tables in the `result.tables` list and return the modified `TestResult` object... This can be useful for renaming columns, adding/removing rows, etc." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from validmind.vm_models.result import TestResult\n", + "\n", + "\n", + "def add_class_labels(result: TestResult):\n", + " result.tables[0].data[\"Class\"] = (\n", + " result.tables[0]\n", + " .data[\"Class\"]\n", + " .map(lambda x: \"Churn\" if x == \"1\" else \"No Churn\" if x == \"0\" else x)\n", + " )\n", + "\n", + " return result\n", + "\n", + "\n", + "result = run_test(\n", + " \"validmind.model_validation.sklearn.ClassifierPerformance\",\n", + " inputs={\"dataset\": vm_test_ds, \"model\": vm_model},\n", + " generate_description=False,\n", + " post_process_fn=add_class_labels,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Adding Tables\n", + "\n", + "This example shows how to add a legend table mapping class values to labels using the `TestResult.add_table()` method:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def add_table(result: TestResult):\n", + " # add legend table to show map of class value to class label\n", + " result.add_table(\n", + " title=\"Class Legend\",\n", + " table=[\n", + " {\"Class Value\": \"0\", \"Class Label\": \"No Churn\"},\n", + " {\"Class Value\": \"1\", \"Class Label\": \"Churn\"},\n", + " ],\n", + " )\n", + "\n", + " return result\n", + "\n", + "\n", + "result = run_test(\n", + " \"validmind.model_validation.sklearn.ClassifierPerformance\",\n", + " inputs={\"dataset\": vm_test_ds, \"model\": vm_model},\n", + " generate_description=False,\n", + " post_process_fn=add_table,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Removing Tables \n", + "\n", + "If there are tables in the test result that you don't want to display or log to the ValidMind platform, you can remove them using the `TestResult.remove_table()` method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def remove_table(result: TestResult):\n", + " result.remove_table(1)\n", + "\n", + " return result\n", + "\n", + "\n", + "result = run_test(\n", + " \"validmind.model_validation.sklearn.ClassifierPerformance\",\n", + " inputs={\"dataset\": vm_test_ds, \"model\": vm_model},\n", + " generate_description=False,\n", + " post_process_fn=remove_table,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Creating Figures from Tables\n", + "\n", + "A powerful use of post-processing is creating visualizations from tabular data. This example shows creating a bar plot from an outliers table using the `TestResult.add_figure()` method. This method can take a `matplotlib`, `plotly`, raw PNG `bytes`, or a `validmind.vm_models.figure.Figure` object." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from plotly_express import bar\n", + "\n", + "\n", + "def create_figure(result: TestResult):\n", + " result.add_figure(\n", + " bar(result.tables[0].data, x=\"Variable\", y=\"Total Count of Outliers\")\n", + " )\n", + "\n", + " return result\n", + "\n", + "\n", + "result = run_test(\n", + " \"validmind.data_validation.IQROutliersTable\",\n", + " inputs={\"dataset\": vm_test_ds},\n", + " generate_description=False,\n", + " post_process_fn=create_figure,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Creating Tables from Figures\n", + "\n", + "The reverse operation - extracting tabular data from figures - is also possible. However, its recommended instead to use the raw data produced by the test (assuming it is available) as the approach below requires deeper knowledge of the underlying figure (e.g. `matplotlib` or `plotly`) and may not be as robust/maintainable." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_table(result: TestResult):\n", + " for fig in result.figures:\n", + " data = fig.figure.data[0]\n", + "\n", + " result.add_table(\n", + " title=fig.figure.layout.title.text,\n", + " table=[\n", + " {\"Percentile\": x, \"Outlier Count\": y}\n", + " for x, y in zip(data.x, data.y)\n", + " ],\n", + " )\n", + "\n", + " return result\n", + "\n", + "\n", + "result = run_test(\n", + " \"validmind.data_validation.IQROutliersBarPlot\",\n", + " inputs={\"dataset\": vm_test_ds},\n", + " generate_description=False,\n", + " post_process_fn=create_table,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Re-Drawing Confusion Matrix\n", + "\n", + "A less common example is re-drawing a figure. This example uses the table produced by the test to create a matplotlib confusion matrix figure and removes the existing plotly figure." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "\n", + "def re_draw_class_imbalance(result: TestResult):\n", + " data = result.tables[0].data\n", + "\n", + " # remove the existing figure\n", + " result.remove_figure(0)\n", + "\n", + " # use matplotlib to plot the confusion matrix\n", + " fig = plt.figure()\n", + "\n", + " plt.bar(data[\"Exited\"], data[\"Percentage of Rows (%)\"])\n", + " plt.xlabel(\"Exited\")\n", + " plt.ylabel(\"Percentage of Rows (%)\")\n", + " plt.title(\"Class Imbalance\")\n", + "\n", + " # add the figure to the result\n", + " result.add_figure(fig)\n", + "\n", + " # close the figure to avoid showing it in the test result\n", + " plt.close()\n", + "\n", + " return result\n", + "\n", + "\n", + "result = run_test(\n", + " \"validmind.data_validation.ClassImbalance\",\n", + " inputs={\"dataset\": vm_test_ds},\n", + " generate_description=False,\n", + " post_process_fn=re_draw_class_imbalance,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Re-Drawing ROC Curve\n", + "\n", + "This example shows re-drawing the ROC curve using the raw data produced by the test. This is the recommended approach to reproducing figures or tables from test results as it allows you to get intermediate and other raw data that was originally used by the test to produce the figures or tables we want to reproduce." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First, let's run the test without post-processing to see the original result." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# run the test without post-processing\n", + "result = run_test(\n", + " \"validmind.model_validation.sklearn.ROCCurve\",\n", + " inputs={\"dataset\": vm_test_ds, \"model\": vm_model},\n", + " generate_description=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we have a `TestResult` object, we can inspect the raw data to see what is available." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result.raw_data.inspect()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now that we know what is available in the raw data, we can build a post-processing function that uses this raw data to reproduce the ROC curve." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def post_process_roc_curve(result: TestResult):\n", + " fpr = result.raw_data.fpr\n", + " tpr = result.raw_data.tpr\n", + " auc = result.raw_data.auc\n", + "\n", + " # remove the existing figure\n", + " result.remove_figure(0)\n", + "\n", + " # use matplotlib to plot the ROC curve\n", + " fig = plt.figure()\n", + "\n", + " plt.plot(fpr, tpr, label=f\"ROC Curve (AUC = {auc:.2f})\")\n", + " plt.xlabel(\"False Positive Rate\")\n", + " plt.ylabel(\"True Positive Rate\")\n", + " plt.title(\"ROC Curve\")\n", + "\n", + " plt.legend()\n", + "\n", + " plt.close()\n", + "\n", + " result.add_figure(fig)\n", + "\n", + " return result\n", + "\n", + "\n", + "result = run_test(\n", + " \"validmind.model_validation.sklearn.ROCCurve\",\n", + " inputs={\"dataset\": vm_test_ds, \"model\": vm_model},\n", + " generate_description=False,\n", + " post_process_fn=post_process_roc_curve,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Custom Test Example\n", + "\n", + "While we envision that post-processing functions are most useful for modifying built-in (ValidMind Library) tests, there are also cases where you may want to use them for your own custom tests. Let's see an example of this." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "from validmind import test\n", + "from validmind.tests import run_test\n", + "\n", + "\n", + "@test(\"custom.CorrelationBetweenVariables\")\n", + "def CorrelationBetweenVariables(var1: str, var2: str):\n", + " \"\"\"This fake test shows the relationship between two variables\"\"\"\n", + " data = pd.DataFrame(\n", + " {\n", + " \"Variable 1\": np.random.rand(20),\n", + " \"Variable 2\": np.random.rand(20),\n", + " }\n", + " )\n", + "\n", + " return [{\"Correlation between var1 and var2\": data.corr().iloc[0, 1]}]\n", + "\n", + "\n", + "variables = [\"Age\", \"Balance\", \"CreditScore\", \"EstimatedSalary\"]\n", + "\n", + "result = run_test(\n", + " \"custom.CorrelationBetweenVariables\",\n", + " param_grid={\n", + " \"var1\": variables,\n", + " \"var2\": variables,\n", + " }, # this will automatically generate all combinations of variables for var1 and var2\n", + " generate_description=False,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As you can see, the test result now contains a table with the correlation between each pair of variables like this:\n", + "\n", + "| var1 | var2 | Correlation between var1 and var2 |\n", + "|------|------|-----------------------------------|\n", + "| Age | Age | 0.3001 |\n", + "| Age | Balance | -0.4185 |\n", + "| Age | CreditScore | 0.2952 |\n", + "| Age | EstimatedSalary | -0.2855 |\n", + "| Balance | Age | 0.0141 |\n", + "| Balance | Balance | -0.1513 |\n", + "| Balance | CreditScore | 0.2401 |\n", + "| Balance | EstimatedSalary | 0.1198 |\n", + "| CreditScore | Age | -0.2320 |\n", + "| CreditScore | Balance | 0.4125 |\n", + "| CreditScore | CreditScore | 0.1726 |\n", + "| CreditScore | EstimatedSalary | 0.3187 |\n", + "| EstimatedSalary | Age | -0.1774 |\n", + "| EstimatedSalary | Balance | -0.1202 |\n", + "| EstimatedSalary | CreditScore | 0.1488 |\n", + "| EstimatedSalary | EstimatedSalary | 0.0524 |\n", + "\n", + "Now let's say we don't really want to see the big table of correlations. Instead, we want to see a heatmap of the correlations. We can use a post-processing function to create a heatmap from the table and add it to the test result while removing the table." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import plotly.graph_objects as go\n", + "\n", + "\n", + "def create_heatmap(result: TestResult):\n", + " # get the data from the existing table\n", + " data = result.tables[0].data\n", + "\n", + " # remove the existing table\n", + " result.remove_table(0)\n", + " \n", + " # Create a pivot table from the data to get it in matrix form\n", + " matrix = pd.pivot_table(\n", + " data,\n", + " values='Correlation between var1 and var2',\n", + " index='var1',\n", + " columns='var2'\n", + " )\n", + "\n", + " # remove the existing figure \n", + " result.remove_figure(0)\n", + "\n", + " # use plotly to create a heatmap\n", + " fig = go.Figure(data=go.Heatmap(\n", + " z=matrix.values,\n", + " x=matrix.columns,\n", + " y=matrix.index,\n", + " colorscale='RdBu',\n", + " zmid=0, # Center the color scale at 0\n", + " ))\n", + "\n", + " fig.update_layout(\n", + " title=\"Correlation Heatmap\",\n", + " xaxis_title=\"Variable\",\n", + " yaxis_title=\"Variable\",\n", + " )\n", + "\n", + " # add the figure to the result\n", + " result.add_figure(fig)\n", + "\n", + " return result\n", + "\n", + "\n", + "result = run_test(\n", + " \"custom.CorrelationBetweenVariables\",\n", + " param_grid={\n", + " \"var1\": variables,\n", + " \"var2\": variables,\n", + " },\n", + " generate_description=False,\n", + " post_process_fn=create_heatmap,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "validmind-BbKYUwN1-py3.11", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index 8588d6395..baeee3750 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,7 +10,7 @@ description = "ValidMind Library" license = "Commercial License" name = "validmind" readme = "README.pypi.md" -version = "2.6.10" +version = "2.6.11" [tool.poetry.dependencies] aiohttp = {extras = ["speedups"], version = "*"} diff --git a/validmind/__init__.py b/validmind/__init__.py index 0763c2b5f..36cd9250e 100644 --- a/validmind/__init__.py +++ b/validmind/__init__.py @@ -50,6 +50,7 @@ run_test_suite, ) from .tests.decorator import tags, tasks, test +from .vm_models.result import RawData __all__ = [ # noqa "__version__", @@ -62,6 +63,7 @@ "init_model", "init_r_model", "preview_template", + "RawData", "reload", "run_documentation_tests", "run_test_suite", diff --git a/validmind/__version__.py b/validmind/__version__.py index fcfff714f..396adecc4 100644 --- a/validmind/__version__.py +++ b/validmind/__version__.py @@ -1 +1 @@ -__version__ = "2.6.10" +__version__ = "2.6.11" diff --git a/validmind/tests/model_validation/sklearn/ROCCurve.py b/validmind/tests/model_validation/sklearn/ROCCurve.py index 9dfacd61c..7113d0bc1 100644 --- a/validmind/tests/model_validation/sklearn/ROCCurve.py +++ b/validmind/tests/model_validation/sklearn/ROCCurve.py @@ -6,7 +6,7 @@ import plotly.graph_objects as go from sklearn.metrics import roc_auc_score, roc_curve -from validmind import tags, tasks +from validmind import RawData, tags, tasks from validmind.errors import SkipTestError from validmind.vm_models import VMDataset, VMModel @@ -77,28 +77,31 @@ def ROCCurve(model: VMModel, dataset: VMDataset): fpr, tpr, _ = roc_curve(y_true, y_prob, drop_intermediate=False) auc = roc_auc_score(y_true, y_prob) - return go.Figure( - data=[ - go.Scatter( - x=fpr, - y=tpr, - mode="lines", - name=f"ROC curve (AUC = {auc:.2f})", - line=dict(color="#DE257E"), + return ( + RawData(fpr=fpr, tpr=tpr, auc=auc), + go.Figure( + data=[ + go.Scatter( + x=fpr, + y=tpr, + mode="lines", + name=f"ROC curve (AUC = {auc:.2f})", + line=dict(color="#DE257E"), + ), + go.Scatter( + x=[0, 1], + y=[0, 1], + mode="lines", + name="Random (AUC = 0.5)", + line=dict(color="grey", dash="dash"), + ), + ], + layout=go.Layout( + title=f"ROC Curve for {model.input_id} on {dataset.input_id}", + xaxis=dict(title="False Positive Rate"), + yaxis=dict(title="True Positive Rate"), + width=700, + height=500, ), - go.Scatter( - x=[0, 1], - y=[0, 1], - mode="lines", - name="Random (AUC = 0.5)", - line=dict(color="grey", dash="dash"), - ), - ], - layout=go.Layout( - title=f"ROC Curve for {model.input_id} on {dataset.input_id}", - xaxis=dict(title="False Positive Rate"), - yaxis=dict(title="True Positive Rate"), - width=700, - height=500, ), ) diff --git a/validmind/tests/output.py b/validmind/tests/output.py index 762be7bc1..d5afc3f3c 100644 --- a/validmind/tests/output.py +++ b/validmind/tests/output.py @@ -15,7 +15,7 @@ is_plotly_figure, is_png_image, ) -from validmind.vm_models.result import ResultTable, TestResult +from validmind.vm_models.result import RawData, ResultTable, TestResult class OutputHandler(ABC): @@ -103,6 +103,14 @@ def process( result.add_table(ResultTable(data=table_data, title=table_name or None)) +class RawDataOutputHandler(OutputHandler): + def can_handle(self, item: Any) -> bool: + return isinstance(item, RawData) + + def process(self, item: Any, result: TestResult) -> None: + result.raw_data = item + + def process_output(item: Any, result: TestResult) -> None: """Process a single test output item and update the TestResult.""" handlers = [ @@ -110,6 +118,7 @@ def process_output(item: Any, result: TestResult) -> None: MetricOutputHandler(), FigureOutputHandler(), TableOutputHandler(), + RawDataOutputHandler(), ] for handler in handlers: diff --git a/validmind/tests/run.py b/validmind/tests/run.py index 9c806f306..c3b28f050 100644 --- a/validmind/tests/run.py +++ b/validmind/tests/run.py @@ -7,7 +7,7 @@ import time from datetime import datetime from inspect import getdoc -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from uuid import uuid4 from validmind import __version__ @@ -283,6 +283,7 @@ def run_test( show: bool = True, generate_description: bool = True, title: Optional[str] = None, + post_process_fn: Union[Callable[[TestResult], None], None] = None, **kwargs, ) -> TestResult: """Run a ValidMind or custom test @@ -306,6 +307,7 @@ def run_test( show (bool, optional): Whether to display results. Defaults to True. generate_description (bool, optional): Whether to generate a description. Defaults to True. title (str, optional): Custom title for the test result + post_process_fn (Callable[[TestResult], None], optional): Function to post-process the test result Returns: TestResult: A TestResult object containing the test results @@ -394,6 +396,9 @@ def run_test( end_time = time.perf_counter() result.metadata = _get_run_metadata(duration_seconds=end_time - start_time) + if post_process_fn: + result = post_process_fn(result) + if show: result.show() diff --git a/validmind/utils.py b/validmind/utils.py index a2de3584e..42b3ec75b 100644 --- a/validmind/utils.py +++ b/validmind/utils.py @@ -168,6 +168,17 @@ def iterencode(self, obj, _one_shot: bool = ...): return super().iterencode(obj, _one_shot) +class HumanReadableEncoder(NumpyEncoder): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # truncate ndarrays to 10 items + self.type_handlers[self.is_numpy_ndarray] = lambda obj: ( + obj.tolist()[:5] + ["..."] + obj.tolist()[-5:] + if len(obj) > 10 + else obj.tolist() + ) + + def get_full_typename(o: Any) -> Any: """We determine types based on type names so we don't have to import (and therefore depend on) PyTorch, TensorFlow, etc. diff --git a/validmind/vm_models/figure.py b/validmind/vm_models/figure.py index c8eeca5bf..d843889b8 100644 --- a/validmind/vm_models/figure.py +++ b/validmind/vm_models/figure.py @@ -33,6 +33,18 @@ def is_png_image(figure) -> bool: return isinstance(figure, bytes) +def create_figure( + figure: Union[matplotlib.figure.Figure, go.Figure, go.FigureWidget, bytes], + key: str, + ref_id: str, +) -> "Figure": + """Create a VM Figure object from a raw figure object""" + if is_matplotlib_figure(figure) or is_plotly_figure(figure) or is_png_image(figure): + return Figure(key=key, figure=figure, ref_id=ref_id) + + raise ValueError(f"Unsupported figure type: {type(figure)}") + + @dataclass class Figure: """ @@ -55,6 +67,9 @@ def __post_init__(self): ): self.figure = go.FigureWidget(self.figure) + def __repr__(self): + return f"Figure(key={self.key}, ref_id={self.ref_id})" + def to_widget(self): """ Returns the ipywidget compatible representation of the figure. Ideally diff --git a/validmind/vm_models/result/__init__.py b/validmind/vm_models/result/__init__.py index 721e2f5cd..aca6c17e6 100644 --- a/validmind/vm_models/result/__init__.py +++ b/validmind/vm_models/result/__init__.py @@ -2,6 +2,6 @@ # See the LICENSE file in the root of this repository for details. # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial -from .result import ErrorResult, Result, ResultTable, TestResult +from .result import ErrorResult, RawData, Result, ResultTable, TestResult -__all__ = ["ErrorResult", "Result", "ResultTable", "TestResult"] +__all__ = ["ErrorResult", "RawData", "Result", "ResultTable", "TestResult"] diff --git a/validmind/vm_models/result/result.py b/validmind/vm_models/result/result.py index 08c26a7b8..8cea3641b 100644 --- a/validmind/vm_models/result/result.py +++ b/validmind/vm_models/result/result.py @@ -12,14 +12,22 @@ from typing import Any, Dict, List, Optional, Union from uuid import uuid4 +import matplotlib import pandas as pd +import plotly.graph_objs as go from ipywidgets import HTML, VBox from ... import api_client from ...ai.utils import DescriptionFuture from ...logging import get_logger -from ...utils import NumpyEncoder, display, run_async, test_id_to_name -from ..figure import Figure +from ...utils import ( + HumanReadableEncoder, + NumpyEncoder, + display, + run_async, + test_id_to_name, +) +from ..figure import Figure, create_figure from ..input import VMInput from .utils import ( AI_REVISION_NAME, @@ -34,6 +42,42 @@ logger = get_logger(__name__) +class RawData: + """Holds raw data for a test result""" + + def __init__(self, log: bool = False, **kwargs): + """Create a new RawData object + + Args: + log (bool): If True, log the raw data to ValidMind + **kwargs: Keyword arguments to set as attributes e.g. + `RawData(log=True, dataset_duplicates=df_duplicates)` + """ + self.log = log + + for key, value in kwargs.items(): + setattr(self, key, value) + + def __repr__(self) -> str: + return f"RawData({', '.join(self.__dict__.keys())})" + + def inspect(self, show: bool = True): + """Inspect the raw data""" + raw_data = { + key: getattr(self, key) + for key in self.__dict__ + if not key.startswith("_") and key != "log" + } + + if not show: + return raw_data + + print(json.dumps(raw_data, indent=2, cls=HumanReadableEncoder)) + + def serialize(self): + return {key: getattr(self, key) for key in self.__dict__} + + @dataclass class ResultTable: """ @@ -41,7 +85,7 @@ class ResultTable: """ data: Union[List[Any], pd.DataFrame] - title: str + title: Optional[str] = None def __repr__(self) -> str: return f'ResultTable(title="{self.title}")' if self.title else "ResultTable" @@ -118,12 +162,12 @@ class TestResult(Result): description: Optional[Union[str, DescriptionFuture]] = None metric: Optional[Union[int, float]] = None tables: Optional[List[ResultTable]] = None + raw_data: Optional[RawData] = None figures: Optional[List[Figure]] = None passed: Optional[bool] = None params: Optional[Dict[str, Any]] = None inputs: Optional[Dict[str, Union[List[VMInput], VMInput]]] = None metadata: Optional[Dict[str, Any]] = None - title: Optional[str] = None _was_description_generated: bool = False _unsafe: bool = False @@ -144,6 +188,11 @@ def __repr__(self) -> str: "passed", ] if getattr(self, attr) is not None + and ( + len(getattr(self, attr)) > 0 + if isinstance(getattr(self, attr), list) + else True + ) ] return f'TestResult("{self.result_id}", {", ".join(attrs)})' @@ -164,21 +213,82 @@ def _get_flat_inputs(self): return list(inputs.values()) - def add_table(self, table: ResultTable): + def add_table( + self, + table: Union[ResultTable, pd.DataFrame, List[Dict[str, Any]]], + title: Optional[str] = None, + ): + """Add a new table to the result + + Args: + table (Union[ResultTable, pd.DataFrame, List[Dict[str, Any]]]): The table to add + title (Optional[str]): The title of the table (can optionally be provided for + pd.DataFrame and List[Dict[str, Any]] tables) + """ if self.tables is None: self.tables = [] + if isinstance(table, (pd.DataFrame, list)): + table = ResultTable(data=table, title=title) + self.tables.append(table) - def add_figure(self, figure: Figure): + def remove_table(self, index: int): + """Remove a table from the result by index + + Args: + index (int): The index of the table to remove (default is 0) + """ + if self.tables is None: + return + + self.tables.pop(index) + + def add_figure( + self, + figure: Union[ + matplotlib.figure.Figure, + go.Figure, + go.FigureWidget, + bytes, + Figure, + ], + ): + """Add a new figure to the result + + Args: + figure (Union[matplotlib.figure.Figure, go.Figure, go.FigureWidget, + bytes, Figure]): The figure to add (can be either a VM Figure object, + a raw figure object from the supported libraries, or a png image as + raw bytes) + """ if self.figures is None: self.figures = [] + if not isinstance(figure, Figure): + random_tag = str(uuid4())[:4] + figure = create_figure( + figure=figure, + ref_id=self.ref_id, + key=f"{self.result_id}:{random_tag}", + ) + if figure.ref_id != self.ref_id: figure.ref_id = self.ref_id self.figures.append(figure) + def remove_figure(self, index: int = 0): + """Remove a figure from the result by index + + Args: + index (int): The index of the figure to remove (default is 0) + """ + if self.figures is None: + return + + self.figures.pop(index) + def to_widget(self): if isinstance(self.description, DescriptionFuture): self.description = self.description.get_description()