From 65f10630e49d589e8b7e804c025950393e0c751d Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Tue, 17 Feb 2026 11:04:47 +0000 Subject: [PATCH 01/24] remove orphaned chart gone since https://github.com/moj-analytical-services/splink/pull/2157 --- splink/internals/charts.py | 14 ---- .../files/chart_defs/missingness.json | 80 ------------------- 2 files changed, 94 deletions(-) delete mode 100644 splink/internals/files/chart_defs/missingness.json diff --git a/splink/internals/charts.py b/splink/internals/charts.py index 9f1a4a31fa..d48b1981de 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -349,20 +349,6 @@ def parameter_estimate_comparisons(records, as_dict=False): return altair_or_json(chart, as_dict=as_dict) -def missingness_chart(records, as_dict=False): - chart_path = "missingness.json" - chart = load_chart_definition(chart_path) - - chart["data"]["values"] = records - - record_count = records[0]["total_record_count"] - - for c in chart["layer"]: - c["title"] = f"Missingness per column out of {record_count:,.0f} records" - - return altair_or_json(chart, as_dict=as_dict) - - def unlinkables_chart( records, x_col="match_weight", diff --git a/splink/internals/files/chart_defs/missingness.json b/splink/internals/files/chart_defs/missingness.json deleted file mode 100644 index de18f660fb..0000000000 --- a/splink/internals/files/chart_defs/missingness.json +++ /dev/null @@ -1,80 +0,0 @@ -{ - "config": { - "view": { - "continuousWidth": 400, - "continuousHeight": 300 - }, - "axis": { - "labelFontSize": 11 - } - }, - "title": "", - "layer": [ - { - "mark": "bar", - "encoding": { - "color": { - "type": "quantitative", - "field": "null_proportion", - "legend": { - "format": ".0%", - "offset": 30 - }, - "scale": {"domain": [0, 1], - "range": "heatmap" - }, - "title": "Missingness" - }, - "tooltip": [ - { - "type": "nominal", - "field": "column_name", - "title": "Column" - }, - { - "type": "quantitative", - "field": "null_count", - "format": ",.0f", - "title": "Count of nulls" - }, - { - "type": "quantitative", - "field": "null_proportion", - "format": ".2%", - "title": "Percentage of nulls" - }, - { - "type": "quantitative", - "field": "total_record_count", - "format": ",.0f", - "title": "Total record count" - } - ], - "x": { - "type": "quantitative", - "scale": {"domain": [0, 1]}, - "axis": { - "labelAlign": "center", - "format": "%", - "title": "Percentage of nulls" - }, - "field": "null_proportion" - }, - "y": { - "type": "nominal", - "axis": { - "title": "" - }, - "field": "column_name", - "sort": "-x" - } - }, - "title": "" - } - ], - "data": { - "values": "", - "name": "data-0e7bce5a1d2f132e282789d6ef7780fe" - }, - "$schema": "https://vega.github.io/schema/vega-lite/v5.9.3.json" -} From c37090e9e8e895f4ce72832cc6d45b664ba32243 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Tue, 17 Feb 2026 13:08:54 +0000 Subject: [PATCH 02/24] cl detailed record as dataclass --- splink/internals/comparison.py | 4 +- splink/internals/comparison_level.py | 85 +++++++++++++++++----------- splink/internals/waterfall_chart.py | 6 +- 3 files changed, 58 insertions(+), 37 deletions(-) diff --git a/splink/internals/comparison.py b/splink/internals/comparison.py index c7e69492d6..0f8d134937 100644 --- a/splink/internals/comparison.py +++ b/splink/internals/comparison.py @@ -367,7 +367,9 @@ def _as_detailed_records(self) -> list[dict[str, Any]]: record["comparison_name"] = self.output_column_name record = { **record, - **cl._as_detailed_record(self._num_levels, self.comparison_levels), + **cl._as_detailed_record( + self._num_levels, self.comparison_levels + ).as_dict(), } records.append(record) return records diff --git a/splink/internals/comparison_level.py b/splink/internals/comparison_level.py index 61aa669711..026e135ff8 100644 --- a/splink/internals/comparison_level.py +++ b/splink/internals/comparison_level.py @@ -4,6 +4,7 @@ import math import re from copy import copy +from dataclasses import asdict, dataclass from statistics import median from textwrap import dedent from typing import Any, Optional, Union, cast @@ -116,6 +117,33 @@ def _default_u_values(num_levels: int) -> list[float]: return u_vals +@dataclass +class ComparisonLevelDetailedRecord: + sql_condition: str + label_for_charts: str + + has_tf_adjustments: bool + tf_adjustment_column: str | None + tf_adjustment_weight: float + + is_null_level: bool + + m_probability: float | None + u_probability: float | None + m_probability_description: str | None + u_probability_description: str | None + + bayes_factor: float | None + log2_bayes_factor: float + bayes_factor_description: str + + comparison_vector_value: int + max_comparison_vector_value: int + + def as_dict(self): + return asdict(self) + + class ComparisonLevel: """Each ComparisonLevel defines a gradation (category) of similarity within a `Comparison`. @@ -735,39 +763,30 @@ def _as_completed_dict(self): def _as_detailed_record( self, comparison_num_levels: int, comparison_levels: list[ComparisonLevel] - ) -> dict[str, Any]: + ) -> ComparisonLevelDetailedRecord: "A detailed representation of this level to describe it in charting outputs" - output: dict[str, Any] = {} - output["sql_condition"] = self.sql_condition - output["label_for_charts"] = self._label_for_charts_no_duplicates( - comparison_levels + return ComparisonLevelDetailedRecord( + sql_condition=self.sql_condition, + label_for_charts=self._label_for_charts_no_duplicates(comparison_levels), + has_tf_adjustments=self._has_tf_adjustments, + tf_adjustment_column=( + self._tf_adjustment_input_column.input_name + if self._has_tf_adjustments + else None + ), + tf_adjustment_weight=self._tf_adjustment_weight, + is_null_level=self.is_null_level, + m_probability=self.m_probability if not self.is_null_level else None, + u_probability=self.u_probability if not self.is_null_level else None, + m_probability_description=self._m_probability_description, + u_probability_description=self._u_probability_description, + bayes_factor=self._bayes_factor, + log2_bayes_factor=self._log2_bayes_factor, + bayes_factor_description=self._bayes_factor_description, + comparison_vector_value=self.comparison_vector_value, + max_comparison_vector_value=comparison_num_levels - 1, ) - if not self._is_null_level: - output["m_probability"] = self.m_probability - output["u_probability"] = self.u_probability - - output["m_probability_description"] = self._m_probability_description - output["u_probability_description"] = self._u_probability_description - - output["has_tf_adjustments"] = self._has_tf_adjustments - if self._has_tf_adjustments: - output["tf_adjustment_column"] = self._tf_adjustment_input_column.input_name - else: - output["tf_adjustment_column"] = None - output["tf_adjustment_weight"] = self._tf_adjustment_weight - - output["is_null_level"] = self.is_null_level - output["bayes_factor"] = self._bayes_factor - output["log2_bayes_factor"] = self._log2_bayes_factor - output["comparison_vector_value"] = self.comparison_vector_value - output["max_comparison_vector_value"] = comparison_num_levels - 1 - output["bayes_factor_description"] = self._bayes_factor_description - output["m_probability_description"] = self._m_probability_description - output["u_probability_description"] = self._u_probability_description - - return output - def _parameter_estimates_as_records( self, comparison_num_levels: int, comparison_levels: list[ComparisonLevel] ) -> list[dict[str, Any]]: @@ -786,9 +805,9 @@ def _parameter_estimates_as_records( else: record["estimated_probability_as_log_odds"] = None - record["sql_condition"] = cl_record["sql_condition"] - record["comparison_level_label"] = cl_record["label_for_charts"] - record["comparison_vector_value"] = cl_record["comparison_vector_value"] + record["sql_condition"] = cl_record.sql_condition + record["comparison_level_label"] = cl_record.label_for_charts + record["comparison_vector_value"] = cl_record.comparison_vector_value output_records.append(record) return output_records diff --git a/splink/internals/waterfall_chart.py b/splink/internals/waterfall_chart.py index a8a080957b..07ef8b3992 100644 --- a/splink/internals/waterfall_chart.py +++ b/splink/internals/waterfall_chart.py @@ -57,9 +57,9 @@ def _comparison_records( cl = c._get_comparison_level_by_comparison_vector_value(cv_value) waterfall_record = { field: value - for field, value in cl._as_detailed_record( - c._num_levels, c.comparison_levels - ).items() + for field, value in cl._as_detailed_record(c._num_levels, c.comparison_levels) + .as_dict() + .items() if field in [ "label_for_charts", From b59f5ccda9e84ff0d6144fa805c1778b4e5caf4c Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Tue, 17 Feb 2026 13:18:59 +0000 Subject: [PATCH 03/24] comparison name in cl detailed rec --- splink/internals/comparison.py | 18 ++++++++---------- splink/internals/comparison_level.py | 1 + 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/splink/internals/comparison.py b/splink/internals/comparison.py index 0f8d134937..95816ca2db 100644 --- a/splink/internals/comparison.py +++ b/splink/internals/comparison.py @@ -10,7 +10,11 @@ join_list_with_commas_final_and, ) -from .comparison_level import ComparisonLevel, _default_m_values, _default_u_values +from .comparison_level import ( + ComparisonLevel, + _default_m_values, + _default_u_values, +) # https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports if TYPE_CHECKING: @@ -363,15 +367,9 @@ def _is_trained(self): def _as_detailed_records(self) -> list[dict[str, Any]]: records = [] for cl in self.comparison_levels: - record = {} - record["comparison_name"] = self.output_column_name - record = { - **record, - **cl._as_detailed_record( - self._num_levels, self.comparison_levels - ).as_dict(), - } - records.append(record) + record = cl._as_detailed_record(self._num_levels, self.comparison_levels) + record.comparison_name = self.output_column_name + records.append(record.as_dict()) return records @property diff --git a/splink/internals/comparison_level.py b/splink/internals/comparison_level.py index 026e135ff8..583268518f 100644 --- a/splink/internals/comparison_level.py +++ b/splink/internals/comparison_level.py @@ -139,6 +139,7 @@ class ComparisonLevelDetailedRecord: comparison_vector_value: int max_comparison_vector_value: int + comparison_name: str | None = None def as_dict(self): return asdict(self) From e8987351ff3b9d0999ff6e47e94d0e47cc11ed89 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Tue, 17 Feb 2026 13:30:33 +0000 Subject: [PATCH 04/24] detailed record dataclass in comparison --- splink/internals/comparison.py | 7 ++++--- splink/internals/settings.py | 2 +- splink/internals/term_frequencies.py | 4 +++- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/splink/internals/comparison.py b/splink/internals/comparison.py index 95816ca2db..b8988a9934 100644 --- a/splink/internals/comparison.py +++ b/splink/internals/comparison.py @@ -12,6 +12,7 @@ from .comparison_level import ( ComparisonLevel, + ComparisonLevelDetailedRecord, _default_m_values, _default_u_values, ) @@ -364,12 +365,12 @@ def _is_trained(self): return self._all_m_are_trained and self._all_u_are_trained @property - def _as_detailed_records(self) -> list[dict[str, Any]]: + def _as_detailed_records(self) -> list[ComparisonLevelDetailedRecord]: records = [] for cl in self.comparison_levels: record = cl._as_detailed_record(self._num_levels, self.comparison_levels) record.comparison_name = self.output_column_name - records.append(record.as_dict()) + records.append(record) return records @property @@ -466,5 +467,5 @@ def match_weights_chart(self, as_dict=False): """Display a chart of comparison levels of the comparison""" from splink.internals.charts import comparison_match_weights_chart - records = self._as_detailed_records + records = list(map(lambda rec: rec.as_dict(), self._as_detailed_records)) return comparison_match_weights_chart(records, as_dict=as_dict) diff --git a/splink/internals/settings.py b/splink/internals/settings.py index 9a4adea43b..0a5fcde4eb 100644 --- a/splink/internals/settings.py +++ b/splink/internals/settings.py @@ -113,7 +113,7 @@ def parameters_as_detailed_records(self): output = [] rr_match = self.probability_two_random_records_match for i, cc in enumerate(self.comparisons): - records = cc._as_detailed_records + records = list(map(lambda rec: rec.as_dict(), cc._as_detailed_records)) for r in records: r["probability_two_random_records_match"] = rr_match r["comparison_sort_order"] = i diff --git a/splink/internals/term_frequencies.py b/splink/internals/term_frequencies.py index da12b12699..85034f6474 100644 --- a/splink/internals/term_frequencies.py +++ b/splink/internals/term_frequencies.py @@ -239,7 +239,9 @@ def tf_adjustment_chart( # Data for chart comparison = linker._settings_obj._get_comparison_by_output_column_name(col) - comparison_records = comparison._as_detailed_records + comparison_records = list( + map(lambda rec: rec.as_dict(), comparison._as_detailed_records) + ) keys_to_retain = [ "comparison_vector_value", From e49a2e332ce3801e64c0d3c955dfbb7be3ca8fd6 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Tue, 17 Feb 2026 14:47:20 +0000 Subject: [PATCH 05/24] specialised detailed record for settings --- splink/internals/comparison_level.py | 6 +-- splink/internals/settings.py | 79 +++++++++++++++++++--------- 2 files changed, 56 insertions(+), 29 deletions(-) diff --git a/splink/internals/comparison_level.py b/splink/internals/comparison_level.py index 583268518f..85dcb0208c 100644 --- a/splink/internals/comparison_level.py +++ b/splink/internals/comparison_level.py @@ -117,14 +117,14 @@ def _default_u_values(num_levels: int) -> list[float]: return u_vals -@dataclass +@dataclass(kw_only=True) class ComparisonLevelDetailedRecord: - sql_condition: str + sql_condition: str | None label_for_charts: str has_tf_adjustments: bool tf_adjustment_column: str | None - tf_adjustment_weight: float + tf_adjustment_weight: float | None is_null_level: bool diff --git a/splink/internals/settings.py b/splink/internals/settings.py index 0a5fcde4eb..2d417e3f7c 100644 --- a/splink/internals/settings.py +++ b/splink/internals/settings.py @@ -8,7 +8,10 @@ from splink.internals.blocking import BlockingRule from splink.internals.charts import m_u_parameters_chart, match_weights_chart from splink.internals.comparison import Comparison -from splink.internals.comparison_level import ComparisonLevel +from splink.internals.comparison_level import ( + ComparisonLevel, + ComparisonLevelDetailedRecord, +) from splink.internals.dialects import SplinkDialect from splink.internals.input_column import InputColumn from splink.internals.misc import ( @@ -27,6 +30,26 @@ class ComparisonAndLevelDict(TypedDict): comparison: Comparison +@dataclass(kw_only=True) +class ModelParameterDetailedRecord(ComparisonLevelDetailedRecord): + probability_two_random_records_match: float + comparison_sort_order: int + + @classmethod + def from_cl_detailed_record( + cls, + cl_rec: ComparisonLevelDetailedRecord, + *, + probability_two_random_records_match: float, + comparison_sort_order: int, + ) -> ModelParameterDetailedRecord: + return cls( + **asdict(cl_rec), + probability_two_random_records_match=probability_two_random_records_match, + comparison_sort_order=comparison_sort_order, + ) + + @dataclass(frozen=True) class ColumnInfoSettings: match_weight_column_prefix: str @@ -109,14 +132,18 @@ def copy(self): return deepcopy(self) @property - def parameters_as_detailed_records(self): + def parameters_as_detailed_records(self) -> list[ModelParameterDetailedRecord]: output = [] rr_match = self.probability_two_random_records_match for i, cc in enumerate(self.comparisons): - records = list(map(lambda rec: rec.as_dict(), cc._as_detailed_records)) - for r in records: - r["probability_two_random_records_match"] = rr_match - r["comparison_sort_order"] = i + records = [ + ModelParameterDetailedRecord.from_cl_detailed_record( + r, + probability_two_random_records_match=rr_match, + comparison_sort_order=i, + ) + for r in cc._as_detailed_records + ] output.extend(records) prior_description = ( @@ -128,26 +155,26 @@ def parameters_as_detailed_records(self): ) # Finally add a record for probability_two_random_records_match - prop_record = { - "comparison_name": "probability_two_random_records_match", - "sql_condition": None, - "label_for_charts": "", - "m_probability": None, - "u_probability": None, - "m_probability_description": None, - "u_probability_description": None, - "has_tf_adjustments": False, - "tf_adjustment_column": None, - "tf_adjustment_weight": None, - "is_null_level": False, - "bayes_factor": prob_to_bayes_factor(rr_match), - "log2_bayes_factor": prob_to_match_weight(rr_match), - "comparison_vector_value": 0, - "max_comparison_vector_value": 0, - "bayes_factor_description": prior_description, - "probability_two_random_records_match": rr_match, - "comparison_sort_order": -1, - } + prop_record = ModelParameterDetailedRecord( + comparison_name="probability_two_random_records_match", + sql_condition=None, + label_for_charts="", + m_probability=None, + u_probability=None, + m_probability_description=None, + u_probability_description=None, + has_tf_adjustments=False, + tf_adjustment_column=None, + tf_adjustment_weight=None, + is_null_level=False, + bayes_factor=prob_to_bayes_factor(rr_match), + log2_bayes_factor=prob_to_match_weight(rr_match), + comparison_vector_value=0, + max_comparison_vector_value=0, + bayes_factor_description=prior_description, + probability_two_random_records_match=rr_match, + comparison_sort_order=-1, + ) output.insert(0, prop_record) return output From 37841a6bf4fb0077d6cf0e1d0ce0d09f479a96a2 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Tue, 17 Feb 2026 15:26:31 +0000 Subject: [PATCH 06/24] em iteration record subclass --- splink/internals/em_training_session.py | 43 +++++++++++++++++++------ splink/internals/settings.py | 8 +++-- 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/splink/internals/em_training_session.py b/splink/internals/em_training_session.py index 866d7023ef..cbd25c9c12 100644 --- a/splink/internals/em_training_session.py +++ b/splink/internals/em_training_session.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from dataclasses import asdict, dataclass from typing import TYPE_CHECKING, List from splink.internals.blocking import BlockingRule, block_using_rules_sqls @@ -22,6 +23,7 @@ from splink.internals.settings import ( ComparisonAndLevelDict, CoreModelSettings, + ModelParameterDetailedRecord, Settings, TrainingSettings, ) @@ -39,6 +41,27 @@ from splink.internals.splink_dataframe import SplinkDataFrame +@dataclass(kw_only=True) +class ModelParameterIterationDetailedRecord(ModelParameterDetailedRecord): + iteration: int + + @classmethod + def from_settings_param_detailed_record( + cls, + cl_rec: ModelParameterDetailedRecord, + *, + probability_two_random_records_match: float, + iteration: int, + ) -> ModelParameterIterationDetailedRecord: + cl_rec.probability_two_random_records_match = ( + probability_two_random_records_match + ) + return cls( + **asdict(cl_rec), + iteration=iteration, + ) + + class EMTrainingSession: """Manages training models using the Expectation Maximisation algorithm, and holds statistics on the evolution of parameter estimates. Plots diagnostic charts @@ -320,20 +343,20 @@ def _blocking_adjusted_probability_two_random_records_match(self): return adjusted_prop_m @property - def _iteration_history_records(self): + def _iteration_history_records(self) -> list[ModelParameterIterationDetailedRecord]: output_records = [] for iteration, core_model_settings in enumerate( self._core_model_settings_history ): - records = core_model_settings.parameters_as_detailed_records - - for r in records: - r["iteration"] = iteration - # TODO: why lambda from current settings, not history? - r["probability_two_random_records_match"] = ( - self.core_model_settings.probability_two_random_records_match + records = [ + ModelParameterIterationDetailedRecord.from_settings_param_detailed_record( + r, + probability_two_random_records_match=self.core_model_settings.probability_two_random_records_match, + iteration=iteration, ) + for r in core_model_settings.parameters_as_detailed_records + ] output_records.extend(records) return output_records @@ -370,7 +393,7 @@ def match_weights_interactive_history_chart(self) -> ChartReturnType: Returns: An interactive Altair chart. """ - records = self._iteration_history_records + records = list(map(lambda rec: rec.as_dict(), self._iteration_history_records)) return match_weights_interactive_history_chart( records, blocking_rule=self._blocking_rule_for_training.blocking_rule_sql ) @@ -382,7 +405,7 @@ def m_u_values_interactive_history_chart(self) -> ChartReturnType: Returns: An interactive Altair chart. """ - records = self._iteration_history_records + records = list(map(lambda rec: rec.as_dict(), self._iteration_history_records)) return m_u_parameters_interactive_history_chart(records) def __repr__(self): diff --git a/splink/internals/settings.py b/splink/internals/settings.py index 2d417e3f7c..40c6029796 100644 --- a/splink/internals/settings.py +++ b/splink/internals/settings.py @@ -595,12 +595,16 @@ def _as_completed_dict(self): } def match_weights_chart(self, as_dict=False): - records = self._parameters_as_detailed_records + records = list( + map(lambda rec: rec.as_dict(), self._parameters_as_detailed_records) + ) return match_weights_chart(records, as_dict=as_dict) def m_u_parameters_chart(self, as_dict=False): - records = self._parameters_as_detailed_records + records = list( + map(lambda rec: rec.as_dict(), self._parameters_as_detailed_records) + ) return m_u_parameters_chart(records, as_dict=as_dict) def _columns_without_estimated_parameters_message(self): From c5c01765da6d76f2bb6c83a9624b7b8972649cc2 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Tue, 17 Feb 2026 17:16:56 +0000 Subject: [PATCH 07/24] comparison_match_weights_chart using dataclasses directly --- splink/internals/charts.py | 23 ++++++++++++++++++----- splink/internals/comparison.py | 2 +- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/splink/internals/charts.py b/splink/internals/charts.py index d48b1981de..6d66cf4f34 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -3,8 +3,10 @@ import json import math import os -from typing import TYPE_CHECKING, Any, Dict, Union +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Dict, Protocol, Union +from splink.internals.comparison_level import ComparisonLevelDetailedRecord from splink.internals.misc import read_resource from splink.internals.waterfall_chart import records_to_waterfall_data @@ -41,6 +43,14 @@ def altair_or_json( return chart_dict +class AsDictable(Protocol): + def as_dict(self) -> dict[str, Any]: ... + + +def list_items_as_dicts(lst: Iterable[AsDictable]) -> list[dict[str, Any]]: + return list(map(lambda item: item.as_dict(), lst)) + + iframe_message = """ To view in Jupyter you can use the following command: @@ -110,18 +120,21 @@ def match_weights_chart(records, as_dict=False): return altair_or_json(chart, as_dict=as_dict) -def comparison_match_weights_chart(records, as_dict=False): +def comparison_match_weights_chart( + records: list[ComparisonLevelDetailedRecord], as_dict: bool = False +) -> ChartReturnType: chart_path = "match_weights_interactive_history.json" chart = load_chart_definition(chart_path) # Remove iteration history since this is a static chart - del chart["vconcat"][0] + # TODO: some render issue if we remove empty top panel, so leave for now + # del chart["vconcat"][0] del chart["params"] del chart["transform"] chart["title"]["text"] = "Comparison summary" - records = [r for r in records if r["comparison_vector_value"] != -1] - chart["data"]["values"] = records + records = [r for r in records if r.comparison_vector_value != -1] + chart["data"]["values"] = list_items_as_dicts(records) return altair_or_json(chart, as_dict=as_dict) diff --git a/splink/internals/comparison.py b/splink/internals/comparison.py index b8988a9934..a3e9857499 100644 --- a/splink/internals/comparison.py +++ b/splink/internals/comparison.py @@ -467,5 +467,5 @@ def match_weights_chart(self, as_dict=False): """Display a chart of comparison levels of the comparison""" from splink.internals.charts import comparison_match_weights_chart - records = list(map(lambda rec: rec.as_dict(), self._as_detailed_records)) + records = self._as_detailed_records return comparison_match_weights_chart(records, as_dict=as_dict) From 9ca3d9f0e091b8647363ba194822d80edd07bccc Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Tue, 17 Feb 2026 17:17:06 +0000 Subject: [PATCH 08/24] test comparison_match_weight_charts --- tests/test_charts.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/test_charts.py b/tests/test_charts.py index 455b6ac837..b85cb33114 100644 --- a/tests/test_charts.py +++ b/tests/test_charts.py @@ -142,6 +142,24 @@ def test_m_u_charts(dialect, test_helpers): linker.visualisations.match_weights_chart() +@mark_with_dialects_excluding() +def test_comparison_match_weight_charts(dialect, test_helpers): + settings = { + "link_type": "dedupe_only", + "comparisons": [ + cl.ExactMatch("gender"), + cl.ExactMatch("tm_partial"), + cl.LevenshteinAtThresholds("surname", [1]), + ], + } + helper = test_helpers[dialect] + + linker = helper.linker_with_registration(df, settings) + + comp = linker._settings_obj.comparisons[0] + comp.match_weights_chart() + + @mark_with_dialects_excluding() def test_parameter_estimate_charts(dialect, test_helpers): settings = { From 29ed4bfebf8beab97057851046e9e140e94fa55c Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Tue, 17 Feb 2026 17:32:50 +0000 Subject: [PATCH 09/24] match weight charts using dataclass --- splink/internals/charts.py | 11 +++++++---- splink/internals/settings.py | 6 ++---- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/splink/internals/charts.py b/splink/internals/charts.py index 6d66cf4f34..99a5c14b2f 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -8,6 +8,7 @@ from splink.internals.comparison_level import ComparisonLevelDetailedRecord from splink.internals.misc import read_resource +from splink.internals.settings import ModelParameterDetailedRecord from splink.internals.waterfall_chart import records_to_waterfall_data if TYPE_CHECKING: @@ -96,7 +97,9 @@ def save_offline_chart( print(iframe_message.format(filename=filename)) # noqa: T201 -def match_weights_chart(records, as_dict=False): +def match_weights_chart( + records: list[ModelParameterDetailedRecord], as_dict: bool = False +) -> ChartReturnType: chart_path = "match_weights_interactive_history.json" chart = load_chart_definition(chart_path) @@ -104,15 +107,15 @@ def match_weights_chart(records, as_dict=False): del chart["params"] del chart["transform"] - records = [r for r in records if r["comparison_vector_value"] != -1] - chart["data"]["values"] = records + records = [r for r in records if r.comparison_vector_value != -1] bayes_factors = [ abs(l2bf) for r in records - if (l2bf := r["log2_bayes_factor"]) is not None and not math.isinf(l2bf) + if (l2bf := r.log2_bayes_factor) is not None and not math.isinf(l2bf) ] max_value = math.ceil(max(bayes_factors)) + chart["data"]["values"] = list_items_as_dicts(records) chart["vconcat"][0]["encoding"]["x"]["scale"]["domain"] = [-max_value, max_value] chart["vconcat"][1]["encoding"]["x"]["scale"]["domain"] = [-max_value, max_value] diff --git a/splink/internals/settings.py b/splink/internals/settings.py index 40c6029796..d690d07c21 100644 --- a/splink/internals/settings.py +++ b/splink/internals/settings.py @@ -595,12 +595,10 @@ def _as_completed_dict(self): } def match_weights_chart(self, as_dict=False): - records = list( - map(lambda rec: rec.as_dict(), self._parameters_as_detailed_records) + return match_weights_chart( + self._parameters_as_detailed_records, as_dict=as_dict ) - return match_weights_chart(records, as_dict=as_dict) - def m_u_parameters_chart(self, as_dict=False): records = list( map(lambda rec: rec.as_dict(), self._parameters_as_detailed_records) From 3648eb30828ad7294602233345d6514beb0b8156 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Tue, 17 Feb 2026 17:55:25 +0000 Subject: [PATCH 10/24] mu charts using dataclass --- splink/internals/charts.py | 10 ++++++---- splink/internals/settings.py | 5 ++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/splink/internals/charts.py b/splink/internals/charts.py index 99a5c14b2f..a1cfc7c544 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -141,7 +141,9 @@ def comparison_match_weights_chart( return altair_or_json(chart, as_dict=as_dict) -def m_u_parameters_chart(records, as_dict=False): +def m_u_parameters_chart( + records: list[ModelParameterDetailedRecord], as_dict: bool = False +) -> ChartReturnType: chart_path = "m_u_parameters_interactive_history.json" chart = load_chart_definition(chart_path) @@ -152,10 +154,10 @@ def m_u_parameters_chart(records, as_dict=False): records = [ r for r in records - if r["comparison_vector_value"] != -1 - and r["comparison_name"] != "probability_two_random_records_match" + if r.comparison_vector_value != -1 + and r.comparison_name != "probability_two_random_records_match" ] - chart["data"]["values"] = records + chart["data"]["values"] = list_items_as_dicts(records) return altair_or_json(chart, as_dict=as_dict) diff --git a/splink/internals/settings.py b/splink/internals/settings.py index d690d07c21..fce55ce66a 100644 --- a/splink/internals/settings.py +++ b/splink/internals/settings.py @@ -600,10 +600,9 @@ def match_weights_chart(self, as_dict=False): ) def m_u_parameters_chart(self, as_dict=False): - records = list( - map(lambda rec: rec.as_dict(), self._parameters_as_detailed_records) + return m_u_parameters_chart( + self._parameters_as_detailed_records, as_dict=as_dict ) - return m_u_parameters_chart(records, as_dict=as_dict) def _columns_without_estimated_parameters_message(self): message_lines = [] From 66f0c2019e389f8876aadbad7bbc5c3866f540b6 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Wed, 18 Feb 2026 09:18:42 +0000 Subject: [PATCH 11/24] import only for type checking --- splink/internals/charts.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/splink/internals/charts.py b/splink/internals/charts.py index a1cfc7c544..55c031fc89 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -4,19 +4,20 @@ import math import os from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Dict, Protocol, Union +from typing import TYPE_CHECKING, Any, Protocol, Union -from splink.internals.comparison_level import ComparisonLevelDetailedRecord from splink.internals.misc import read_resource -from splink.internals.settings import ModelParameterDetailedRecord from splink.internals.waterfall_chart import records_to_waterfall_data if TYPE_CHECKING: from altair import SchemaBase + + from splink.internals.comparison_level import ComparisonLevelDetailedRecord + from splink.internals.settings import ModelParameterDetailedRecord else: SchemaBase = None # type alias: -ChartReturnType = Union[Dict[Any, Any], SchemaBase] +ChartReturnType = Union[dict[Any, Any], SchemaBase] def load_chart_definition(filename): From a092b4d003e35e48acd32386d0073dec91d29441 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Wed, 18 Feb 2026 09:56:10 +0000 Subject: [PATCH 12/24] avoid kw only not supported in 3.9, and as we only have 1 default arg, that's only instantiated in 1 place, it is not a big sacrifice to just set this value at the call site --- splink/internals/comparison_level.py | 5 +++-- splink/internals/em_training_session.py | 2 +- splink/internals/settings.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/splink/internals/comparison_level.py b/splink/internals/comparison_level.py index 85dcb0208c..0f4de99597 100644 --- a/splink/internals/comparison_level.py +++ b/splink/internals/comparison_level.py @@ -117,7 +117,7 @@ def _default_u_values(num_levels: int) -> list[float]: return u_vals -@dataclass(kw_only=True) +@dataclass class ComparisonLevelDetailedRecord: sql_condition: str | None label_for_charts: str @@ -139,7 +139,7 @@ class ComparisonLevelDetailedRecord: comparison_vector_value: int max_comparison_vector_value: int - comparison_name: str | None = None + comparison_name: str | None def as_dict(self): return asdict(self) @@ -786,6 +786,7 @@ def _as_detailed_record( bayes_factor_description=self._bayes_factor_description, comparison_vector_value=self.comparison_vector_value, max_comparison_vector_value=comparison_num_levels - 1, + comparison_name=None, ) def _parameter_estimates_as_records( diff --git a/splink/internals/em_training_session.py b/splink/internals/em_training_session.py index cbd25c9c12..1681972db1 100644 --- a/splink/internals/em_training_session.py +++ b/splink/internals/em_training_session.py @@ -41,7 +41,7 @@ from splink.internals.splink_dataframe import SplinkDataFrame -@dataclass(kw_only=True) +@dataclass class ModelParameterIterationDetailedRecord(ModelParameterDetailedRecord): iteration: int diff --git a/splink/internals/settings.py b/splink/internals/settings.py index fce55ce66a..5a807d84d7 100644 --- a/splink/internals/settings.py +++ b/splink/internals/settings.py @@ -30,7 +30,7 @@ class ComparisonAndLevelDict(TypedDict): comparison: Comparison -@dataclass(kw_only=True) +@dataclass class ModelParameterDetailedRecord(ComparisonLevelDetailedRecord): probability_two_random_records_match: float comparison_sort_order: int From bafa2443498453d42ee72ede77cc13bb47d5b5b8 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Wed, 18 Feb 2026 11:16:12 +0000 Subject: [PATCH 13/24] iteration charts with dataclasses --- splink/internals/charts.py | 27 ++++++++++++++++--------- splink/internals/em_training_session.py | 7 +++---- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/splink/internals/charts.py b/splink/internals/charts.py index 55c031fc89..e3bcd392ae 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -13,6 +13,9 @@ from altair import SchemaBase from splink.internals.comparison_level import ComparisonLevelDetailedRecord + from splink.internals.em_training_session import ( + ModelParameterIterationDetailedRecord, + ) from splink.internals.settings import ModelParameterDetailedRecord else: SchemaBase = None @@ -170,39 +173,45 @@ def probability_two_random_records_match_iteration_chart(records, as_dict=False) return altair_or_json(chart, as_dict=as_dict) -def match_weights_interactive_history_chart(records, as_dict=False, blocking_rule=None): +def match_weights_interactive_history_chart( + records: list[ModelParameterIterationDetailedRecord], + as_dict: bool = False, + blocking_rule: str | None = None, +) -> ChartReturnType: chart_path = "match_weights_interactive_history.json" chart = load_chart_definition(chart_path) chart["title"]["subtitle"] = f"Training session blocked on {blocking_rule}" - records = [r for r in records if r["comparison_vector_value"] != -1] - chart["data"]["values"] = records + records = [r for r in records if r.comparison_vector_value != -1] max_iteration = 0 for r in records: - max_iteration = max(r["iteration"], max_iteration) + max_iteration = max(r.iteration, max_iteration) + chart["data"]["values"] = list_items_as_dicts(records) chart["params"][0]["bind"]["max"] = max_iteration chart["params"][0]["value"] = max_iteration return altair_or_json(chart, as_dict=as_dict) -def m_u_parameters_interactive_history_chart(records, as_dict=False): +def m_u_parameters_interactive_history_chart( + records: list[ModelParameterIterationDetailedRecord], as_dict: bool = False +) -> ChartReturnType: chart_path = "m_u_parameters_interactive_history.json" chart = load_chart_definition(chart_path) records = [ r for r in records - if r["comparison_vector_value"] != -1 - and r["comparison_name"] != "probability_two_random_records_match" + if r.comparison_vector_value != -1 + and r.comparison_name != "probability_two_random_records_match" ] - chart["data"]["values"] = records max_iteration = 0 for r in records: - max_iteration = max(r["iteration"], max_iteration) + max_iteration = max(r.iteration, max_iteration) + chart["data"]["values"] = list_items_as_dicts(records) chart["params"][0]["bind"]["max"] = max_iteration chart["params"][0]["value"] = max_iteration return altair_or_json(chart, as_dict=as_dict) diff --git a/splink/internals/em_training_session.py b/splink/internals/em_training_session.py index 1681972db1..f50af9153b 100644 --- a/splink/internals/em_training_session.py +++ b/splink/internals/em_training_session.py @@ -393,9 +393,9 @@ def match_weights_interactive_history_chart(self) -> ChartReturnType: Returns: An interactive Altair chart. """ - records = list(map(lambda rec: rec.as_dict(), self._iteration_history_records)) return match_weights_interactive_history_chart( - records, blocking_rule=self._blocking_rule_for_training.blocking_rule_sql + self._iteration_history_records, + blocking_rule=self._blocking_rule_for_training.blocking_rule_sql, ) def m_u_values_interactive_history_chart(self) -> ChartReturnType: @@ -405,8 +405,7 @@ def m_u_values_interactive_history_chart(self) -> ChartReturnType: Returns: An interactive Altair chart. """ - records = list(map(lambda rec: rec.as_dict(), self._iteration_history_records)) - return m_u_parameters_interactive_history_chart(records) + return m_u_parameters_interactive_history_chart(self._iteration_history_records) def __repr__(self): deactivated_cols = ", ".join( From 687e6a9d3e9e3616bb4274df234395e2cf1ea696 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Wed, 18 Feb 2026 14:36:20 +0000 Subject: [PATCH 14/24] SplinkChart + MatchWeightsChart subclass --- splink/internals/charts.py | 112 ++++++++++++++++++++++++++++------- splink/internals/settings.py | 9 +-- 2 files changed, 95 insertions(+), 26 deletions(-) diff --git a/splink/internals/charts.py b/splink/internals/charts.py index e3bcd392ae..d3552f7467 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -3,8 +3,9 @@ import json import math import os +from abc import ABC, abstractmethod from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Protocol, Union +from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, Union from splink.internals.misc import read_resource from splink.internals.waterfall_chart import records_to_waterfall_data @@ -19,6 +20,8 @@ from splink.internals.settings import ModelParameterDetailedRecord else: SchemaBase = None + + ComparisonLevelDetailedRecord = None # type alias: ChartReturnType = Union[dict[Any, Any], SchemaBase] @@ -52,8 +55,12 @@ class AsDictable(Protocol): def as_dict(self) -> dict[str, Any]: ... -def list_items_as_dicts(lst: Iterable[AsDictable]) -> list[dict[str, Any]]: - return list(map(lambda item: item.as_dict(), lst)) +def list_items_as_dicts( + lst: Iterable[AsDictable | dict[str, Any]], +) -> list[dict[str, Any]]: + return list( + map(lambda item: item if isinstance(item, dict) else item.as_dict(), lst) + ) iframe_message = """ @@ -101,30 +108,91 @@ def save_offline_chart( print(iframe_message.format(filename=filename)) # noqa: T201 -def match_weights_chart( - records: list[ModelParameterDetailedRecord], as_dict: bool = False -) -> ChartReturnType: - chart_path = "match_weights_interactive_history.json" - chart = load_chart_definition(chart_path) +class ChartRecord(Protocol): ... - # Remove iteration history since this is a static chart - del chart["params"] - del chart["transform"] - records = [r for r in records if r.comparison_vector_value != -1] +T = TypeVar("T", bound=ChartRecord) - bayes_factors = [ - abs(l2bf) - for r in records - if (l2bf := r.log2_bayes_factor) is not None and not math.isinf(l2bf) - ] - max_value = math.ceil(max(bayes_factors)) - chart["data"]["values"] = list_items_as_dicts(records) - chart["vconcat"][0]["encoding"]["x"]["scale"]["domain"] = [-max_value, max_value] - chart["vconcat"][1]["encoding"]["x"]["scale"]["domain"] = [-max_value, max_value] +class SplinkChart(ABC, Generic[T]): + def __init__(self, records: list[Any], as_dict: bool = False): + # TODO: as_dict only in methods rather than on object + self.raw_records = records + self.as_dict = as_dict - return altair_or_json(chart, as_dict=as_dict) + @property + @abstractmethod + def chart_spec_file(self) -> str: + pass + + @property + def chart_data(self): + # TODO: cache + return self.alter_data(self.raw_records) + + @property + def chart_spec(self) -> dict[str, Any]: + chart_spec = load_chart_definition(self.chart_spec_file) + chart_spec = self.alter_spec_directly(chart_spec) + chart_spec = self.alter_spec_from_data(chart_spec) + return chart_spec + + @property + def chart_dict(self) -> dict[str, Any]: + chart = self.chart_spec + chart["data"]["values"] = list_items_as_dicts(self.chart_data) + return chart + + @property + def chart(self) -> ChartReturnType: + # TODO: split into separate methods + # also save etc + return altair_or_json(self.chart_dict, as_dict=self.as_dict) + + @staticmethod + def alter_data(records): + return records + + @staticmethod + def alter_spec_directly(chart_spec): + return chart_spec + + def alter_spec_from_data(self, chart_spec): + return chart_spec + + +class MatchWeightsChart(SplinkChart[ComparisonLevelDetailedRecord]): + @property + def chart_spec_file(self) -> str: + return "match_weights_interactive_history.json" + + @staticmethod + def alter_data(records): + return [r for r in records if r.comparison_vector_value != -1] + + @staticmethod + def alter_spec_directly(chart_spec): + # Remove iteration history since this is a static chart + del chart_spec["params"] + del chart_spec["transform"] + return chart_spec + + def alter_spec_from_data(self, chart_spec): + bayes_factors = [ + abs(l2bf) + for r in self.chart_data + if (l2bf := r.log2_bayes_factor) is not None and not math.isinf(l2bf) + ] + max_value = math.ceil(max(bayes_factors)) + chart_spec["vconcat"][0]["encoding"]["x"]["scale"]["domain"] = [ + -max_value, + max_value, + ] + chart_spec["vconcat"][1]["encoding"]["x"]["scale"]["domain"] = [ + -max_value, + max_value, + ] + return chart_spec def comparison_match_weights_chart( diff --git a/splink/internals/settings.py b/splink/internals/settings.py index 5a807d84d7..11232a092d 100644 --- a/splink/internals/settings.py +++ b/splink/internals/settings.py @@ -6,7 +6,7 @@ from typing import Any, List, Literal, TypedDict from splink.internals.blocking import BlockingRule -from splink.internals.charts import m_u_parameters_chart, match_weights_chart +from splink.internals.charts import MatchWeightsChart, m_u_parameters_chart from splink.internals.comparison import Comparison from splink.internals.comparison_level import ( ComparisonLevel, @@ -595,9 +595,10 @@ def _as_completed_dict(self): } def match_weights_chart(self, as_dict=False): - return match_weights_chart( - self._parameters_as_detailed_records, as_dict=as_dict - ) + return MatchWeightsChart( + self._parameters_as_detailed_records, + as_dict=as_dict, + ).chart def m_u_parameters_chart(self, as_dict=False): return m_u_parameters_chart( From 75ab2b37b28f721c5b340cf52bd18631d9116dbd Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Wed, 18 Feb 2026 15:04:25 +0000 Subject: [PATCH 15/24] comparison match weights chart to SplinkChart --- splink/internals/charts.py | 30 ++++++++++++++++-------------- splink/internals/comparison.py | 4 ++-- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/splink/internals/charts.py b/splink/internals/charts.py index d3552f7467..07e7607ec9 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -195,22 +195,24 @@ def alter_spec_from_data(self, chart_spec): return chart_spec -def comparison_match_weights_chart( - records: list[ComparisonLevelDetailedRecord], as_dict: bool = False -) -> ChartReturnType: - chart_path = "match_weights_interactive_history.json" - chart = load_chart_definition(chart_path) +class ComparisonMatchWeightsChart(SplinkChart[ComparisonLevelDetailedRecord]): + @property + def chart_spec_file(self) -> str: + return "match_weights_interactive_history.json" - # Remove iteration history since this is a static chart - # TODO: some render issue if we remove empty top panel, so leave for now - # del chart["vconcat"][0] - del chart["params"] - del chart["transform"] + @staticmethod + def alter_data(records): + return [r for r in records if r.comparison_vector_value != -1] - chart["title"]["text"] = "Comparison summary" - records = [r for r in records if r.comparison_vector_value != -1] - chart["data"]["values"] = list_items_as_dicts(records) - return altair_or_json(chart, as_dict=as_dict) + @staticmethod + def alter_spec_directly(chart_spec): + # Remove iteration history since this is a static chart + # TODO: some render issue if we remove empty top panel, so leave for now + # del chart["vconcat"][0] + del chart_spec["params"] + del chart_spec["transform"] + chart_spec["title"]["text"] = "Comparison summary" + return chart_spec def m_u_parameters_chart( diff --git a/splink/internals/comparison.py b/splink/internals/comparison.py index a3e9857499..00eef7f3ea 100644 --- a/splink/internals/comparison.py +++ b/splink/internals/comparison.py @@ -465,7 +465,7 @@ def human_readable_description(self): def match_weights_chart(self, as_dict=False): """Display a chart of comparison levels of the comparison""" - from splink.internals.charts import comparison_match_weights_chart + from splink.internals.charts import ComparisonMatchWeightsChart records = self._as_detailed_records - return comparison_match_weights_chart(records, as_dict=as_dict) + return ComparisonMatchWeightsChart(records, as_dict=as_dict).chart From 58536602520baf09d92c51e694cc77cfd2a119a9 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Wed, 18 Feb 2026 15:47:26 +0000 Subject: [PATCH 16/24] some simpler charts to SplinkChart --- splink/internals/blocking_analysis.py | 6 +- splink/internals/charts.py | 83 +++++++++---------- splink/internals/completeness.py | 6 +- splink/internals/em_training_session.py | 4 +- .../linker_components/visualisations.py | 6 +- splink/internals/settings.py | 6 +- 6 files changed, 54 insertions(+), 57 deletions(-) diff --git a/splink/internals/blocking_analysis.py b/splink/internals/blocking_analysis.py index 4511faf95a..565eefc063 100644 --- a/splink/internals/blocking_analysis.py +++ b/splink/internals/blocking_analysis.py @@ -30,7 +30,7 @@ from splink.internals.blocking_rule_creator_utils import to_blocking_rule_creator from splink.internals.charts import ( ChartReturnType, - cumulative_blocking_rule_comparisons_generated, + CumulativeBlockingRuleComparisonsGeneratedChart, ) from splink.internals.database_api import DatabaseAPISubClass from splink.internals.duckdb.duckdb_helpers import record_dicts_from_relation @@ -721,7 +721,9 @@ def cumulative_comparisons_to_be_scored_from_blocking_rules_chart( ) ) - return cumulative_blocking_rule_comparisons_generated(cumulative_comparison_records) + return CumulativeBlockingRuleComparisonsGeneratedChart( + cumulative_comparison_records + ).chart def n_largest_blocks( diff --git a/splink/internals/charts.py b/splink/internals/charts.py index 07e7607ec9..ece3fa035f 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -22,6 +22,8 @@ SchemaBase = None ComparisonLevelDetailedRecord = None + ModelParameterDetailedRecord = None + ModelParameterIterationDetailedRecord = None # type alias: ChartReturnType = Union[dict[Any, Any], SchemaBase] @@ -108,9 +110,13 @@ def save_offline_chart( print(iframe_message.format(filename=filename)) # noqa: T201 +# TODO: we can have more detailed subclasses to hint the fields needed per chart class ChartRecord(Protocol): ... +class PlaceholderRecord: ... # TODO: placeholder until we type remaining charts + + T = TypeVar("T", bound=ChartRecord) @@ -215,32 +221,32 @@ def alter_spec_directly(chart_spec): return chart_spec -def m_u_parameters_chart( - records: list[ModelParameterDetailedRecord], as_dict: bool = False -) -> ChartReturnType: - chart_path = "m_u_parameters_interactive_history.json" - chart = load_chart_definition(chart_path) - - # Remove iteration history since this is a static chart - del chart["params"] - del chart["transform"] +class MUParametersChart(SplinkChart[ModelParameterDetailedRecord]): + @property + def chart_spec_file(self) -> str: + return "m_u_parameters_interactive_history.json" - records = [ - r - for r in records - if r.comparison_vector_value != -1 - and r.comparison_name != "probability_two_random_records_match" - ] - chart["data"]["values"] = list_items_as_dicts(records) - return altair_or_json(chart, as_dict=as_dict) + @staticmethod + def alter_data(records): + return [ + r + for r in records + if r.comparison_vector_value != -1 + and r.comparison_name != "probability_two_random_records_match" + ] + @staticmethod + def alter_spec_directly(chart_spec): + # Remove iteration history since this is a static chart + del chart_spec["params"] + del chart_spec["transform"] + return chart_spec -def probability_two_random_records_match_iteration_chart(records, as_dict=False): - chart_path = "probability_two_random_records_match_iteration.json" - chart = load_chart_definition(chart_path) - chart["data"]["values"] = records - return altair_or_json(chart, as_dict=as_dict) +class ProbabilityTwoRandomRecordsMatchIterationChart(SplinkChart[PlaceholderRecord]): + @property + def chart_spec_file(self) -> str: + return "probability_two_random_records_match_iteration.json" def match_weights_interactive_history_chart( @@ -438,13 +444,10 @@ def match_weights_histogram(records, width=500, height=250, as_dict=False): return altair_or_json(chart, as_dict=as_dict) -def parameter_estimate_comparisons(records, as_dict=False): - chart_path = "parameter_estimate_comparisons.json" - chart = load_chart_definition(chart_path) - - chart["data"]["values"] = records - - return altair_or_json(chart, as_dict=as_dict) +class ParameterEstimateComparisonsChart(SplinkChart[PlaceholderRecord]): + @property + def chart_spec_file(self) -> str: + return "parameter_estimate_comparisons.json" def unlinkables_chart( @@ -496,22 +499,16 @@ def unlinkables_chart( return altair_or_json(unlinkables_chart_def, as_dict=as_dict) -def completeness_chart(records, as_dict=False): - chart_path = "completeness.json" - chart = load_chart_definition(chart_path) - - chart["data"]["values"] = records - - return altair_or_json(chart, as_dict=as_dict) - - -def cumulative_blocking_rule_comparisons_generated(records, as_dict=False): - chart_path = "blocking_rule_generated_comparisons.json" - chart = load_chart_definition(chart_path) +class CompletenessChart(SplinkChart[PlaceholderRecord]): + @property + def chart_spec_file(self) -> str: + return "completeness.json" - chart["data"]["values"] = records - return altair_or_json(chart, as_dict=as_dict) +class CumulativeBlockingRuleComparisonsGeneratedChart(SplinkChart[PlaceholderRecord]): + @property + def chart_spec_file(self) -> str: + return "blocking_rule_generated_comparisons.json" def _comparator_score_chart(similarity_records, distance_records, as_dict=False): diff --git a/splink/internals/completeness.py b/splink/internals/completeness.py index 111d9999b6..c7f92b3bad 100644 --- a/splink/internals/completeness.py +++ b/splink/internals/completeness.py @@ -4,9 +4,7 @@ from splink.internals.charts import ( ChartReturnType, -) -from splink.internals.charts import ( - completeness_chart as records_to_completeness_chart, + CompletenessChart, ) from splink.internals.database_api import DatabaseAPISubClass from splink.internals.input_column import InputColumn @@ -131,4 +129,4 @@ def completeness_chart( db_api = get_db_api_from_inputs(splink_dataframe_or_dataframes) splink_df_dict = splink_dataframes_to_dict(splink_dataframe_or_dataframes) records = completeness_data(splink_df_dict, db_api, cols, table_names_for_chart) - return records_to_completeness_chart(records) + return CompletenessChart(records).chart diff --git a/splink/internals/em_training_session.py b/splink/internals/em_training_session.py index f50af9153b..94840fd2b4 100644 --- a/splink/internals/em_training_session.py +++ b/splink/internals/em_training_session.py @@ -7,9 +7,9 @@ from splink.internals.blocking import BlockingRule, block_using_rules_sqls from splink.internals.charts import ( ChartReturnType, + ProbabilityTwoRandomRecordsMatchIterationChart, m_u_parameters_interactive_history_chart, match_weights_interactive_history_chart, - probability_two_random_records_match_iteration_chart, ) from splink.internals.comparison import Comparison from splink.internals.comparison_vector_values import ( @@ -384,7 +384,7 @@ def probability_two_random_records_match_iteration_chart(self) -> ChartReturnTyp An interactive Altair chart. """ records = self._lambda_history_records - return probability_two_random_records_match_iteration_chart(records) + return ProbabilityTwoRandomRecordsMatchIterationChart(records).chart def match_weights_interactive_history_chart(self) -> ChartReturnType: """ diff --git a/splink/internals/linker_components/visualisations.py b/splink/internals/linker_components/visualisations.py index efc7d08c89..b550cd30c4 100644 --- a/splink/internals/linker_components/visualisations.py +++ b/splink/internals/linker_components/visualisations.py @@ -4,8 +4,8 @@ from splink.internals.charts import ( ChartReturnType, + ParameterEstimateComparisonsChart, match_weights_histogram, - parameter_estimate_comparisons, waterfall_chart, ) from splink.internals.cluster_studio import ( @@ -187,8 +187,8 @@ def parameter_estimate_comparisons_chart( to_retain.append("u") records = [r for r in records if r["m_or_u"] in to_retain] - - return parameter_estimate_comparisons(records, as_dict) + # TODO: this logic into chart object + return ParameterEstimateComparisonsChart(records, as_dict).chart def tf_adjustment_chart( self, diff --git a/splink/internals/settings.py b/splink/internals/settings.py index 11232a092d..244c66f5a6 100644 --- a/splink/internals/settings.py +++ b/splink/internals/settings.py @@ -6,7 +6,7 @@ from typing import Any, List, Literal, TypedDict from splink.internals.blocking import BlockingRule -from splink.internals.charts import MatchWeightsChart, m_u_parameters_chart +from splink.internals.charts import MatchWeightsChart, MUParametersChart from splink.internals.comparison import Comparison from splink.internals.comparison_level import ( ComparisonLevel, @@ -601,9 +601,9 @@ def match_weights_chart(self, as_dict=False): ).chart def m_u_parameters_chart(self, as_dict=False): - return m_u_parameters_chart( + return MUParametersChart( self._parameters_as_detailed_records, as_dict=as_dict - ) + ).chart def _columns_without_estimated_parameters_message(self): message_lines = [] From b6bac817186e81f6408e743f9929027cac05b617 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Wed, 18 Feb 2026 16:23:16 +0000 Subject: [PATCH 17/24] mw histogram chart, and settable height/width --- splink/internals/charts.py | 43 +++++++++++++++---- .../linker_components/visualisations.py | 8 ++-- 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/splink/internals/charts.py b/splink/internals/charts.py index ece3fa035f..8b1a61d0b5 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -125,6 +125,8 @@ def __init__(self, records: list[Any], as_dict: bool = False): # TODO: as_dict only in methods rather than on object self.raw_records = records self.as_dict = as_dict + self.width: float | None = self.default_width + self.height: float | None = self.default_height @property @abstractmethod @@ -136,10 +138,19 @@ def chart_data(self): # TODO: cache return self.alter_data(self.raw_records) + @property + def default_width(self) -> float | None: + return None + + @property + def default_height(self) -> float | None: + return None + @property def chart_spec(self) -> dict[str, Any]: chart_spec = load_chart_definition(self.chart_spec_file) chart_spec = self.alter_spec_directly(chart_spec) + chart_spec = self.alter_spec_height_width(chart_spec) chart_spec = self.alter_spec_from_data(chart_spec) return chart_spec @@ -166,6 +177,20 @@ def alter_spec_directly(chart_spec): def alter_spec_from_data(self, chart_spec): return chart_spec + def alter_spec_height_width(self, chart_spec): + if width := self.width: + chart_spec["width"] = width + if height := self.height: + chart_spec["height"] = height + return chart_spec + + def set_width_height( + self, *, width: float | None = None, height: float | None = None + ) -> None: + self.width = width + self.height = height + # TODO: return self? + class MatchWeightsChart(SplinkChart[ComparisonLevelDetailedRecord]): @property @@ -432,16 +457,18 @@ def threshold_selection_tool(records, as_dict=False, add_metrics=[]): return altair_or_json(chart, as_dict=as_dict) -def match_weights_histogram(records, width=500, height=250, as_dict=False): - chart_path = "match_weight_histogram.json" - chart = load_chart_definition(chart_path) - - chart["data"]["values"] = records +class MatchWeightsHistogramChart(SplinkChart[PlaceholderRecord]): + @property + def chart_spec_file(self) -> str: + return "match_weight_histogram.json" - chart["height"] = height - chart["width"] = width + @property + def default_width(self) -> float: + return 500 - return altair_or_json(chart, as_dict=as_dict) + @property + def default_height(self) -> float: + return 250 class ParameterEstimateComparisonsChart(SplinkChart[PlaceholderRecord]): diff --git a/splink/internals/linker_components/visualisations.py b/splink/internals/linker_components/visualisations.py index b550cd30c4..1d0456fcaa 100644 --- a/splink/internals/linker_components/visualisations.py +++ b/splink/internals/linker_components/visualisations.py @@ -4,8 +4,8 @@ from splink.internals.charts import ( ChartReturnType, + MatchWeightsHistogramChart, ParameterEstimateComparisonsChart, - match_weights_histogram, waterfall_chart, ) from splink.internals.cluster_studio import ( @@ -139,9 +139,9 @@ def match_weights_histogram( """ df = histogram_data(self._linker, df_predict, target_bins) recs = df.as_record_dict() - return match_weights_histogram( - recs, width=width, height=height, as_dict=as_dict - ) + chart = MatchWeightsHistogramChart(recs, as_dict=as_dict) + chart.set_width_height(width=width, height=height) + return chart.chart def parameter_estimate_comparisons_chart( self, include_m: bool = True, include_u: bool = False, as_dict: bool = False From 8b2b1aa57d6dc72088f70eeb8125e2b025d27364 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Wed, 18 Feb 2026 16:46:27 +0000 Subject: [PATCH 18/24] mw interactive history as SplinkChart --- splink/internals/charts.py | 46 ++++++++++++++++--------- splink/internals/em_training_session.py | 8 ++--- 2 files changed, 33 insertions(+), 21 deletions(-) diff --git a/splink/internals/charts.py b/splink/internals/charts.py index 8b1a61d0b5..29ffbbc211 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -274,26 +274,38 @@ def chart_spec_file(self) -> str: return "probability_two_random_records_match_iteration.json" -def match_weights_interactive_history_chart( - records: list[ModelParameterIterationDetailedRecord], - as_dict: bool = False, - blocking_rule: str | None = None, -) -> ChartReturnType: - chart_path = "match_weights_interactive_history.json" - chart = load_chart_definition(chart_path) - - chart["title"]["subtitle"] = f"Training session blocked on {blocking_rule}" +class MatchWeightsInteractiveHistoryChart( + SplinkChart[ModelParameterIterationDetailedRecord] +): + def __init__( + self, + records: list[ModelParameterIterationDetailedRecord], + blocking_rule_text: str, + as_dict: bool = False, + ): + super().__init__(records, as_dict=as_dict) + self.blocking_rule_text = blocking_rule_text - records = [r for r in records if r.comparison_vector_value != -1] + @property + def chart_spec_file(self) -> str: + return "match_weights_interactive_history.json" - max_iteration = 0 - for r in records: - max_iteration = max(r.iteration, max_iteration) - chart["data"]["values"] = list_items_as_dicts(records) + @staticmethod + def alter_data(records): + return [r for r in records if r.comparison_vector_value != -1] - chart["params"][0]["bind"]["max"] = max_iteration - chart["params"][0]["value"] = max_iteration - return altair_or_json(chart, as_dict=as_dict) + def alter_spec_from_data(self, chart_spec): + records = self.chart_data + max_iteration = 0 + for r in records: + max_iteration = max(r.iteration, max_iteration) + + chart_spec["params"][0]["bind"]["max"] = max_iteration + chart_spec["params"][0]["value"] = max_iteration + chart_spec["title"]["subtitle"] = ( + f"Training session blocked on {self.blocking_rule_text}" + ) + return chart_spec def m_u_parameters_interactive_history_chart( diff --git a/splink/internals/em_training_session.py b/splink/internals/em_training_session.py index 94840fd2b4..21b9d143fc 100644 --- a/splink/internals/em_training_session.py +++ b/splink/internals/em_training_session.py @@ -7,9 +7,9 @@ from splink.internals.blocking import BlockingRule, block_using_rules_sqls from splink.internals.charts import ( ChartReturnType, + MatchWeightsInteractiveHistoryChart, ProbabilityTwoRandomRecordsMatchIterationChart, m_u_parameters_interactive_history_chart, - match_weights_interactive_history_chart, ) from splink.internals.comparison import Comparison from splink.internals.comparison_vector_values import ( @@ -393,10 +393,10 @@ def match_weights_interactive_history_chart(self) -> ChartReturnType: Returns: An interactive Altair chart. """ - return match_weights_interactive_history_chart( + return MatchWeightsInteractiveHistoryChart( self._iteration_history_records, - blocking_rule=self._blocking_rule_for_training.blocking_rule_sql, - ) + blocking_rule_text=self._blocking_rule_for_training.blocking_rule_sql, + ).chart def m_u_values_interactive_history_chart(self) -> ChartReturnType: """ From 59556b278a89d4240c06a44e2dfc590b368defae Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Wed, 18 Feb 2026 17:03:41 +0000 Subject: [PATCH 19/24] better chart typing --- splink/internals/charts.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/splink/internals/charts.py b/splink/internals/charts.py index 29ffbbc211..3a090c1cb5 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -5,7 +5,7 @@ import os from abc import ABC, abstractmethod from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar, Union +from typing import TYPE_CHECKING, Any, Generic, Protocol, Sequence, TypeVar, Union from splink.internals.misc import read_resource from splink.internals.waterfall_chart import records_to_waterfall_data @@ -114,14 +114,11 @@ def save_offline_chart( class ChartRecord(Protocol): ... -class PlaceholderRecord: ... # TODO: placeholder until we type remaining charts - - T = TypeVar("T", bound=ChartRecord) class SplinkChart(ABC, Generic[T]): - def __init__(self, records: list[Any], as_dict: bool = False): + def __init__(self, records: Sequence[T], as_dict: bool = False): # TODO: as_dict only in methods rather than on object self.raw_records = records self.as_dict = as_dict @@ -268,7 +265,7 @@ def alter_spec_directly(chart_spec): return chart_spec -class ProbabilityTwoRandomRecordsMatchIterationChart(SplinkChart[PlaceholderRecord]): +class ProbabilityTwoRandomRecordsMatchIterationChart(SplinkChart[ChartRecord]): @property def chart_spec_file(self) -> str: return "probability_two_random_records_match_iteration.json" @@ -279,7 +276,7 @@ class MatchWeightsInteractiveHistoryChart( ): def __init__( self, - records: list[ModelParameterIterationDetailedRecord], + records: Sequence[ModelParameterIterationDetailedRecord], blocking_rule_text: str, as_dict: bool = False, ): @@ -469,7 +466,7 @@ def threshold_selection_tool(records, as_dict=False, add_metrics=[]): return altair_or_json(chart, as_dict=as_dict) -class MatchWeightsHistogramChart(SplinkChart[PlaceholderRecord]): +class MatchWeightsHistogramChart(SplinkChart[ChartRecord]): @property def chart_spec_file(self) -> str: return "match_weight_histogram.json" @@ -483,7 +480,7 @@ def default_height(self) -> float: return 250 -class ParameterEstimateComparisonsChart(SplinkChart[PlaceholderRecord]): +class ParameterEstimateComparisonsChart(SplinkChart[ChartRecord]): @property def chart_spec_file(self) -> str: return "parameter_estimate_comparisons.json" @@ -538,13 +535,13 @@ def unlinkables_chart( return altair_or_json(unlinkables_chart_def, as_dict=as_dict) -class CompletenessChart(SplinkChart[PlaceholderRecord]): +class CompletenessChart(SplinkChart[ChartRecord]): @property def chart_spec_file(self) -> str: return "completeness.json" -class CumulativeBlockingRuleComparisonsGeneratedChart(SplinkChart[PlaceholderRecord]): +class CumulativeBlockingRuleComparisonsGeneratedChart(SplinkChart[ChartRecord]): @property def chart_spec_file(self) -> str: return "blocking_rule_generated_comparisons.json" From 3507ee8c707690dcb6bd91219dcae42644c42bab Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Wed, 18 Feb 2026 17:19:41 +0000 Subject: [PATCH 20/24] more chart functions -> SplinkCharts --- splink/internals/charts.py | 107 ++++++++++-------- splink/internals/em_training_session.py | 6 +- .../internals/linker_components/evaluation.py | 12 +- 3 files changed, 71 insertions(+), 54 deletions(-) diff --git a/splink/internals/charts.py b/splink/internals/charts.py index 3a090c1cb5..16980d235d 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -305,26 +305,31 @@ def alter_spec_from_data(self, chart_spec): return chart_spec -def m_u_parameters_interactive_history_chart( - records: list[ModelParameterIterationDetailedRecord], as_dict: bool = False -) -> ChartReturnType: - chart_path = "m_u_parameters_interactive_history.json" - chart = load_chart_definition(chart_path) - records = [ - r - for r in records - if r.comparison_vector_value != -1 - and r.comparison_name != "probability_two_random_records_match" - ] - - max_iteration = 0 - for r in records: - max_iteration = max(r.iteration, max_iteration) - - chart["data"]["values"] = list_items_as_dicts(records) - chart["params"][0]["bind"]["max"] = max_iteration - chart["params"][0]["value"] = max_iteration - return altair_or_json(chart, as_dict=as_dict) +class MUParametersInteractiveHistoryChart( + SplinkChart[ModelParameterIterationDetailedRecord] +): + @property + def chart_spec_file(self) -> str: + return "m_u_parameters_interactive_history.json" + + @staticmethod + def alter_data(records): + return [ + r + for r in records + if r.comparison_vector_value != -1 + and r.comparison_name != "probability_two_random_records_match" + ] + + def alter_spec_from_data(self, chart_spec): + records = self.chart_data + max_iteration = 0 + for r in records: + max_iteration = max(r.iteration, max_iteration) + + chart_spec["params"][0]["bind"]["max"] = max_iteration + chart_spec["params"][0]["value"] = max_iteration + return chart_spec def waterfall_chart( @@ -345,40 +350,50 @@ def waterfall_chart( return altair_or_json(chart, as_dict=as_dict) -def roc_chart(records, width=400, height=400, as_dict=False): - chart_path = "roc.json" - chart = load_chart_definition(chart_path) - - chart["data"]["values"] = records - - # If 'curve_label' not in records, remove colour coding - # This is for if you want to compare roc curves - r = records[0] - if "curve_label" not in r.keys(): - del chart["encoding"]["color"] +class ROCChart(SplinkChart[ChartRecord]): + @property + def chart_spec_file(self) -> str: + return "roc.json" - chart["height"] = height - chart["width"] = width + @property + def default_width(self) -> float: + return 400 - return altair_or_json(chart, as_dict=as_dict) + @property + def default_height(self) -> float: + return 400 + def alter_spec_from_data(self, chart_spec): + records = self.chart_data + # If 'curve_label' not in records, remove colour coding + # This is for if you want to compare roc curves + r = records[0] + if "curve_label" not in r.keys(): + del chart_spec["encoding"]["color"] + return chart_spec -def precision_recall_chart(records, width=400, height=400, as_dict=False): - chart_path = "precision_recall.json" - chart = load_chart_definition(chart_path) - chart["data"]["values"] = records +class PrecisionRecallChart(SplinkChart[ChartRecord]): + @property + def chart_spec_file(self) -> str: + return "precision_recall.json" - # If 'curve_label' not in records, remove colour coding - # This is for if you want to compare roc curves - r = records[0] - if "curve_label" not in r.keys(): - del chart["encoding"]["color"] + @property + def default_width(self) -> float: + return 400 - chart["height"] = height - chart["width"] = width + @property + def default_height(self) -> float: + return 400 - return altair_or_json(chart, as_dict=as_dict) + def alter_spec_from_data(self, chart_spec): + records = self.chart_data + # If 'curve_label' not in records, remove colour coding + # This is for if you want to compare roc curves + r = records[0] + if "curve_label" not in r.keys(): + del chart_spec["encoding"]["color"] + return chart_spec def accuracy_chart(records, width=400, height=400, as_dict=False, add_metrics=[]): diff --git a/splink/internals/em_training_session.py b/splink/internals/em_training_session.py index 21b9d143fc..44bbfdf4c4 100644 --- a/splink/internals/em_training_session.py +++ b/splink/internals/em_training_session.py @@ -8,8 +8,8 @@ from splink.internals.charts import ( ChartReturnType, MatchWeightsInteractiveHistoryChart, + MUParametersInteractiveHistoryChart, ProbabilityTwoRandomRecordsMatchIterationChart, - m_u_parameters_interactive_history_chart, ) from splink.internals.comparison import Comparison from splink.internals.comparison_vector_values import ( @@ -405,7 +405,9 @@ def m_u_values_interactive_history_chart(self) -> ChartReturnType: Returns: An interactive Altair chart. """ - return m_u_parameters_interactive_history_chart(self._iteration_history_records) + return MUParametersInteractiveHistoryChart( + self._iteration_history_records + ).chart def __repr__(self): deactivated_cols = ", ".join( diff --git a/splink/internals/linker_components/evaluation.py b/splink/internals/linker_components/evaluation.py index 080df0a49b..78d69731ac 100644 --- a/splink/internals/linker_components/evaluation.py +++ b/splink/internals/linker_components/evaluation.py @@ -10,9 +10,9 @@ ) from splink.internals.charts import ( ChartReturnType, + PrecisionRecallChart, + ROCChart, accuracy_chart, - precision_recall_chart, - roc_chart, threshold_selection_tool, unlinkables_chart, ) @@ -172,9 +172,9 @@ def accuracy_analysis_from_labels_column( elif output_type == "accuracy": return accuracy_chart(recs, add_metrics=add_metrics) elif output_type == "roc": - return roc_chart(recs) + return ROCChart(recs).chart elif output_type == "precision_recall": - return precision_recall_chart(recs) + return PrecisionRecallChart(recs).chart elif output_type == "table": return df_truth_space else: @@ -285,9 +285,9 @@ def accuracy_analysis_from_labels_table( elif output_type == "accuracy": return accuracy_chart(recs, add_metrics=add_metrics) elif output_type == "roc": - return roc_chart(recs) + return ROCChart(recs).chart elif output_type == "precision_recall": - return precision_recall_chart(recs) + return PrecisionRecallChart(recs).chart elif output_type == "table": return df_truth_space else: From fd416b3b00b3d31f6d45ee900c9fab2ce35f0c0e Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Thu, 19 Feb 2026 09:53:11 +0000 Subject: [PATCH 21/24] threshold selection & accuracy to SplinkChart --- splink/internals/charts.py | 190 +++++++++++------- .../internals/linker_components/evaluation.py | 12 +- 2 files changed, 122 insertions(+), 80 deletions(-) diff --git a/splink/internals/charts.py b/splink/internals/charts.py index 16980d235d..2eefd27995 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -184,6 +184,8 @@ def alter_spec_height_width(self, chart_spec): def set_width_height( self, *, width: float | None = None, height: float | None = None ) -> None: + # TODO: do we need to be careful with charts that can't have these set + # such as threshold_selection_tool self.width = width self.height = height # TODO: return self? @@ -396,89 +398,129 @@ def alter_spec_from_data(self, chart_spec): return chart_spec -def accuracy_chart(records, width=400, height=400, as_dict=False, add_metrics=[]): - chart_path = "accuracy_chart.json" - chart = load_chart_definition(chart_path) +class AccuracyChart(SplinkChart[ChartRecord]): + def __init__( + self, + records: Sequence[ChartRecord], + add_metrics: Sequence[str] = [], # TODO + as_dict: bool = False, + ): + super().__init__(records, as_dict=as_dict) + # User-specified metrics to include + self.additional_metrics = add_metrics - # User-specified metrics to include - metrics = ["precision", "recall", *add_metrics] - chart["transform"][0]["fold"] = metrics - chart["transform"][1]["calculate"] = chart["transform"][1]["calculate"].replace( - "__metrics__", str(metrics) - ) - chart["layer"][0]["encoding"]["color"]["sort"] = metrics - chart["layer"][1]["layer"][1]["encoding"]["color"]["sort"] = metrics - - # Metric-label mapping - mapping = { - "precision": "Precision (PPV)", - "recall": "Recall (TPR)", - "specificity": "Specificity (TNR)", - "accuracy": "Accuracy", - "npv": "NPV", - "f1": "F1", - "f2": "F2", - "f0_5": "F0.5", - "p4": "P4", - "phi": "\u03c6 (MCC)", - } - chart["transform"][2]["calculate"] = chart["transform"][2]["calculate"].replace( - "__mapping__", str(mapping) - ) - chart["layer"][0]["encoding"]["color"]["legend"]["labelExpr"] = chart["layer"][0][ - "encoding" - ]["color"]["legend"]["labelExpr"].replace("__mapping__", str(mapping)) + @property + def chart_spec_file(self) -> str: + return "accuracy_chart.json" - chart["data"]["values"] = records + @property + def default_width(self) -> float: + return 400 - chart["height"] = height - chart["width"] = width + @property + def default_height(self) -> float: + return 400 - return altair_or_json(chart, as_dict=as_dict) + @property + def metrics(self): + return ["precision", "recall", *self.additional_metrics] + @staticmethod + def alter_spec_directly(chart_spec): + mapping = { + "precision": "Precision (PPV)", + "recall": "Recall (TPR)", + "specificity": "Specificity (TNR)", + "accuracy": "Accuracy", + "npv": "NPV", + "f1": "F1", + "f2": "F2", + "f0_5": "F0.5", + "p4": "P4", + "phi": "\u03c6 (MCC)", + } + chart_spec["transform"][2]["calculate"] = chart_spec["transform"][2][ + "calculate" + ].replace("__mapping__", str(mapping)) + chart_spec["layer"][0]["encoding"]["color"]["legend"]["labelExpr"] = chart_spec[ + "layer" + ][0]["encoding"]["color"]["legend"]["labelExpr"].replace( + "__mapping__", str(mapping) + ) -def threshold_selection_tool(records, as_dict=False, add_metrics=[]): - chart_path = "threshold_selection_tool.json" - chart = load_chart_definition(chart_path) + return chart_spec - # Remove extremes with low precision and recall - records = [d for d in records if d["precision"] > 0.5 and d["recall"] > 0.5] - - # User-specified metrics to include - metrics = ["precision", "recall", *add_metrics] - - chart["hconcat"][1]["transform"][0]["fold"] = metrics - chart["hconcat"][1]["transform"][1]["calculate"] = chart["hconcat"][1]["transform"][ - 1 - ]["calculate"].replace("__metrics__", str(metrics)) - chart["hconcat"][1]["layer"][0]["encoding"]["color"]["sort"] = metrics - chart["hconcat"][1]["layer"][1]["layer"][1]["encoding"]["color"]["sort"] = metrics - - # Metric-label mapping - mapping = { - "precision": "Precision (PPV)", - "recall": "Recall (TPR)", - "specificity": "Specificity (TNR)", - "accuracy": "Accuracy", - "npv": "NPV", - "f1": "F1", - "f2": "F2", - "f0_5": "F0.5", - "p4": "P4", - "phi": "\u03c6 (MCC)", - } - chart["hconcat"][1]["transform"][2]["calculate"] = chart["hconcat"][1]["transform"][ - 2 - ]["calculate"].replace("__mapping__", str(mapping)) - chart["hconcat"][1]["layer"][0]["encoding"]["color"]["legend"]["labelExpr"] = chart[ - "hconcat" - ][1]["layer"][0]["encoding"]["color"]["legend"]["labelExpr"].replace( - "__mapping__", str(mapping) - ) + def alter_spec_from_data(self, chart_spec): + metrics = self.metrics + chart_spec["transform"][0]["fold"] = metrics + chart_spec["transform"][1]["calculate"] = chart_spec["transform"][1][ + "calculate" + ].replace("__metrics__", str(metrics)) + chart_spec["layer"][0]["encoding"]["color"]["sort"] = metrics + chart_spec["layer"][1]["layer"][1]["encoding"]["color"]["sort"] = metrics + return chart_spec - chart["data"]["values"] = records - return altair_or_json(chart, as_dict=as_dict) +class ThresholdSelectionToolChart(SplinkChart[ChartRecord]): + def __init__( + self, + records: Sequence[ChartRecord], + add_metrics: Sequence[str] = [], # TODO + as_dict: bool = False, + ): + super().__init__(records, as_dict=as_dict) + # User-specified metrics to include + self.additional_metrics = add_metrics + + @property + def chart_spec_file(self) -> str: + return "threshold_selection_tool.json" + + @property + def metrics(self): + return ["precision", "recall", *self.additional_metrics] + + @staticmethod + def alter_data(records): + return [d for d in records if d["precision"] > 0.5 and d["recall"] > 0.5] + + @staticmethod + def alter_spec_directly(chart_spec): + mapping = { + "precision": "Precision (PPV)", + "recall": "Recall (TPR)", + "specificity": "Specificity (TNR)", + "accuracy": "Accuracy", + "npv": "NPV", + "f1": "F1", + "f2": "F2", + "f0_5": "F0.5", + "p4": "P4", + "phi": "\u03c6 (MCC)", + } + chart_spec["hconcat"][1]["transform"][2]["calculate"] = chart_spec["hconcat"][ + 1 + ]["transform"][2]["calculate"].replace("__mapping__", str(mapping)) + chart_spec["hconcat"][1]["layer"][0]["encoding"]["color"]["legend"][ + "labelExpr" + ] = chart_spec["hconcat"][1]["layer"][0]["encoding"]["color"]["legend"][ + "labelExpr" + ].replace("__mapping__", str(mapping)) + + return chart_spec + + def alter_spec_from_data(self, chart_spec): + metrics = self.metrics + chart_spec["hconcat"][1]["transform"][0]["fold"] = metrics + chart_spec["hconcat"][1]["transform"][1]["calculate"] = chart_spec["hconcat"][ + 1 + ]["transform"][1]["calculate"].replace("__metrics__", str(metrics)) + chart_spec["hconcat"][1]["layer"][0]["encoding"]["color"]["sort"] = metrics + chart_spec["hconcat"][1]["layer"][1]["layer"][1]["encoding"]["color"][ + "sort" + ] = metrics + + return chart_spec class MatchWeightsHistogramChart(SplinkChart[ChartRecord]): diff --git a/splink/internals/linker_components/evaluation.py b/splink/internals/linker_components/evaluation.py index 78d69731ac..ac24a56226 100644 --- a/splink/internals/linker_components/evaluation.py +++ b/splink/internals/linker_components/evaluation.py @@ -9,11 +9,11 @@ truth_space_table_from_labels_table, ) from splink.internals.charts import ( + AccuracyChart, ChartReturnType, PrecisionRecallChart, ROCChart, - accuracy_chart, - threshold_selection_tool, + ThresholdSelectionToolChart, unlinkables_chart, ) from splink.internals.labelling_tool import ( @@ -168,9 +168,9 @@ def accuracy_analysis_from_labels_column( recs = df_truth_space.as_record_dict() if output_type == "threshold_selection": - return threshold_selection_tool(recs, add_metrics=add_metrics) + return ThresholdSelectionToolChart(recs, add_metrics=add_metrics).chart elif output_type == "accuracy": - return accuracy_chart(recs, add_metrics=add_metrics) + return AccuracyChart(recs, add_metrics=add_metrics).chart elif output_type == "roc": return ROCChart(recs).chart elif output_type == "precision_recall": @@ -281,9 +281,9 @@ def accuracy_analysis_from_labels_table( recs = df_truth_space.as_record_dict() if output_type == "threshold_selection": - return threshold_selection_tool(recs, add_metrics=add_metrics) + return ThresholdSelectionToolChart(recs, add_metrics=add_metrics).chart elif output_type == "accuracy": - return accuracy_chart(recs, add_metrics=add_metrics) + return AccuracyChart(recs, add_metrics=add_metrics).chart elif output_type == "roc": return ROCChart(recs).chart elif output_type == "precision_recall": From b8f94256c839c206f7b98d11dc2d43e097335c36 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Thu, 19 Feb 2026 10:42:03 +0000 Subject: [PATCH 22/24] unlinkables -> SplinkChart --- splink/internals/charts.py | 78 ++++++++++--------- .../internals/linker_components/evaluation.py | 7 +- 2 files changed, 47 insertions(+), 38 deletions(-) diff --git a/splink/internals/charts.py b/splink/internals/charts.py index 2eefd27995..6446608437 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -5,7 +5,16 @@ import os from abc import ABC, abstractmethod from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Generic, Protocol, Sequence, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Literal, + Protocol, + Sequence, + TypeVar, + Union, +) from splink.internals.misc import read_resource from splink.internals.waterfall_chart import records_to_waterfall_data @@ -543,53 +552,52 @@ def chart_spec_file(self) -> str: return "parameter_estimate_comparisons.json" -def unlinkables_chart( - records, - x_col="match_weight", - source_dataset=None, - as_dict=False, -): - if x_col not in ["match_weight", "match_probability"]: - raise ValueError( - f"{x_col} must be 'match_weight' (default) or 'match_probability'." - ) +class UnlinkablesChart(SplinkChart[ChartRecord]): + def __init__( + self, + records: Sequence[ChartRecord], + x_col: Literal["match_weight", "match_probability"] = "match_weight", + source_dataset: str | None = None, + as_dict: bool = False, + ): + if x_col not in ["match_weight", "match_probability"]: + raise ValueError( + f"{x_col} must be 'match_weight' (default) or 'match_probability'." + ) + super().__init__(records, as_dict=as_dict) + self.x_col = x_col + self.source_dataset = source_dataset - chart_path = "unlinkables_chart_def.json" - unlinkables_chart_def = load_chart_definition(chart_path) - unlinkables_chart_def["data"]["values"] = records + @property + def chart_spec_file(self) -> str: + return "unlinkables_chart_def.json" - if x_col == "match_probability": - unlinkables_chart_def["layer"][0]["encoding"]["x"]["field"] = ( - "match_probability" - ) - unlinkables_chart_def["layer"][0]["encoding"]["x"]["axis"]["title"] = ( + def alter_spec_from_data(self, chart_spec): + if source_dataset := self.source_dataset: + chart_spec["title"]["text"] += f" - {source_dataset}" + if self.x_col == "match_weight": + return chart_spec + # if we have match_probability we need to update spec to match: + chart_spec["layer"][0]["encoding"]["x"]["field"] = "match_probability" + chart_spec["layer"][0]["encoding"]["x"]["axis"]["title"] = ( "Threshold match probability" ) - unlinkables_chart_def["layer"][0]["encoding"]["x"]["axis"]["format"] = ".2" + chart_spec["layer"][0]["encoding"]["x"]["axis"]["format"] = ".2" - unlinkables_chart_def["layer"][1]["encoding"]["x"]["field"] = ( - "match_probability" - ) - unlinkables_chart_def["layer"][1]["selection"]["selector112"]["fields"] = [ + chart_spec["layer"][1]["encoding"]["x"]["field"] = "match_probability" + chart_spec["layer"][1]["selection"]["selector112"]["fields"] = [ "match_probability", "cum_prop", ] - unlinkables_chart_def["layer"][2]["encoding"]["x"]["field"] = ( - "match_probability" - ) - unlinkables_chart_def["layer"][2]["encoding"]["x"]["axis"]["title"] = ( + chart_spec["layer"][2]["encoding"]["x"]["field"] = "match_probability" + chart_spec["layer"][2]["encoding"]["x"]["axis"]["title"] = ( "Threshold match probability" ) - unlinkables_chart_def["layer"][3]["encoding"]["x"]["field"] = ( - "match_probability" - ) + chart_spec["layer"][3]["encoding"]["x"]["field"] = "match_probability" - if source_dataset: - unlinkables_chart_def["title"]["text"] += f" - {source_dataset}" - - return altair_or_json(unlinkables_chart_def, as_dict=as_dict) + return chart_spec class CompletenessChart(SplinkChart[ChartRecord]): diff --git a/splink/internals/linker_components/evaluation.py b/splink/internals/linker_components/evaluation.py index ac24a56226..4d90fa29e3 100644 --- a/splink/internals/linker_components/evaluation.py +++ b/splink/internals/linker_components/evaluation.py @@ -14,7 +14,7 @@ PrecisionRecallChart, ROCChart, ThresholdSelectionToolChart, - unlinkables_chart, + UnlinkablesChart, ) from splink.internals.labelling_tool import ( generate_labelling_tool_comparisons, @@ -337,7 +337,7 @@ def prediction_errors_from_labels_column( def unlinkables_chart( self, - x_col: str = "match_weight", + x_col: Literal["match_weight", "match_probability"] = "match_weight", name_of_data_in_title: str | None = None, as_dict: bool = False, ) -> ChartReturnType: @@ -349,6 +349,7 @@ def unlinkables_chart( Args: x_col (str, optional): Column to use for the x-axis. + Must be either "match_weight" or "match_probability". Defaults to "match_weight". name_of_data_in_title (str, optional): Name of the source dataset to use for the title of the output chart. @@ -367,7 +368,7 @@ def unlinkables_chart( # Link our initial df on itself and calculate the % of unlinkable entries records = unlinkables_data(self._linker) - return unlinkables_chart(records, x_col, name_of_data_in_title, as_dict) + return UnlinkablesChart(records, x_col, name_of_data_in_title, as_dict).chart def labelling_tool_for_specific_record( self, From 7ea5018d5c1c3276885ca8a346b8d80a6421fbc8 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Thu, 19 Feb 2026 10:51:51 +0000 Subject: [PATCH 23/24] Ditch unneeded need VegaliteNoValidate workaround See https://github.com/moj-analytical-services/splink/pull/1315 --- splink/internals/charts.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/splink/internals/charts.py b/splink/internals/charts.py index 6446608437..1bb9db25de 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -102,9 +102,6 @@ def save_offline_chart( f"The path {filename} already exists. Please provide a different path." ) - if type(chart_dict).__name__ == "VegaliteNoValidate": - chart_dict = chart_dict.spec - template = read_resource("internals/files/templates/single_chart_template.html") fmt_dict = _load_external_libs() From 932f829811ab74b332abf0ca954ffabda35aa02f Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Thu, 19 Feb 2026 10:49:07 +0000 Subject: [PATCH 24/24] waterfall chart to SplinkChart --- splink/internals/charts.py | 37 +++++++++++-------- .../linker_components/visualisations.py | 15 +++++--- 2 files changed, 30 insertions(+), 22 deletions(-) diff --git a/splink/internals/charts.py b/splink/internals/charts.py index 1bb9db25de..31dcf3bd99 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -17,7 +17,6 @@ ) from splink.internals.misc import read_resource -from splink.internals.waterfall_chart import records_to_waterfall_data if TYPE_CHECKING: from altair import SchemaBase @@ -340,22 +339,28 @@ def alter_spec_from_data(self, chart_spec): return chart_spec -def waterfall_chart( - records, - settings_obj, - filter_nulls=True, - remove_sensitive_data=False, - as_dict=False, -): - data = records_to_waterfall_data(records, settings_obj, remove_sensitive_data) - chart_path = "match_weights_waterfall.json" - chart = load_chart_definition(chart_path) - chart["data"]["values"] = data - chart["params"][0]["bind"]["max"] = len(records) - 1 - if filter_nulls: - chart["transform"].insert(1, {"filter": "(datum.bayes_factor !== 1.0)"}) +class WaterfallChart(SplinkChart[ChartRecord]): + def __init__( + self, + records: Sequence[ChartRecord], + filter_nulls: bool = True, + as_dict: bool = False, + ): + super().__init__(records, as_dict=as_dict) + self.filter_nulls = filter_nulls - return altair_or_json(chart, as_dict=as_dict) + @property + def chart_spec_file(self) -> str: + return "match_weights_waterfall.json" + + def alter_spec_from_data(self, chart_spec): + records = self.chart_data + chart_spec["params"][0]["bind"]["max"] = len(records) - 1 + if self.filter_nulls: + chart_spec["transform"].insert( + 1, {"filter": "(datum.bayes_factor !== 1.0)"} + ) + return chart_spec class ROCChart(SplinkChart[ChartRecord]): diff --git a/splink/internals/linker_components/visualisations.py b/splink/internals/linker_components/visualisations.py index 1d0456fcaa..2867bf3029 100644 --- a/splink/internals/linker_components/visualisations.py +++ b/splink/internals/linker_components/visualisations.py @@ -6,7 +6,7 @@ ChartReturnType, MatchWeightsHistogramChart, ParameterEstimateComparisonsChart, - waterfall_chart, + WaterfallChart, ) from splink.internals.cluster_studio import ( SamplingMethods, @@ -26,6 +26,7 @@ from splink.internals.term_frequencies import ( tf_adjustment_chart, ) +from splink.internals.waterfall_chart import records_to_waterfall_data if TYPE_CHECKING: from splink.internals.linker import Linker @@ -288,13 +289,15 @@ def waterfall_chart( """ self._linker._raise_error_if_necessary_waterfall_columns_not_computed() - return waterfall_chart( - records, - self._linker._settings_obj, + data = records_to_waterfall_data( + records, self._linker._settings_obj, remove_sensitive_data + ) + + return WaterfallChart( + data, filter_nulls, - remove_sensitive_data, as_dict, - ) + ).chart def comparison_viewer_dashboard( self,