diff --git a/splink/internals/blocking_analysis.py b/splink/internals/blocking_analysis.py index 4511faf95a..565eefc063 100644 --- a/splink/internals/blocking_analysis.py +++ b/splink/internals/blocking_analysis.py @@ -30,7 +30,7 @@ from splink.internals.blocking_rule_creator_utils import to_blocking_rule_creator from splink.internals.charts import ( ChartReturnType, - cumulative_blocking_rule_comparisons_generated, + CumulativeBlockingRuleComparisonsGeneratedChart, ) from splink.internals.database_api import DatabaseAPISubClass from splink.internals.duckdb.duckdb_helpers import record_dicts_from_relation @@ -721,7 +721,9 @@ def cumulative_comparisons_to_be_scored_from_blocking_rules_chart( ) ) - return cumulative_blocking_rule_comparisons_generated(cumulative_comparison_records) + return CumulativeBlockingRuleComparisonsGeneratedChart( + cumulative_comparison_records + ).chart def n_largest_blocks( diff --git a/splink/internals/charts.py b/splink/internals/charts.py index 9f1a4a31fa..31dcf3bd99 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -3,17 +3,37 @@ import json import math import os -from typing import TYPE_CHECKING, Any, Dict, Union +from abc import ABC, abstractmethod +from collections.abc import Iterable +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Literal, + Protocol, + Sequence, + TypeVar, + Union, +) from splink.internals.misc import read_resource -from splink.internals.waterfall_chart import records_to_waterfall_data if TYPE_CHECKING: from altair import SchemaBase + + from splink.internals.comparison_level import ComparisonLevelDetailedRecord + from splink.internals.em_training_session import ( + ModelParameterIterationDetailedRecord, + ) + from splink.internals.settings import ModelParameterDetailedRecord else: SchemaBase = None + + ComparisonLevelDetailedRecord = None + ModelParameterDetailedRecord = None + ModelParameterIterationDetailedRecord = None # type alias: -ChartReturnType = Union[Dict[Any, Any], SchemaBase] +ChartReturnType = Union[dict[Any, Any], SchemaBase] def load_chart_definition(filename): @@ -41,6 +61,18 @@ def altair_or_json( return chart_dict +class AsDictable(Protocol): + def as_dict(self) -> dict[str, Any]: ... + + +def list_items_as_dicts( + lst: Iterable[AsDictable | dict[str, Any]], +) -> list[dict[str, Any]]: + return list( + map(lambda item: item if isinstance(item, dict) else item.as_dict(), lst) + ) + + iframe_message = """ To view in Jupyter you can use the following command: @@ -69,9 +101,6 @@ def save_offline_chart( f"The path {filename} already exists. Please provide a different path." ) - if type(chart_dict).__name__ == "VegaliteNoValidate": - chart_dict = chart_dict.spec - template = read_resource("internals/files/templates/single_chart_template.html") fmt_dict = _load_external_libs() @@ -86,348 +115,503 @@ def save_offline_chart( print(iframe_message.format(filename=filename)) # noqa: T201 -def match_weights_chart(records, as_dict=False): - chart_path = "match_weights_interactive_history.json" - chart = load_chart_definition(chart_path) - - # Remove iteration history since this is a static chart - del chart["params"] - del chart["transform"] - - records = [r for r in records if r["comparison_vector_value"] != -1] - chart["data"]["values"] = records - - bayes_factors = [ - abs(l2bf) - for r in records - if (l2bf := r["log2_bayes_factor"]) is not None and not math.isinf(l2bf) - ] - max_value = math.ceil(max(bayes_factors)) - - chart["vconcat"][0]["encoding"]["x"]["scale"]["domain"] = [-max_value, max_value] - chart["vconcat"][1]["encoding"]["x"]["scale"]["domain"] = [-max_value, max_value] - - return altair_or_json(chart, as_dict=as_dict) - - -def comparison_match_weights_chart(records, as_dict=False): - chart_path = "match_weights_interactive_history.json" - chart = load_chart_definition(chart_path) - - # Remove iteration history since this is a static chart - del chart["vconcat"][0] - del chart["params"] - del chart["transform"] - - chart["title"]["text"] = "Comparison summary" - records = [r for r in records if r["comparison_vector_value"] != -1] - chart["data"]["values"] = records - return altair_or_json(chart, as_dict=as_dict) - - -def m_u_parameters_chart(records, as_dict=False): - chart_path = "m_u_parameters_interactive_history.json" - chart = load_chart_definition(chart_path) - - # Remove iteration history since this is a static chart - del chart["params"] - del chart["transform"] - - records = [ - r - for r in records - if r["comparison_vector_value"] != -1 - and r["comparison_name"] != "probability_two_random_records_match" - ] - chart["data"]["values"] = records - return altair_or_json(chart, as_dict=as_dict) - - -def probability_two_random_records_match_iteration_chart(records, as_dict=False): - chart_path = "probability_two_random_records_match_iteration.json" - chart = load_chart_definition(chart_path) - - chart["data"]["values"] = records - return altair_or_json(chart, as_dict=as_dict) - - -def match_weights_interactive_history_chart(records, as_dict=False, blocking_rule=None): - chart_path = "match_weights_interactive_history.json" - chart = load_chart_definition(chart_path) - - chart["title"]["subtitle"] = f"Training session blocked on {blocking_rule}" - - records = [r for r in records if r["comparison_vector_value"] != -1] - chart["data"]["values"] = records - - max_iteration = 0 - for r in records: - max_iteration = max(r["iteration"], max_iteration) +# TODO: we can have more detailed subclasses to hint the fields needed per chart +class ChartRecord(Protocol): ... + + +T = TypeVar("T", bound=ChartRecord) + + +class SplinkChart(ABC, Generic[T]): + def __init__(self, records: Sequence[T], as_dict: bool = False): + # TODO: as_dict only in methods rather than on object + self.raw_records = records + self.as_dict = as_dict + self.width: float | None = self.default_width + self.height: float | None = self.default_height + + @property + @abstractmethod + def chart_spec_file(self) -> str: + pass + + @property + def chart_data(self): + # TODO: cache + return self.alter_data(self.raw_records) + + @property + def default_width(self) -> float | None: + return None + + @property + def default_height(self) -> float | None: + return None + + @property + def chart_spec(self) -> dict[str, Any]: + chart_spec = load_chart_definition(self.chart_spec_file) + chart_spec = self.alter_spec_directly(chart_spec) + chart_spec = self.alter_spec_height_width(chart_spec) + chart_spec = self.alter_spec_from_data(chart_spec) + return chart_spec + + @property + def chart_dict(self) -> dict[str, Any]: + chart = self.chart_spec + chart["data"]["values"] = list_items_as_dicts(self.chart_data) + return chart + + @property + def chart(self) -> ChartReturnType: + # TODO: split into separate methods + # also save etc + return altair_or_json(self.chart_dict, as_dict=self.as_dict) + + @staticmethod + def alter_data(records): + return records + + @staticmethod + def alter_spec_directly(chart_spec): + return chart_spec + + def alter_spec_from_data(self, chart_spec): + return chart_spec + + def alter_spec_height_width(self, chart_spec): + if width := self.width: + chart_spec["width"] = width + if height := self.height: + chart_spec["height"] = height + return chart_spec + + def set_width_height( + self, *, width: float | None = None, height: float | None = None + ) -> None: + # TODO: do we need to be careful with charts that can't have these set + # such as threshold_selection_tool + self.width = width + self.height = height + # TODO: return self? + + +class MatchWeightsChart(SplinkChart[ComparisonLevelDetailedRecord]): + @property + def chart_spec_file(self) -> str: + return "match_weights_interactive_history.json" + + @staticmethod + def alter_data(records): + return [r for r in records if r.comparison_vector_value != -1] + + @staticmethod + def alter_spec_directly(chart_spec): + # Remove iteration history since this is a static chart + del chart_spec["params"] + del chart_spec["transform"] + return chart_spec + + def alter_spec_from_data(self, chart_spec): + bayes_factors = [ + abs(l2bf) + for r in self.chart_data + if (l2bf := r.log2_bayes_factor) is not None and not math.isinf(l2bf) + ] + max_value = math.ceil(max(bayes_factors)) + chart_spec["vconcat"][0]["encoding"]["x"]["scale"]["domain"] = [ + -max_value, + max_value, + ] + chart_spec["vconcat"][1]["encoding"]["x"]["scale"]["domain"] = [ + -max_value, + max_value, + ] + return chart_spec + + +class ComparisonMatchWeightsChart(SplinkChart[ComparisonLevelDetailedRecord]): + @property + def chart_spec_file(self) -> str: + return "match_weights_interactive_history.json" + + @staticmethod + def alter_data(records): + return [r for r in records if r.comparison_vector_value != -1] + + @staticmethod + def alter_spec_directly(chart_spec): + # Remove iteration history since this is a static chart + # TODO: some render issue if we remove empty top panel, so leave for now + # del chart["vconcat"][0] + del chart_spec["params"] + del chart_spec["transform"] + chart_spec["title"]["text"] = "Comparison summary" + return chart_spec + + +class MUParametersChart(SplinkChart[ModelParameterDetailedRecord]): + @property + def chart_spec_file(self) -> str: + return "m_u_parameters_interactive_history.json" + + @staticmethod + def alter_data(records): + return [ + r + for r in records + if r.comparison_vector_value != -1 + and r.comparison_name != "probability_two_random_records_match" + ] - chart["params"][0]["bind"]["max"] = max_iteration - chart["params"][0]["value"] = max_iteration - return altair_or_json(chart, as_dict=as_dict) + @staticmethod + def alter_spec_directly(chart_spec): + # Remove iteration history since this is a static chart + del chart_spec["params"] + del chart_spec["transform"] + return chart_spec -def m_u_parameters_interactive_history_chart(records, as_dict=False): - chart_path = "m_u_parameters_interactive_history.json" - chart = load_chart_definition(chart_path) - records = [ - r - for r in records - if r["comparison_vector_value"] != -1 - and r["comparison_name"] != "probability_two_random_records_match" - ] - chart["data"]["values"] = records - - max_iteration = 0 - for r in records: - max_iteration = max(r["iteration"], max_iteration) - - chart["params"][0]["bind"]["max"] = max_iteration - chart["params"][0]["value"] = max_iteration - return altair_or_json(chart, as_dict=as_dict) +class ProbabilityTwoRandomRecordsMatchIterationChart(SplinkChart[ChartRecord]): + @property + def chart_spec_file(self) -> str: + return "probability_two_random_records_match_iteration.json" -def waterfall_chart( - records, - settings_obj, - filter_nulls=True, - remove_sensitive_data=False, - as_dict=False, +class MatchWeightsInteractiveHistoryChart( + SplinkChart[ModelParameterIterationDetailedRecord] ): - data = records_to_waterfall_data(records, settings_obj, remove_sensitive_data) - chart_path = "match_weights_waterfall.json" - chart = load_chart_definition(chart_path) - chart["data"]["values"] = data - chart["params"][0]["bind"]["max"] = len(records) - 1 - if filter_nulls: - chart["transform"].insert(1, {"filter": "(datum.bayes_factor !== 1.0)"}) - - return altair_or_json(chart, as_dict=as_dict) - - -def roc_chart(records, width=400, height=400, as_dict=False): - chart_path = "roc.json" - chart = load_chart_definition(chart_path) - - chart["data"]["values"] = records - - # If 'curve_label' not in records, remove colour coding - # This is for if you want to compare roc curves - r = records[0] - if "curve_label" not in r.keys(): - del chart["encoding"]["color"] - - chart["height"] = height - chart["width"] = width - - return altair_or_json(chart, as_dict=as_dict) - - -def precision_recall_chart(records, width=400, height=400, as_dict=False): - chart_path = "precision_recall.json" - chart = load_chart_definition(chart_path) - - chart["data"]["values"] = records - - # If 'curve_label' not in records, remove colour coding - # This is for if you want to compare roc curves - r = records[0] - if "curve_label" not in r.keys(): - del chart["encoding"]["color"] - - chart["height"] = height - chart["width"] = width - - return altair_or_json(chart, as_dict=as_dict) - - -def accuracy_chart(records, width=400, height=400, as_dict=False, add_metrics=[]): - chart_path = "accuracy_chart.json" - chart = load_chart_definition(chart_path) - - # User-specified metrics to include - metrics = ["precision", "recall", *add_metrics] - chart["transform"][0]["fold"] = metrics - chart["transform"][1]["calculate"] = chart["transform"][1]["calculate"].replace( - "__metrics__", str(metrics) - ) - chart["layer"][0]["encoding"]["color"]["sort"] = metrics - chart["layer"][1]["layer"][1]["encoding"]["color"]["sort"] = metrics - - # Metric-label mapping - mapping = { - "precision": "Precision (PPV)", - "recall": "Recall (TPR)", - "specificity": "Specificity (TNR)", - "accuracy": "Accuracy", - "npv": "NPV", - "f1": "F1", - "f2": "F2", - "f0_5": "F0.5", - "p4": "P4", - "phi": "\u03c6 (MCC)", - } - chart["transform"][2]["calculate"] = chart["transform"][2]["calculate"].replace( - "__mapping__", str(mapping) - ) - chart["layer"][0]["encoding"]["color"]["legend"]["labelExpr"] = chart["layer"][0][ - "encoding" - ]["color"]["legend"]["labelExpr"].replace("__mapping__", str(mapping)) - - chart["data"]["values"] = records - - chart["height"] = height - chart["width"] = width - - return altair_or_json(chart, as_dict=as_dict) - - -def threshold_selection_tool(records, as_dict=False, add_metrics=[]): - chart_path = "threshold_selection_tool.json" - chart = load_chart_definition(chart_path) - - # Remove extremes with low precision and recall - records = [d for d in records if d["precision"] > 0.5 and d["recall"] > 0.5] - - # User-specified metrics to include - metrics = ["precision", "recall", *add_metrics] - - chart["hconcat"][1]["transform"][0]["fold"] = metrics - chart["hconcat"][1]["transform"][1]["calculate"] = chart["hconcat"][1]["transform"][ - 1 - ]["calculate"].replace("__metrics__", str(metrics)) - chart["hconcat"][1]["layer"][0]["encoding"]["color"]["sort"] = metrics - chart["hconcat"][1]["layer"][1]["layer"][1]["encoding"]["color"]["sort"] = metrics - - # Metric-label mapping - mapping = { - "precision": "Precision (PPV)", - "recall": "Recall (TPR)", - "specificity": "Specificity (TNR)", - "accuracy": "Accuracy", - "npv": "NPV", - "f1": "F1", - "f2": "F2", - "f0_5": "F0.5", - "p4": "P4", - "phi": "\u03c6 (MCC)", - } - chart["hconcat"][1]["transform"][2]["calculate"] = chart["hconcat"][1]["transform"][ - 2 - ]["calculate"].replace("__mapping__", str(mapping)) - chart["hconcat"][1]["layer"][0]["encoding"]["color"]["legend"]["labelExpr"] = chart[ - "hconcat" - ][1]["layer"][0]["encoding"]["color"]["legend"]["labelExpr"].replace( - "__mapping__", str(mapping) - ) - - chart["data"]["values"] = records - - return altair_or_json(chart, as_dict=as_dict) - - -def match_weights_histogram(records, width=500, height=250, as_dict=False): - chart_path = "match_weight_histogram.json" - chart = load_chart_definition(chart_path) - - chart["data"]["values"] = records - - chart["height"] = height - chart["width"] = width - - return altair_or_json(chart, as_dict=as_dict) - - -def parameter_estimate_comparisons(records, as_dict=False): - chart_path = "parameter_estimate_comparisons.json" - chart = load_chart_definition(chart_path) - - chart["data"]["values"] = records - - return altair_or_json(chart, as_dict=as_dict) - - -def missingness_chart(records, as_dict=False): - chart_path = "missingness.json" - chart = load_chart_definition(chart_path) - - chart["data"]["values"] = records - - record_count = records[0]["total_record_count"] - - for c in chart["layer"]: - c["title"] = f"Missingness per column out of {record_count:,.0f} records" - - return altair_or_json(chart, as_dict=as_dict) + def __init__( + self, + records: Sequence[ModelParameterIterationDetailedRecord], + blocking_rule_text: str, + as_dict: bool = False, + ): + super().__init__(records, as_dict=as_dict) + self.blocking_rule_text = blocking_rule_text + + @property + def chart_spec_file(self) -> str: + return "match_weights_interactive_history.json" + + @staticmethod + def alter_data(records): + return [r for r in records if r.comparison_vector_value != -1] + + def alter_spec_from_data(self, chart_spec): + records = self.chart_data + max_iteration = 0 + for r in records: + max_iteration = max(r.iteration, max_iteration) + + chart_spec["params"][0]["bind"]["max"] = max_iteration + chart_spec["params"][0]["value"] = max_iteration + chart_spec["title"]["subtitle"] = ( + f"Training session blocked on {self.blocking_rule_text}" + ) + return chart_spec -def unlinkables_chart( - records, - x_col="match_weight", - source_dataset=None, - as_dict=False, +class MUParametersInteractiveHistoryChart( + SplinkChart[ModelParameterIterationDetailedRecord] ): - if x_col not in ["match_weight", "match_probability"]: - raise ValueError( - f"{x_col} must be 'match_weight' (default) or 'match_probability'." - ) - - chart_path = "unlinkables_chart_def.json" - unlinkables_chart_def = load_chart_definition(chart_path) - unlinkables_chart_def["data"]["values"] = records + @property + def chart_spec_file(self) -> str: + return "m_u_parameters_interactive_history.json" + + @staticmethod + def alter_data(records): + return [ + r + for r in records + if r.comparison_vector_value != -1 + and r.comparison_name != "probability_two_random_records_match" + ] - if x_col == "match_probability": - unlinkables_chart_def["layer"][0]["encoding"]["x"]["field"] = ( - "match_probability" + def alter_spec_from_data(self, chart_spec): + records = self.chart_data + max_iteration = 0 + for r in records: + max_iteration = max(r.iteration, max_iteration) + + chart_spec["params"][0]["bind"]["max"] = max_iteration + chart_spec["params"][0]["value"] = max_iteration + return chart_spec + + +class WaterfallChart(SplinkChart[ChartRecord]): + def __init__( + self, + records: Sequence[ChartRecord], + filter_nulls: bool = True, + as_dict: bool = False, + ): + super().__init__(records, as_dict=as_dict) + self.filter_nulls = filter_nulls + + @property + def chart_spec_file(self) -> str: + return "match_weights_waterfall.json" + + def alter_spec_from_data(self, chart_spec): + records = self.chart_data + chart_spec["params"][0]["bind"]["max"] = len(records) - 1 + if self.filter_nulls: + chart_spec["transform"].insert( + 1, {"filter": "(datum.bayes_factor !== 1.0)"} + ) + return chart_spec + + +class ROCChart(SplinkChart[ChartRecord]): + @property + def chart_spec_file(self) -> str: + return "roc.json" + + @property + def default_width(self) -> float: + return 400 + + @property + def default_height(self) -> float: + return 400 + + def alter_spec_from_data(self, chart_spec): + records = self.chart_data + # If 'curve_label' not in records, remove colour coding + # This is for if you want to compare roc curves + r = records[0] + if "curve_label" not in r.keys(): + del chart_spec["encoding"]["color"] + return chart_spec + + +class PrecisionRecallChart(SplinkChart[ChartRecord]): + @property + def chart_spec_file(self) -> str: + return "precision_recall.json" + + @property + def default_width(self) -> float: + return 400 + + @property + def default_height(self) -> float: + return 400 + + def alter_spec_from_data(self, chart_spec): + records = self.chart_data + # If 'curve_label' not in records, remove colour coding + # This is for if you want to compare roc curves + r = records[0] + if "curve_label" not in r.keys(): + del chart_spec["encoding"]["color"] + return chart_spec + + +class AccuracyChart(SplinkChart[ChartRecord]): + def __init__( + self, + records: Sequence[ChartRecord], + add_metrics: Sequence[str] = [], # TODO + as_dict: bool = False, + ): + super().__init__(records, as_dict=as_dict) + # User-specified metrics to include + self.additional_metrics = add_metrics + + @property + def chart_spec_file(self) -> str: + return "accuracy_chart.json" + + @property + def default_width(self) -> float: + return 400 + + @property + def default_height(self) -> float: + return 400 + + @property + def metrics(self): + return ["precision", "recall", *self.additional_metrics] + + @staticmethod + def alter_spec_directly(chart_spec): + mapping = { + "precision": "Precision (PPV)", + "recall": "Recall (TPR)", + "specificity": "Specificity (TNR)", + "accuracy": "Accuracy", + "npv": "NPV", + "f1": "F1", + "f2": "F2", + "f0_5": "F0.5", + "p4": "P4", + "phi": "\u03c6 (MCC)", + } + chart_spec["transform"][2]["calculate"] = chart_spec["transform"][2][ + "calculate" + ].replace("__mapping__", str(mapping)) + chart_spec["layer"][0]["encoding"]["color"]["legend"]["labelExpr"] = chart_spec[ + "layer" + ][0]["encoding"]["color"]["legend"]["labelExpr"].replace( + "__mapping__", str(mapping) ) - unlinkables_chart_def["layer"][0]["encoding"]["x"]["axis"]["title"] = ( + + return chart_spec + + def alter_spec_from_data(self, chart_spec): + metrics = self.metrics + chart_spec["transform"][0]["fold"] = metrics + chart_spec["transform"][1]["calculate"] = chart_spec["transform"][1][ + "calculate" + ].replace("__metrics__", str(metrics)) + chart_spec["layer"][0]["encoding"]["color"]["sort"] = metrics + chart_spec["layer"][1]["layer"][1]["encoding"]["color"]["sort"] = metrics + return chart_spec + + +class ThresholdSelectionToolChart(SplinkChart[ChartRecord]): + def __init__( + self, + records: Sequence[ChartRecord], + add_metrics: Sequence[str] = [], # TODO + as_dict: bool = False, + ): + super().__init__(records, as_dict=as_dict) + # User-specified metrics to include + self.additional_metrics = add_metrics + + @property + def chart_spec_file(self) -> str: + return "threshold_selection_tool.json" + + @property + def metrics(self): + return ["precision", "recall", *self.additional_metrics] + + @staticmethod + def alter_data(records): + return [d for d in records if d["precision"] > 0.5 and d["recall"] > 0.5] + + @staticmethod + def alter_spec_directly(chart_spec): + mapping = { + "precision": "Precision (PPV)", + "recall": "Recall (TPR)", + "specificity": "Specificity (TNR)", + "accuracy": "Accuracy", + "npv": "NPV", + "f1": "F1", + "f2": "F2", + "f0_5": "F0.5", + "p4": "P4", + "phi": "\u03c6 (MCC)", + } + chart_spec["hconcat"][1]["transform"][2]["calculate"] = chart_spec["hconcat"][ + 1 + ]["transform"][2]["calculate"].replace("__mapping__", str(mapping)) + chart_spec["hconcat"][1]["layer"][0]["encoding"]["color"]["legend"][ + "labelExpr" + ] = chart_spec["hconcat"][1]["layer"][0]["encoding"]["color"]["legend"][ + "labelExpr" + ].replace("__mapping__", str(mapping)) + + return chart_spec + + def alter_spec_from_data(self, chart_spec): + metrics = self.metrics + chart_spec["hconcat"][1]["transform"][0]["fold"] = metrics + chart_spec["hconcat"][1]["transform"][1]["calculate"] = chart_spec["hconcat"][ + 1 + ]["transform"][1]["calculate"].replace("__metrics__", str(metrics)) + chart_spec["hconcat"][1]["layer"][0]["encoding"]["color"]["sort"] = metrics + chart_spec["hconcat"][1]["layer"][1]["layer"][1]["encoding"]["color"][ + "sort" + ] = metrics + + return chart_spec + + +class MatchWeightsHistogramChart(SplinkChart[ChartRecord]): + @property + def chart_spec_file(self) -> str: + return "match_weight_histogram.json" + + @property + def default_width(self) -> float: + return 500 + + @property + def default_height(self) -> float: + return 250 + + +class ParameterEstimateComparisonsChart(SplinkChart[ChartRecord]): + @property + def chart_spec_file(self) -> str: + return "parameter_estimate_comparisons.json" + + +class UnlinkablesChart(SplinkChart[ChartRecord]): + def __init__( + self, + records: Sequence[ChartRecord], + x_col: Literal["match_weight", "match_probability"] = "match_weight", + source_dataset: str | None = None, + as_dict: bool = False, + ): + if x_col not in ["match_weight", "match_probability"]: + raise ValueError( + f"{x_col} must be 'match_weight' (default) or 'match_probability'." + ) + super().__init__(records, as_dict=as_dict) + self.x_col = x_col + self.source_dataset = source_dataset + + @property + def chart_spec_file(self) -> str: + return "unlinkables_chart_def.json" + + def alter_spec_from_data(self, chart_spec): + if source_dataset := self.source_dataset: + chart_spec["title"]["text"] += f" - {source_dataset}" + if self.x_col == "match_weight": + return chart_spec + # if we have match_probability we need to update spec to match: + chart_spec["layer"][0]["encoding"]["x"]["field"] = "match_probability" + chart_spec["layer"][0]["encoding"]["x"]["axis"]["title"] = ( "Threshold match probability" ) - unlinkables_chart_def["layer"][0]["encoding"]["x"]["axis"]["format"] = ".2" + chart_spec["layer"][0]["encoding"]["x"]["axis"]["format"] = ".2" - unlinkables_chart_def["layer"][1]["encoding"]["x"]["field"] = ( - "match_probability" - ) - unlinkables_chart_def["layer"][1]["selection"]["selector112"]["fields"] = [ + chart_spec["layer"][1]["encoding"]["x"]["field"] = "match_probability" + chart_spec["layer"][1]["selection"]["selector112"]["fields"] = [ "match_probability", "cum_prop", ] - unlinkables_chart_def["layer"][2]["encoding"]["x"]["field"] = ( - "match_probability" - ) - unlinkables_chart_def["layer"][2]["encoding"]["x"]["axis"]["title"] = ( + chart_spec["layer"][2]["encoding"]["x"]["field"] = "match_probability" + chart_spec["layer"][2]["encoding"]["x"]["axis"]["title"] = ( "Threshold match probability" ) - unlinkables_chart_def["layer"][3]["encoding"]["x"]["field"] = ( - "match_probability" - ) - - if source_dataset: - unlinkables_chart_def["title"]["text"] += f" - {source_dataset}" + chart_spec["layer"][3]["encoding"]["x"]["field"] = "match_probability" - return altair_or_json(unlinkables_chart_def, as_dict=as_dict) + return chart_spec -def completeness_chart(records, as_dict=False): - chart_path = "completeness.json" - chart = load_chart_definition(chart_path) +class CompletenessChart(SplinkChart[ChartRecord]): + @property + def chart_spec_file(self) -> str: + return "completeness.json" - chart["data"]["values"] = records - return altair_or_json(chart, as_dict=as_dict) - - -def cumulative_blocking_rule_comparisons_generated(records, as_dict=False): - chart_path = "blocking_rule_generated_comparisons.json" - chart = load_chart_definition(chart_path) - - chart["data"]["values"] = records - - return altair_or_json(chart, as_dict=as_dict) +class CumulativeBlockingRuleComparisonsGeneratedChart(SplinkChart[ChartRecord]): + @property + def chart_spec_file(self) -> str: + return "blocking_rule_generated_comparisons.json" def _comparator_score_chart(similarity_records, distance_records, as_dict=False): diff --git a/splink/internals/comparison.py b/splink/internals/comparison.py index c7e69492d6..00eef7f3ea 100644 --- a/splink/internals/comparison.py +++ b/splink/internals/comparison.py @@ -10,7 +10,12 @@ join_list_with_commas_final_and, ) -from .comparison_level import ComparisonLevel, _default_m_values, _default_u_values +from .comparison_level import ( + ComparisonLevel, + ComparisonLevelDetailedRecord, + _default_m_values, + _default_u_values, +) # https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports if TYPE_CHECKING: @@ -360,15 +365,11 @@ def _is_trained(self): return self._all_m_are_trained and self._all_u_are_trained @property - def _as_detailed_records(self) -> list[dict[str, Any]]: + def _as_detailed_records(self) -> list[ComparisonLevelDetailedRecord]: records = [] for cl in self.comparison_levels: - record = {} - record["comparison_name"] = self.output_column_name - record = { - **record, - **cl._as_detailed_record(self._num_levels, self.comparison_levels), - } + record = cl._as_detailed_record(self._num_levels, self.comparison_levels) + record.comparison_name = self.output_column_name records.append(record) return records @@ -464,7 +465,7 @@ def human_readable_description(self): def match_weights_chart(self, as_dict=False): """Display a chart of comparison levels of the comparison""" - from splink.internals.charts import comparison_match_weights_chart + from splink.internals.charts import ComparisonMatchWeightsChart records = self._as_detailed_records - return comparison_match_weights_chart(records, as_dict=as_dict) + return ComparisonMatchWeightsChart(records, as_dict=as_dict).chart diff --git a/splink/internals/comparison_level.py b/splink/internals/comparison_level.py index 61aa669711..0f4de99597 100644 --- a/splink/internals/comparison_level.py +++ b/splink/internals/comparison_level.py @@ -4,6 +4,7 @@ import math import re from copy import copy +from dataclasses import asdict, dataclass from statistics import median from textwrap import dedent from typing import Any, Optional, Union, cast @@ -116,6 +117,34 @@ def _default_u_values(num_levels: int) -> list[float]: return u_vals +@dataclass +class ComparisonLevelDetailedRecord: + sql_condition: str | None + label_for_charts: str + + has_tf_adjustments: bool + tf_adjustment_column: str | None + tf_adjustment_weight: float | None + + is_null_level: bool + + m_probability: float | None + u_probability: float | None + m_probability_description: str | None + u_probability_description: str | None + + bayes_factor: float | None + log2_bayes_factor: float + bayes_factor_description: str + + comparison_vector_value: int + max_comparison_vector_value: int + comparison_name: str | None + + def as_dict(self): + return asdict(self) + + class ComparisonLevel: """Each ComparisonLevel defines a gradation (category) of similarity within a `Comparison`. @@ -735,39 +764,31 @@ def _as_completed_dict(self): def _as_detailed_record( self, comparison_num_levels: int, comparison_levels: list[ComparisonLevel] - ) -> dict[str, Any]: + ) -> ComparisonLevelDetailedRecord: "A detailed representation of this level to describe it in charting outputs" - output: dict[str, Any] = {} - output["sql_condition"] = self.sql_condition - output["label_for_charts"] = self._label_for_charts_no_duplicates( - comparison_levels + return ComparisonLevelDetailedRecord( + sql_condition=self.sql_condition, + label_for_charts=self._label_for_charts_no_duplicates(comparison_levels), + has_tf_adjustments=self._has_tf_adjustments, + tf_adjustment_column=( + self._tf_adjustment_input_column.input_name + if self._has_tf_adjustments + else None + ), + tf_adjustment_weight=self._tf_adjustment_weight, + is_null_level=self.is_null_level, + m_probability=self.m_probability if not self.is_null_level else None, + u_probability=self.u_probability if not self.is_null_level else None, + m_probability_description=self._m_probability_description, + u_probability_description=self._u_probability_description, + bayes_factor=self._bayes_factor, + log2_bayes_factor=self._log2_bayes_factor, + bayes_factor_description=self._bayes_factor_description, + comparison_vector_value=self.comparison_vector_value, + max_comparison_vector_value=comparison_num_levels - 1, + comparison_name=None, ) - if not self._is_null_level: - output["m_probability"] = self.m_probability - output["u_probability"] = self.u_probability - - output["m_probability_description"] = self._m_probability_description - output["u_probability_description"] = self._u_probability_description - - output["has_tf_adjustments"] = self._has_tf_adjustments - if self._has_tf_adjustments: - output["tf_adjustment_column"] = self._tf_adjustment_input_column.input_name - else: - output["tf_adjustment_column"] = None - output["tf_adjustment_weight"] = self._tf_adjustment_weight - - output["is_null_level"] = self.is_null_level - output["bayes_factor"] = self._bayes_factor - output["log2_bayes_factor"] = self._log2_bayes_factor - output["comparison_vector_value"] = self.comparison_vector_value - output["max_comparison_vector_value"] = comparison_num_levels - 1 - output["bayes_factor_description"] = self._bayes_factor_description - output["m_probability_description"] = self._m_probability_description - output["u_probability_description"] = self._u_probability_description - - return output - def _parameter_estimates_as_records( self, comparison_num_levels: int, comparison_levels: list[ComparisonLevel] ) -> list[dict[str, Any]]: @@ -786,9 +807,9 @@ def _parameter_estimates_as_records( else: record["estimated_probability_as_log_odds"] = None - record["sql_condition"] = cl_record["sql_condition"] - record["comparison_level_label"] = cl_record["label_for_charts"] - record["comparison_vector_value"] = cl_record["comparison_vector_value"] + record["sql_condition"] = cl_record.sql_condition + record["comparison_level_label"] = cl_record.label_for_charts + record["comparison_vector_value"] = cl_record.comparison_vector_value output_records.append(record) return output_records diff --git a/splink/internals/completeness.py b/splink/internals/completeness.py index 111d9999b6..c7f92b3bad 100644 --- a/splink/internals/completeness.py +++ b/splink/internals/completeness.py @@ -4,9 +4,7 @@ from splink.internals.charts import ( ChartReturnType, -) -from splink.internals.charts import ( - completeness_chart as records_to_completeness_chart, + CompletenessChart, ) from splink.internals.database_api import DatabaseAPISubClass from splink.internals.input_column import InputColumn @@ -131,4 +129,4 @@ def completeness_chart( db_api = get_db_api_from_inputs(splink_dataframe_or_dataframes) splink_df_dict = splink_dataframes_to_dict(splink_dataframe_or_dataframes) records = completeness_data(splink_df_dict, db_api, cols, table_names_for_chart) - return records_to_completeness_chart(records) + return CompletenessChart(records).chart diff --git a/splink/internals/em_training_session.py b/splink/internals/em_training_session.py index 866d7023ef..44bbfdf4c4 100644 --- a/splink/internals/em_training_session.py +++ b/splink/internals/em_training_session.py @@ -1,14 +1,15 @@ from __future__ import annotations import logging +from dataclasses import asdict, dataclass from typing import TYPE_CHECKING, List from splink.internals.blocking import BlockingRule, block_using_rules_sqls from splink.internals.charts import ( ChartReturnType, - m_u_parameters_interactive_history_chart, - match_weights_interactive_history_chart, - probability_two_random_records_match_iteration_chart, + MatchWeightsInteractiveHistoryChart, + MUParametersInteractiveHistoryChart, + ProbabilityTwoRandomRecordsMatchIterationChart, ) from splink.internals.comparison import Comparison from splink.internals.comparison_vector_values import ( @@ -22,6 +23,7 @@ from splink.internals.settings import ( ComparisonAndLevelDict, CoreModelSettings, + ModelParameterDetailedRecord, Settings, TrainingSettings, ) @@ -39,6 +41,27 @@ from splink.internals.splink_dataframe import SplinkDataFrame +@dataclass +class ModelParameterIterationDetailedRecord(ModelParameterDetailedRecord): + iteration: int + + @classmethod + def from_settings_param_detailed_record( + cls, + cl_rec: ModelParameterDetailedRecord, + *, + probability_two_random_records_match: float, + iteration: int, + ) -> ModelParameterIterationDetailedRecord: + cl_rec.probability_two_random_records_match = ( + probability_two_random_records_match + ) + return cls( + **asdict(cl_rec), + iteration=iteration, + ) + + class EMTrainingSession: """Manages training models using the Expectation Maximisation algorithm, and holds statistics on the evolution of parameter estimates. Plots diagnostic charts @@ -320,20 +343,20 @@ def _blocking_adjusted_probability_two_random_records_match(self): return adjusted_prop_m @property - def _iteration_history_records(self): + def _iteration_history_records(self) -> list[ModelParameterIterationDetailedRecord]: output_records = [] for iteration, core_model_settings in enumerate( self._core_model_settings_history ): - records = core_model_settings.parameters_as_detailed_records - - for r in records: - r["iteration"] = iteration - # TODO: why lambda from current settings, not history? - r["probability_two_random_records_match"] = ( - self.core_model_settings.probability_two_random_records_match + records = [ + ModelParameterIterationDetailedRecord.from_settings_param_detailed_record( + r, + probability_two_random_records_match=self.core_model_settings.probability_two_random_records_match, + iteration=iteration, ) + for r in core_model_settings.parameters_as_detailed_records + ] output_records.extend(records) return output_records @@ -361,7 +384,7 @@ def probability_two_random_records_match_iteration_chart(self) -> ChartReturnTyp An interactive Altair chart. """ records = self._lambda_history_records - return probability_two_random_records_match_iteration_chart(records) + return ProbabilityTwoRandomRecordsMatchIterationChart(records).chart def match_weights_interactive_history_chart(self) -> ChartReturnType: """ @@ -370,10 +393,10 @@ def match_weights_interactive_history_chart(self) -> ChartReturnType: Returns: An interactive Altair chart. """ - records = self._iteration_history_records - return match_weights_interactive_history_chart( - records, blocking_rule=self._blocking_rule_for_training.blocking_rule_sql - ) + return MatchWeightsInteractiveHistoryChart( + self._iteration_history_records, + blocking_rule_text=self._blocking_rule_for_training.blocking_rule_sql, + ).chart def m_u_values_interactive_history_chart(self) -> ChartReturnType: """ @@ -382,8 +405,9 @@ def m_u_values_interactive_history_chart(self) -> ChartReturnType: Returns: An interactive Altair chart. """ - records = self._iteration_history_records - return m_u_parameters_interactive_history_chart(records) + return MUParametersInteractiveHistoryChart( + self._iteration_history_records + ).chart def __repr__(self): deactivated_cols = ", ".join( diff --git a/splink/internals/files/chart_defs/missingness.json b/splink/internals/files/chart_defs/missingness.json deleted file mode 100644 index de18f660fb..0000000000 --- a/splink/internals/files/chart_defs/missingness.json +++ /dev/null @@ -1,80 +0,0 @@ -{ - "config": { - "view": { - "continuousWidth": 400, - "continuousHeight": 300 - }, - "axis": { - "labelFontSize": 11 - } - }, - "title": "", - "layer": [ - { - "mark": "bar", - "encoding": { - "color": { - "type": "quantitative", - "field": "null_proportion", - "legend": { - "format": ".0%", - "offset": 30 - }, - "scale": {"domain": [0, 1], - "range": "heatmap" - }, - "title": "Missingness" - }, - "tooltip": [ - { - "type": "nominal", - "field": "column_name", - "title": "Column" - }, - { - "type": "quantitative", - "field": "null_count", - "format": ",.0f", - "title": "Count of nulls" - }, - { - "type": "quantitative", - "field": "null_proportion", - "format": ".2%", - "title": "Percentage of nulls" - }, - { - "type": "quantitative", - "field": "total_record_count", - "format": ",.0f", - "title": "Total record count" - } - ], - "x": { - "type": "quantitative", - "scale": {"domain": [0, 1]}, - "axis": { - "labelAlign": "center", - "format": "%", - "title": "Percentage of nulls" - }, - "field": "null_proportion" - }, - "y": { - "type": "nominal", - "axis": { - "title": "" - }, - "field": "column_name", - "sort": "-x" - } - }, - "title": "" - } - ], - "data": { - "values": "", - "name": "data-0e7bce5a1d2f132e282789d6ef7780fe" - }, - "$schema": "https://vega.github.io/schema/vega-lite/v5.9.3.json" -} diff --git a/splink/internals/linker_components/evaluation.py b/splink/internals/linker_components/evaluation.py index 080df0a49b..4d90fa29e3 100644 --- a/splink/internals/linker_components/evaluation.py +++ b/splink/internals/linker_components/evaluation.py @@ -9,12 +9,12 @@ truth_space_table_from_labels_table, ) from splink.internals.charts import ( + AccuracyChart, ChartReturnType, - accuracy_chart, - precision_recall_chart, - roc_chart, - threshold_selection_tool, - unlinkables_chart, + PrecisionRecallChart, + ROCChart, + ThresholdSelectionToolChart, + UnlinkablesChart, ) from splink.internals.labelling_tool import ( generate_labelling_tool_comparisons, @@ -168,13 +168,13 @@ def accuracy_analysis_from_labels_column( recs = df_truth_space.as_record_dict() if output_type == "threshold_selection": - return threshold_selection_tool(recs, add_metrics=add_metrics) + return ThresholdSelectionToolChart(recs, add_metrics=add_metrics).chart elif output_type == "accuracy": - return accuracy_chart(recs, add_metrics=add_metrics) + return AccuracyChart(recs, add_metrics=add_metrics).chart elif output_type == "roc": - return roc_chart(recs) + return ROCChart(recs).chart elif output_type == "precision_recall": - return precision_recall_chart(recs) + return PrecisionRecallChart(recs).chart elif output_type == "table": return df_truth_space else: @@ -281,13 +281,13 @@ def accuracy_analysis_from_labels_table( recs = df_truth_space.as_record_dict() if output_type == "threshold_selection": - return threshold_selection_tool(recs, add_metrics=add_metrics) + return ThresholdSelectionToolChart(recs, add_metrics=add_metrics).chart elif output_type == "accuracy": - return accuracy_chart(recs, add_metrics=add_metrics) + return AccuracyChart(recs, add_metrics=add_metrics).chart elif output_type == "roc": - return roc_chart(recs) + return ROCChart(recs).chart elif output_type == "precision_recall": - return precision_recall_chart(recs) + return PrecisionRecallChart(recs).chart elif output_type == "table": return df_truth_space else: @@ -337,7 +337,7 @@ def prediction_errors_from_labels_column( def unlinkables_chart( self, - x_col: str = "match_weight", + x_col: Literal["match_weight", "match_probability"] = "match_weight", name_of_data_in_title: str | None = None, as_dict: bool = False, ) -> ChartReturnType: @@ -349,6 +349,7 @@ def unlinkables_chart( Args: x_col (str, optional): Column to use for the x-axis. + Must be either "match_weight" or "match_probability". Defaults to "match_weight". name_of_data_in_title (str, optional): Name of the source dataset to use for the title of the output chart. @@ -367,7 +368,7 @@ def unlinkables_chart( # Link our initial df on itself and calculate the % of unlinkable entries records = unlinkables_data(self._linker) - return unlinkables_chart(records, x_col, name_of_data_in_title, as_dict) + return UnlinkablesChart(records, x_col, name_of_data_in_title, as_dict).chart def labelling_tool_for_specific_record( self, diff --git a/splink/internals/linker_components/visualisations.py b/splink/internals/linker_components/visualisations.py index efc7d08c89..2867bf3029 100644 --- a/splink/internals/linker_components/visualisations.py +++ b/splink/internals/linker_components/visualisations.py @@ -4,9 +4,9 @@ from splink.internals.charts import ( ChartReturnType, - match_weights_histogram, - parameter_estimate_comparisons, - waterfall_chart, + MatchWeightsHistogramChart, + ParameterEstimateComparisonsChart, + WaterfallChart, ) from splink.internals.cluster_studio import ( SamplingMethods, @@ -26,6 +26,7 @@ from splink.internals.term_frequencies import ( tf_adjustment_chart, ) +from splink.internals.waterfall_chart import records_to_waterfall_data if TYPE_CHECKING: from splink.internals.linker import Linker @@ -139,9 +140,9 @@ def match_weights_histogram( """ df = histogram_data(self._linker, df_predict, target_bins) recs = df.as_record_dict() - return match_weights_histogram( - recs, width=width, height=height, as_dict=as_dict - ) + chart = MatchWeightsHistogramChart(recs, as_dict=as_dict) + chart.set_width_height(width=width, height=height) + return chart.chart def parameter_estimate_comparisons_chart( self, include_m: bool = True, include_u: bool = False, as_dict: bool = False @@ -187,8 +188,8 @@ def parameter_estimate_comparisons_chart( to_retain.append("u") records = [r for r in records if r["m_or_u"] in to_retain] - - return parameter_estimate_comparisons(records, as_dict) + # TODO: this logic into chart object + return ParameterEstimateComparisonsChart(records, as_dict).chart def tf_adjustment_chart( self, @@ -288,13 +289,15 @@ def waterfall_chart( """ self._linker._raise_error_if_necessary_waterfall_columns_not_computed() - return waterfall_chart( - records, - self._linker._settings_obj, + data = records_to_waterfall_data( + records, self._linker._settings_obj, remove_sensitive_data + ) + + return WaterfallChart( + data, filter_nulls, - remove_sensitive_data, as_dict, - ) + ).chart def comparison_viewer_dashboard( self, diff --git a/splink/internals/settings.py b/splink/internals/settings.py index 9a4adea43b..244c66f5a6 100644 --- a/splink/internals/settings.py +++ b/splink/internals/settings.py @@ -6,9 +6,12 @@ from typing import Any, List, Literal, TypedDict from splink.internals.blocking import BlockingRule -from splink.internals.charts import m_u_parameters_chart, match_weights_chart +from splink.internals.charts import MatchWeightsChart, MUParametersChart from splink.internals.comparison import Comparison -from splink.internals.comparison_level import ComparisonLevel +from splink.internals.comparison_level import ( + ComparisonLevel, + ComparisonLevelDetailedRecord, +) from splink.internals.dialects import SplinkDialect from splink.internals.input_column import InputColumn from splink.internals.misc import ( @@ -27,6 +30,26 @@ class ComparisonAndLevelDict(TypedDict): comparison: Comparison +@dataclass +class ModelParameterDetailedRecord(ComparisonLevelDetailedRecord): + probability_two_random_records_match: float + comparison_sort_order: int + + @classmethod + def from_cl_detailed_record( + cls, + cl_rec: ComparisonLevelDetailedRecord, + *, + probability_two_random_records_match: float, + comparison_sort_order: int, + ) -> ModelParameterDetailedRecord: + return cls( + **asdict(cl_rec), + probability_two_random_records_match=probability_two_random_records_match, + comparison_sort_order=comparison_sort_order, + ) + + @dataclass(frozen=True) class ColumnInfoSettings: match_weight_column_prefix: str @@ -109,14 +132,18 @@ def copy(self): return deepcopy(self) @property - def parameters_as_detailed_records(self): + def parameters_as_detailed_records(self) -> list[ModelParameterDetailedRecord]: output = [] rr_match = self.probability_two_random_records_match for i, cc in enumerate(self.comparisons): - records = cc._as_detailed_records - for r in records: - r["probability_two_random_records_match"] = rr_match - r["comparison_sort_order"] = i + records = [ + ModelParameterDetailedRecord.from_cl_detailed_record( + r, + probability_two_random_records_match=rr_match, + comparison_sort_order=i, + ) + for r in cc._as_detailed_records + ] output.extend(records) prior_description = ( @@ -128,26 +155,26 @@ def parameters_as_detailed_records(self): ) # Finally add a record for probability_two_random_records_match - prop_record = { - "comparison_name": "probability_two_random_records_match", - "sql_condition": None, - "label_for_charts": "", - "m_probability": None, - "u_probability": None, - "m_probability_description": None, - "u_probability_description": None, - "has_tf_adjustments": False, - "tf_adjustment_column": None, - "tf_adjustment_weight": None, - "is_null_level": False, - "bayes_factor": prob_to_bayes_factor(rr_match), - "log2_bayes_factor": prob_to_match_weight(rr_match), - "comparison_vector_value": 0, - "max_comparison_vector_value": 0, - "bayes_factor_description": prior_description, - "probability_two_random_records_match": rr_match, - "comparison_sort_order": -1, - } + prop_record = ModelParameterDetailedRecord( + comparison_name="probability_two_random_records_match", + sql_condition=None, + label_for_charts="", + m_probability=None, + u_probability=None, + m_probability_description=None, + u_probability_description=None, + has_tf_adjustments=False, + tf_adjustment_column=None, + tf_adjustment_weight=None, + is_null_level=False, + bayes_factor=prob_to_bayes_factor(rr_match), + log2_bayes_factor=prob_to_match_weight(rr_match), + comparison_vector_value=0, + max_comparison_vector_value=0, + bayes_factor_description=prior_description, + probability_two_random_records_match=rr_match, + comparison_sort_order=-1, + ) output.insert(0, prop_record) return output @@ -568,13 +595,15 @@ def _as_completed_dict(self): } def match_weights_chart(self, as_dict=False): - records = self._parameters_as_detailed_records - - return match_weights_chart(records, as_dict=as_dict) + return MatchWeightsChart( + self._parameters_as_detailed_records, + as_dict=as_dict, + ).chart def m_u_parameters_chart(self, as_dict=False): - records = self._parameters_as_detailed_records - return m_u_parameters_chart(records, as_dict=as_dict) + return MUParametersChart( + self._parameters_as_detailed_records, as_dict=as_dict + ).chart def _columns_without_estimated_parameters_message(self): message_lines = [] diff --git a/splink/internals/term_frequencies.py b/splink/internals/term_frequencies.py index da12b12699..85034f6474 100644 --- a/splink/internals/term_frequencies.py +++ b/splink/internals/term_frequencies.py @@ -239,7 +239,9 @@ def tf_adjustment_chart( # Data for chart comparison = linker._settings_obj._get_comparison_by_output_column_name(col) - comparison_records = comparison._as_detailed_records + comparison_records = list( + map(lambda rec: rec.as_dict(), comparison._as_detailed_records) + ) keys_to_retain = [ "comparison_vector_value", diff --git a/splink/internals/waterfall_chart.py b/splink/internals/waterfall_chart.py index a8a080957b..07ef8b3992 100644 --- a/splink/internals/waterfall_chart.py +++ b/splink/internals/waterfall_chart.py @@ -57,9 +57,9 @@ def _comparison_records( cl = c._get_comparison_level_by_comparison_vector_value(cv_value) waterfall_record = { field: value - for field, value in cl._as_detailed_record( - c._num_levels, c.comparison_levels - ).items() + for field, value in cl._as_detailed_record(c._num_levels, c.comparison_levels) + .as_dict() + .items() if field in [ "label_for_charts", diff --git a/tests/test_charts.py b/tests/test_charts.py index 455b6ac837..b85cb33114 100644 --- a/tests/test_charts.py +++ b/tests/test_charts.py @@ -142,6 +142,24 @@ def test_m_u_charts(dialect, test_helpers): linker.visualisations.match_weights_chart() +@mark_with_dialects_excluding() +def test_comparison_match_weight_charts(dialect, test_helpers): + settings = { + "link_type": "dedupe_only", + "comparisons": [ + cl.ExactMatch("gender"), + cl.ExactMatch("tm_partial"), + cl.LevenshteinAtThresholds("surname", [1]), + ], + } + helper = test_helpers[dialect] + + linker = helper.linker_with_registration(df, settings) + + comp = linker._settings_obj.comparisons[0] + comp.match_weights_chart() + + @mark_with_dialects_excluding() def test_parameter_estimate_charts(dialect, test_helpers): settings = {