From 82846655ac32f4d7836d4586d4e4e1036f227a8c Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Sun, 15 Feb 2026 13:18:00 +0000 Subject: [PATCH 01/20] filtered neighbours gets persisted --- splink/internals/spark/database_api.py | 1 + 1 file changed, 1 insertion(+) diff --git a/splink/internals/spark/database_api.py b/splink/internals/spark/database_api.py index 7a55bdc04e..6dbbc20ac7 100644 --- a/splink/internals/spark/database_api.py +++ b/splink/internals/spark/database_api.py @@ -299,6 +299,7 @@ def _break_lineage_and_repartition(self, spark_df, templated_name, physical_name r"__splink__clusters_at_all_thresholds", r"__splink__clustering_output_final", r"__splink__stable_nodes_at_new_threshold", + r"__splink__filtered_neighbours.*", ] if re.fullmatch(r"|".join(regex_to_persist), templated_name): From f71f382fd5c76048d5ae2ccccf5f26d849639321 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Sun, 15 Feb 2026 13:32:57 +0000 Subject: [PATCH 02/20] 4.0.14 release and changelog --- CHANGELOG.md | 7 ++++++- pyproject.toml | 2 +- splink/__init__.py | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 44fa10003a..c5136fbc0c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,7 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## Unreleased +## [4.0.14] - 2026-02-12 + +### Changed + +* Two dataset link only exploding blocking rule optimisation by @RobinL in https://github.com/moj-analytical-services/splink/pull/2931 +* Filtered neighbours gets persisted by @RobinL in https://github.com/moj-analytical-services/splink/pull/2933 ## [4.0.13] - 2026-02-12 diff --git a/pyproject.toml b/pyproject.toml index 404cf95217..7c319b4ed9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "splink" -version = "4.0.13" +version = "4.0.14" description = "Fast probabilistic data linkage at scale" authors = [ { name = "Robin Linacre", email = "robinlinacre@hotmail.com" }, diff --git a/splink/__init__.py b/splink/__init__.py index 55f661ea94..56088d76f0 100644 --- a/splink/__init__.py +++ b/splink/__init__.py @@ -56,7 +56,7 @@ def __getattr__(name): raise AttributeError(f"module 'splink' has no attribute '{name}'") from None -__version__ = "4.0.13" +__version__ = "4.0.14" __all__ = [ From 8dc6c93dbc5f63deee34a11baead9877bada6f99 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 16 Feb 2026 07:44:36 +0000 Subject: [PATCH 03/20] faster join to blocked pairs --- splink/internals/comparison_vector_values.py | 22 +++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/splink/internals/comparison_vector_values.py b/splink/internals/comparison_vector_values.py index ae6341543c..6f6f09ef10 100644 --- a/splink/internals/comparison_vector_values.py +++ b/splink/internals/comparison_vector_values.py @@ -4,7 +4,9 @@ from typing import List, Optional from splink.internals.input_column import InputColumn -from splink.internals.unique_id_concat import _composite_unique_id_from_nodes_sql +from splink.internals.unique_id_concat import ( + _composite_unique_id_from_nodes_sql, +) logger = logging.getLogger(__name__) @@ -62,18 +64,28 @@ def compute_comparison_vector_values_from_id_pairs_sqls( uid_l_expr = _composite_unique_id_from_nodes_sql(unique_id_columns, "l") uid_r_expr = _composite_unique_id_from_nodes_sql(unique_id_columns, "r") + # The where condition shouldn't really do anything because + # it's already covered by the inner join. + # However, Where there are large numbers of unmatched records, the DuckDB query + # planner can struggle with the double inner join below. It should + # push the filters down to the input tables, but it doesn't always do this. + # This forces it. + # The first table selects the required columns from the input tables # and alises them as `col_l`, `col_r` etc # using the __splink__blocked_id_pairs as an associated (junction) table - # That is, it does the join, but doesn't compute the comparison vectors - sql = sql = f""" + sql = f""" select {select_cols_expr}, b.match_key - from {input_tablename_l} as l - inner join __splink__blocked_id_pairs as b + from __splink__blocked_id_pairs as b + inner join {input_tablename_l} as l on {uid_l_expr} = b.join_key_l inner join {input_tablename_r} as r on {uid_r_expr} = b.join_key_r + where + {uid_l_expr} in (select join_key_l from __splink__blocked_id_pairs) + or + {uid_r_expr} in (select join_key_r from __splink__blocked_id_pairs) """ sqls.append({"sql": sql, "output_table_name": "blocked_with_cols"}) From 997eddd4c2e2c74ec2a784b49fa6e968e025ca19 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 16 Feb 2026 08:40:58 +0000 Subject: [PATCH 04/20] faster join to blocked pairs --- splink/internals/comparison_vector_values.py | 33 +++++++++++++------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/splink/internals/comparison_vector_values.py b/splink/internals/comparison_vector_values.py index 6f6f09ef10..c4befd3528 100644 --- a/splink/internals/comparison_vector_values.py +++ b/splink/internals/comparison_vector_values.py @@ -6,6 +6,7 @@ from splink.internals.input_column import InputColumn from splink.internals.unique_id_concat import ( _composite_unique_id_from_nodes_sql, + _composite_unique_id_from_edges_sql, ) logger = logging.getLogger(__name__) @@ -61,15 +62,29 @@ def compute_comparison_vector_values_from_id_pairs_sqls( select_cols_expr = ", \n".join(columns_to_select_for_blocking) - uid_l_expr = _composite_unique_id_from_nodes_sql(unique_id_columns, "l") - uid_r_expr = _composite_unique_id_from_nodes_sql(unique_id_columns, "r") - - # The where condition shouldn't really do anything because - # it's already covered by the inner join. - # However, Where there are large numbers of unmatched records, the DuckDB query - # planner can struggle with the double inner join below. It should + # Where there are large numbers of unmatched records, the DuckDB query planner + # Can struggle with the double inner join below. It should # push the filters down to the input tables, but it doesn't always do this. # This forces it. + if input_tablename_l == input_tablename_r: + uid_expr = _composite_unique_id_from_nodes_sql(unique_id_columns) + sql = f""" + select * + from {input_tablename_l} + where + {uid_expr} in (select join_key_l from __splink__blocked_id_pairs) + or + {uid_expr} in (select join_key_r from __splink__blocked_id_pairs) + """ + + sqls.append( + {"sql": sql, "output_table_name": "__splink__df_concat_with_tf_filtered"} + ) + input_tablename_l = "__splink__df_concat_with_tf_filtered" + input_tablename_r = "__splink__df_concat_with_tf_filtered" + + uid_l_expr = _composite_unique_id_from_nodes_sql(unique_id_columns, "l") + uid_r_expr = _composite_unique_id_from_nodes_sql(unique_id_columns, "r") # The first table selects the required columns from the input tables # and alises them as `col_l`, `col_r` etc @@ -82,10 +97,6 @@ def compute_comparison_vector_values_from_id_pairs_sqls( on {uid_l_expr} = b.join_key_l inner join {input_tablename_r} as r on {uid_r_expr} = b.join_key_r - where - {uid_l_expr} in (select join_key_l from __splink__blocked_id_pairs) - or - {uid_r_expr} in (select join_key_r from __splink__blocked_id_pairs) """ sqls.append({"sql": sql, "output_table_name": "blocked_with_cols"}) From 853c0f25fef63e92c1e18f6282b6c1712b3f8b79 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Mon, 16 Feb 2026 08:42:19 +0000 Subject: [PATCH 05/20] clean up --- splink/internals/comparison_vector_values.py | 1 - 1 file changed, 1 deletion(-) diff --git a/splink/internals/comparison_vector_values.py b/splink/internals/comparison_vector_values.py index c4befd3528..7149646c7b 100644 --- a/splink/internals/comparison_vector_values.py +++ b/splink/internals/comparison_vector_values.py @@ -6,7 +6,6 @@ from splink.internals.input_column import InputColumn from splink.internals.unique_id_concat import ( _composite_unique_id_from_nodes_sql, - _composite_unique_id_from_edges_sql, ) logger = logging.getLogger(__name__) From 53021d059020793770fd95b4547821a4d997e270 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 17 Feb 2026 18:01:40 +0000 Subject: [PATCH 06/20] pass correct args --- splink/internals/comparison_vector_values.py | 17 ++++++++++++++--- splink/internals/linker_components/inference.py | 4 ++++ 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/splink/internals/comparison_vector_values.py b/splink/internals/comparison_vector_values.py index 7149646c7b..97b86bebf1 100644 --- a/splink/internals/comparison_vector_values.py +++ b/splink/internals/comparison_vector_values.py @@ -45,6 +45,8 @@ def compute_comparison_vector_values_from_id_pairs_sqls( source_dataset_input_column: Optional[InputColumn], unique_id_input_column: InputColumn, include_clerical_match_score: bool = False, + link_type: Optional[str] = None, + sql_dialect_str: Optional[str] = None, ) -> list[dict[str, str]]: """Compute the comparison vectors from __splink__blocked_id_pairs, the materialised dataframe of blocked pairwise record comparisons. @@ -62,10 +64,19 @@ def compute_comparison_vector_values_from_id_pairs_sqls( select_cols_expr = ", \n".join(columns_to_select_for_blocking) # Where there are large numbers of unmatched records, the DuckDB query planner - # Can struggle with the double inner join below. It should + # can struggle with the double inner join below. It should # push the filters down to the input tables, but it doesn't always do this. - # This forces it. - if input_tablename_l == input_tablename_r: + # This forces it. it is only really relevant in the link only case, + # where one dataset is much larger than the other + # This optimisation is here due to poor performance observed in + # the `uk_address_matcher` package + # TODO: Once DuckDB 1.5 is released, check this is still needed + # ref https://github.com/moj-analytical-services/uk_address_matcher/issues/226 + if ( + input_tablename_l == input_tablename_r + and link_type == "two_dataset_link_only" + and sql_dialect_str == "duckdb" + ): uid_expr = _composite_unique_id_from_nodes_sql(unique_id_columns) sql = f""" select * diff --git a/splink/internals/linker_components/inference.py b/splink/internals/linker_components/inference.py index 280b074705..e0b92d4e4b 100644 --- a/splink/internals/linker_components/inference.py +++ b/splink/internals/linker_components/inference.py @@ -139,6 +139,8 @@ def deterministic_link(self) -> SplinkDataFrame: input_tablename_r="__splink__df_concat_with_tf", source_dataset_input_column=self._linker._settings_obj.column_info_settings.source_dataset_input_column, unique_id_input_column=self._linker._settings_obj.column_info_settings.unique_id_input_column, + link_type=link_type, + sql_dialect_str=self._linker._sql_dialect_str, ) pipeline.enqueue_list_of_sqls(sqls) @@ -268,6 +270,8 @@ def predict( input_tablename_r="__splink__df_concat_with_tf", source_dataset_input_column=self._linker._settings_obj.column_info_settings.source_dataset_input_column, unique_id_input_column=self._linker._settings_obj.column_info_settings.unique_id_input_column, + link_type=link_type, + sql_dialect_str=self._linker._sql_dialect_str, ) pipeline.enqueue_list_of_sqls(sqls) From b599a0165423b4dc64ed8be9bfc04cbda828b848 Mon Sep 17 00:00:00 2001 From: Robin Linacre Date: Tue, 17 Feb 2026 21:04:25 +0000 Subject: [PATCH 07/20] 4_0_15 release --- CHANGELOG.md | 9 +++++++++ pyproject.toml | 2 +- splink/__init__.py | 2 +- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c5136fbc0c..b14d9ddca8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,15 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## Unreleased + + +## [4.0.15] - 2026-02-17 + +### Changed + +Faster two_dataset_link_only joins when joining small table to large in duckdb by @RobinL in https://github.com/moj-analytical-services/splink/pull/2936 + ## [4.0.14] - 2026-02-12 ### Changed diff --git a/pyproject.toml b/pyproject.toml index 7c319b4ed9..5954967a15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "splink" -version = "4.0.14" +version = "4.0.15" description = "Fast probabilistic data linkage at scale" authors = [ { name = "Robin Linacre", email = "robinlinacre@hotmail.com" }, diff --git a/splink/__init__.py b/splink/__init__.py index 56088d76f0..903bd7ee0d 100644 --- a/splink/__init__.py +++ b/splink/__init__.py @@ -56,7 +56,7 @@ def __getattr__(name): raise AttributeError(f"module 'splink' has no attribute '{name}'") from None -__version__ = "4.0.14" +__version__ = "4.0.15" __all__ = [ From e626bf6f5ad2f66380ed13b1575d33c34f76cc65 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Tue, 17 Feb 2026 11:04:47 +0000 Subject: [PATCH 08/20] remove orphaned chart gone since https://github.com/moj-analytical-services/splink/pull/2157 --- splink/internals/charts.py | 14 ---- .../files/chart_defs/missingness.json | 80 ------------------- 2 files changed, 94 deletions(-) delete mode 100644 splink/internals/files/chart_defs/missingness.json diff --git a/splink/internals/charts.py b/splink/internals/charts.py index 9f1a4a31fa..d48b1981de 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -349,20 +349,6 @@ def parameter_estimate_comparisons(records, as_dict=False): return altair_or_json(chart, as_dict=as_dict) -def missingness_chart(records, as_dict=False): - chart_path = "missingness.json" - chart = load_chart_definition(chart_path) - - chart["data"]["values"] = records - - record_count = records[0]["total_record_count"] - - for c in chart["layer"]: - c["title"] = f"Missingness per column out of {record_count:,.0f} records" - - return altair_or_json(chart, as_dict=as_dict) - - def unlinkables_chart( records, x_col="match_weight", diff --git a/splink/internals/files/chart_defs/missingness.json b/splink/internals/files/chart_defs/missingness.json deleted file mode 100644 index de18f660fb..0000000000 --- a/splink/internals/files/chart_defs/missingness.json +++ /dev/null @@ -1,80 +0,0 @@ -{ - "config": { - "view": { - "continuousWidth": 400, - "continuousHeight": 300 - }, - "axis": { - "labelFontSize": 11 - } - }, - "title": "", - "layer": [ - { - "mark": "bar", - "encoding": { - "color": { - "type": "quantitative", - "field": "null_proportion", - "legend": { - "format": ".0%", - "offset": 30 - }, - "scale": {"domain": [0, 1], - "range": "heatmap" - }, - "title": "Missingness" - }, - "tooltip": [ - { - "type": "nominal", - "field": "column_name", - "title": "Column" - }, - { - "type": "quantitative", - "field": "null_count", - "format": ",.0f", - "title": "Count of nulls" - }, - { - "type": "quantitative", - "field": "null_proportion", - "format": ".2%", - "title": "Percentage of nulls" - }, - { - "type": "quantitative", - "field": "total_record_count", - "format": ",.0f", - "title": "Total record count" - } - ], - "x": { - "type": "quantitative", - "scale": {"domain": [0, 1]}, - "axis": { - "labelAlign": "center", - "format": "%", - "title": "Percentage of nulls" - }, - "field": "null_proportion" - }, - "y": { - "type": "nominal", - "axis": { - "title": "" - }, - "field": "column_name", - "sort": "-x" - } - }, - "title": "" - } - ], - "data": { - "values": "", - "name": "data-0e7bce5a1d2f132e282789d6ef7780fe" - }, - "$schema": "https://vega.github.io/schema/vega-lite/v5.9.3.json" -} From bbf2568756850786dca4d943ace62813b3db9136 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Tue, 17 Feb 2026 13:08:54 +0000 Subject: [PATCH 09/20] cl detailed record as dataclass --- splink/internals/comparison.py | 4 +- splink/internals/comparison_level.py | 85 +++++++++++++++++----------- splink/internals/waterfall_chart.py | 6 +- 3 files changed, 58 insertions(+), 37 deletions(-) diff --git a/splink/internals/comparison.py b/splink/internals/comparison.py index c7e69492d6..0f8d134937 100644 --- a/splink/internals/comparison.py +++ b/splink/internals/comparison.py @@ -367,7 +367,9 @@ def _as_detailed_records(self) -> list[dict[str, Any]]: record["comparison_name"] = self.output_column_name record = { **record, - **cl._as_detailed_record(self._num_levels, self.comparison_levels), + **cl._as_detailed_record( + self._num_levels, self.comparison_levels + ).as_dict(), } records.append(record) return records diff --git a/splink/internals/comparison_level.py b/splink/internals/comparison_level.py index 61aa669711..026e135ff8 100644 --- a/splink/internals/comparison_level.py +++ b/splink/internals/comparison_level.py @@ -4,6 +4,7 @@ import math import re from copy import copy +from dataclasses import asdict, dataclass from statistics import median from textwrap import dedent from typing import Any, Optional, Union, cast @@ -116,6 +117,33 @@ def _default_u_values(num_levels: int) -> list[float]: return u_vals +@dataclass +class ComparisonLevelDetailedRecord: + sql_condition: str + label_for_charts: str + + has_tf_adjustments: bool + tf_adjustment_column: str | None + tf_adjustment_weight: float + + is_null_level: bool + + m_probability: float | None + u_probability: float | None + m_probability_description: str | None + u_probability_description: str | None + + bayes_factor: float | None + log2_bayes_factor: float + bayes_factor_description: str + + comparison_vector_value: int + max_comparison_vector_value: int + + def as_dict(self): + return asdict(self) + + class ComparisonLevel: """Each ComparisonLevel defines a gradation (category) of similarity within a `Comparison`. @@ -735,39 +763,30 @@ def _as_completed_dict(self): def _as_detailed_record( self, comparison_num_levels: int, comparison_levels: list[ComparisonLevel] - ) -> dict[str, Any]: + ) -> ComparisonLevelDetailedRecord: "A detailed representation of this level to describe it in charting outputs" - output: dict[str, Any] = {} - output["sql_condition"] = self.sql_condition - output["label_for_charts"] = self._label_for_charts_no_duplicates( - comparison_levels + return ComparisonLevelDetailedRecord( + sql_condition=self.sql_condition, + label_for_charts=self._label_for_charts_no_duplicates(comparison_levels), + has_tf_adjustments=self._has_tf_adjustments, + tf_adjustment_column=( + self._tf_adjustment_input_column.input_name + if self._has_tf_adjustments + else None + ), + tf_adjustment_weight=self._tf_adjustment_weight, + is_null_level=self.is_null_level, + m_probability=self.m_probability if not self.is_null_level else None, + u_probability=self.u_probability if not self.is_null_level else None, + m_probability_description=self._m_probability_description, + u_probability_description=self._u_probability_description, + bayes_factor=self._bayes_factor, + log2_bayes_factor=self._log2_bayes_factor, + bayes_factor_description=self._bayes_factor_description, + comparison_vector_value=self.comparison_vector_value, + max_comparison_vector_value=comparison_num_levels - 1, ) - if not self._is_null_level: - output["m_probability"] = self.m_probability - output["u_probability"] = self.u_probability - - output["m_probability_description"] = self._m_probability_description - output["u_probability_description"] = self._u_probability_description - - output["has_tf_adjustments"] = self._has_tf_adjustments - if self._has_tf_adjustments: - output["tf_adjustment_column"] = self._tf_adjustment_input_column.input_name - else: - output["tf_adjustment_column"] = None - output["tf_adjustment_weight"] = self._tf_adjustment_weight - - output["is_null_level"] = self.is_null_level - output["bayes_factor"] = self._bayes_factor - output["log2_bayes_factor"] = self._log2_bayes_factor - output["comparison_vector_value"] = self.comparison_vector_value - output["max_comparison_vector_value"] = comparison_num_levels - 1 - output["bayes_factor_description"] = self._bayes_factor_description - output["m_probability_description"] = self._m_probability_description - output["u_probability_description"] = self._u_probability_description - - return output - def _parameter_estimates_as_records( self, comparison_num_levels: int, comparison_levels: list[ComparisonLevel] ) -> list[dict[str, Any]]: @@ -786,9 +805,9 @@ def _parameter_estimates_as_records( else: record["estimated_probability_as_log_odds"] = None - record["sql_condition"] = cl_record["sql_condition"] - record["comparison_level_label"] = cl_record["label_for_charts"] - record["comparison_vector_value"] = cl_record["comparison_vector_value"] + record["sql_condition"] = cl_record.sql_condition + record["comparison_level_label"] = cl_record.label_for_charts + record["comparison_vector_value"] = cl_record.comparison_vector_value output_records.append(record) return output_records diff --git a/splink/internals/waterfall_chart.py b/splink/internals/waterfall_chart.py index a8a080957b..07ef8b3992 100644 --- a/splink/internals/waterfall_chart.py +++ b/splink/internals/waterfall_chart.py @@ -57,9 +57,9 @@ def _comparison_records( cl = c._get_comparison_level_by_comparison_vector_value(cv_value) waterfall_record = { field: value - for field, value in cl._as_detailed_record( - c._num_levels, c.comparison_levels - ).items() + for field, value in cl._as_detailed_record(c._num_levels, c.comparison_levels) + .as_dict() + .items() if field in [ "label_for_charts", From 31c2d8c8cd7b46aa0228b0bc1d6ad9157d3136a0 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Tue, 17 Feb 2026 13:18:59 +0000 Subject: [PATCH 10/20] comparison name in cl detailed rec --- splink/internals/comparison.py | 18 ++++++++---------- splink/internals/comparison_level.py | 1 + 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/splink/internals/comparison.py b/splink/internals/comparison.py index 0f8d134937..95816ca2db 100644 --- a/splink/internals/comparison.py +++ b/splink/internals/comparison.py @@ -10,7 +10,11 @@ join_list_with_commas_final_and, ) -from .comparison_level import ComparisonLevel, _default_m_values, _default_u_values +from .comparison_level import ( + ComparisonLevel, + _default_m_values, + _default_u_values, +) # https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports if TYPE_CHECKING: @@ -363,15 +367,9 @@ def _is_trained(self): def _as_detailed_records(self) -> list[dict[str, Any]]: records = [] for cl in self.comparison_levels: - record = {} - record["comparison_name"] = self.output_column_name - record = { - **record, - **cl._as_detailed_record( - self._num_levels, self.comparison_levels - ).as_dict(), - } - records.append(record) + record = cl._as_detailed_record(self._num_levels, self.comparison_levels) + record.comparison_name = self.output_column_name + records.append(record.as_dict()) return records @property diff --git a/splink/internals/comparison_level.py b/splink/internals/comparison_level.py index 026e135ff8..583268518f 100644 --- a/splink/internals/comparison_level.py +++ b/splink/internals/comparison_level.py @@ -139,6 +139,7 @@ class ComparisonLevelDetailedRecord: comparison_vector_value: int max_comparison_vector_value: int + comparison_name: str | None = None def as_dict(self): return asdict(self) From e3205ce0558f5660a4421bc40ce42a6ef6787696 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Tue, 17 Feb 2026 13:30:33 +0000 Subject: [PATCH 11/20] detailed record dataclass in comparison --- splink/internals/comparison.py | 7 ++++--- splink/internals/settings.py | 2 +- splink/internals/term_frequencies.py | 4 +++- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/splink/internals/comparison.py b/splink/internals/comparison.py index 95816ca2db..b8988a9934 100644 --- a/splink/internals/comparison.py +++ b/splink/internals/comparison.py @@ -12,6 +12,7 @@ from .comparison_level import ( ComparisonLevel, + ComparisonLevelDetailedRecord, _default_m_values, _default_u_values, ) @@ -364,12 +365,12 @@ def _is_trained(self): return self._all_m_are_trained and self._all_u_are_trained @property - def _as_detailed_records(self) -> list[dict[str, Any]]: + def _as_detailed_records(self) -> list[ComparisonLevelDetailedRecord]: records = [] for cl in self.comparison_levels: record = cl._as_detailed_record(self._num_levels, self.comparison_levels) record.comparison_name = self.output_column_name - records.append(record.as_dict()) + records.append(record) return records @property @@ -466,5 +467,5 @@ def match_weights_chart(self, as_dict=False): """Display a chart of comparison levels of the comparison""" from splink.internals.charts import comparison_match_weights_chart - records = self._as_detailed_records + records = list(map(lambda rec: rec.as_dict(), self._as_detailed_records)) return comparison_match_weights_chart(records, as_dict=as_dict) diff --git a/splink/internals/settings.py b/splink/internals/settings.py index 9a4adea43b..0a5fcde4eb 100644 --- a/splink/internals/settings.py +++ b/splink/internals/settings.py @@ -113,7 +113,7 @@ def parameters_as_detailed_records(self): output = [] rr_match = self.probability_two_random_records_match for i, cc in enumerate(self.comparisons): - records = cc._as_detailed_records + records = list(map(lambda rec: rec.as_dict(), cc._as_detailed_records)) for r in records: r["probability_two_random_records_match"] = rr_match r["comparison_sort_order"] = i diff --git a/splink/internals/term_frequencies.py b/splink/internals/term_frequencies.py index da12b12699..85034f6474 100644 --- a/splink/internals/term_frequencies.py +++ b/splink/internals/term_frequencies.py @@ -239,7 +239,9 @@ def tf_adjustment_chart( # Data for chart comparison = linker._settings_obj._get_comparison_by_output_column_name(col) - comparison_records = comparison._as_detailed_records + comparison_records = list( + map(lambda rec: rec.as_dict(), comparison._as_detailed_records) + ) keys_to_retain = [ "comparison_vector_value", From 7eb4d05221c3495d487f997c6f46ee87c09e731e Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Tue, 17 Feb 2026 14:47:20 +0000 Subject: [PATCH 12/20] specialised detailed record for settings --- splink/internals/comparison_level.py | 6 +-- splink/internals/settings.py | 79 +++++++++++++++++++--------- 2 files changed, 56 insertions(+), 29 deletions(-) diff --git a/splink/internals/comparison_level.py b/splink/internals/comparison_level.py index 583268518f..85dcb0208c 100644 --- a/splink/internals/comparison_level.py +++ b/splink/internals/comparison_level.py @@ -117,14 +117,14 @@ def _default_u_values(num_levels: int) -> list[float]: return u_vals -@dataclass +@dataclass(kw_only=True) class ComparisonLevelDetailedRecord: - sql_condition: str + sql_condition: str | None label_for_charts: str has_tf_adjustments: bool tf_adjustment_column: str | None - tf_adjustment_weight: float + tf_adjustment_weight: float | None is_null_level: bool diff --git a/splink/internals/settings.py b/splink/internals/settings.py index 0a5fcde4eb..2d417e3f7c 100644 --- a/splink/internals/settings.py +++ b/splink/internals/settings.py @@ -8,7 +8,10 @@ from splink.internals.blocking import BlockingRule from splink.internals.charts import m_u_parameters_chart, match_weights_chart from splink.internals.comparison import Comparison -from splink.internals.comparison_level import ComparisonLevel +from splink.internals.comparison_level import ( + ComparisonLevel, + ComparisonLevelDetailedRecord, +) from splink.internals.dialects import SplinkDialect from splink.internals.input_column import InputColumn from splink.internals.misc import ( @@ -27,6 +30,26 @@ class ComparisonAndLevelDict(TypedDict): comparison: Comparison +@dataclass(kw_only=True) +class ModelParameterDetailedRecord(ComparisonLevelDetailedRecord): + probability_two_random_records_match: float + comparison_sort_order: int + + @classmethod + def from_cl_detailed_record( + cls, + cl_rec: ComparisonLevelDetailedRecord, + *, + probability_two_random_records_match: float, + comparison_sort_order: int, + ) -> ModelParameterDetailedRecord: + return cls( + **asdict(cl_rec), + probability_two_random_records_match=probability_two_random_records_match, + comparison_sort_order=comparison_sort_order, + ) + + @dataclass(frozen=True) class ColumnInfoSettings: match_weight_column_prefix: str @@ -109,14 +132,18 @@ def copy(self): return deepcopy(self) @property - def parameters_as_detailed_records(self): + def parameters_as_detailed_records(self) -> list[ModelParameterDetailedRecord]: output = [] rr_match = self.probability_two_random_records_match for i, cc in enumerate(self.comparisons): - records = list(map(lambda rec: rec.as_dict(), cc._as_detailed_records)) - for r in records: - r["probability_two_random_records_match"] = rr_match - r["comparison_sort_order"] = i + records = [ + ModelParameterDetailedRecord.from_cl_detailed_record( + r, + probability_two_random_records_match=rr_match, + comparison_sort_order=i, + ) + for r in cc._as_detailed_records + ] output.extend(records) prior_description = ( @@ -128,26 +155,26 @@ def parameters_as_detailed_records(self): ) # Finally add a record for probability_two_random_records_match - prop_record = { - "comparison_name": "probability_two_random_records_match", - "sql_condition": None, - "label_for_charts": "", - "m_probability": None, - "u_probability": None, - "m_probability_description": None, - "u_probability_description": None, - "has_tf_adjustments": False, - "tf_adjustment_column": None, - "tf_adjustment_weight": None, - "is_null_level": False, - "bayes_factor": prob_to_bayes_factor(rr_match), - "log2_bayes_factor": prob_to_match_weight(rr_match), - "comparison_vector_value": 0, - "max_comparison_vector_value": 0, - "bayes_factor_description": prior_description, - "probability_two_random_records_match": rr_match, - "comparison_sort_order": -1, - } + prop_record = ModelParameterDetailedRecord( + comparison_name="probability_two_random_records_match", + sql_condition=None, + label_for_charts="", + m_probability=None, + u_probability=None, + m_probability_description=None, + u_probability_description=None, + has_tf_adjustments=False, + tf_adjustment_column=None, + tf_adjustment_weight=None, + is_null_level=False, + bayes_factor=prob_to_bayes_factor(rr_match), + log2_bayes_factor=prob_to_match_weight(rr_match), + comparison_vector_value=0, + max_comparison_vector_value=0, + bayes_factor_description=prior_description, + probability_two_random_records_match=rr_match, + comparison_sort_order=-1, + ) output.insert(0, prop_record) return output From f7537d8f71ab40c556d56688b194ec99a5a4c054 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Tue, 17 Feb 2026 15:26:31 +0000 Subject: [PATCH 13/20] em iteration record subclass --- splink/internals/em_training_session.py | 43 +++++++++++++++++++------ splink/internals/settings.py | 8 +++-- 2 files changed, 39 insertions(+), 12 deletions(-) diff --git a/splink/internals/em_training_session.py b/splink/internals/em_training_session.py index 866d7023ef..cbd25c9c12 100644 --- a/splink/internals/em_training_session.py +++ b/splink/internals/em_training_session.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from dataclasses import asdict, dataclass from typing import TYPE_CHECKING, List from splink.internals.blocking import BlockingRule, block_using_rules_sqls @@ -22,6 +23,7 @@ from splink.internals.settings import ( ComparisonAndLevelDict, CoreModelSettings, + ModelParameterDetailedRecord, Settings, TrainingSettings, ) @@ -39,6 +41,27 @@ from splink.internals.splink_dataframe import SplinkDataFrame +@dataclass(kw_only=True) +class ModelParameterIterationDetailedRecord(ModelParameterDetailedRecord): + iteration: int + + @classmethod + def from_settings_param_detailed_record( + cls, + cl_rec: ModelParameterDetailedRecord, + *, + probability_two_random_records_match: float, + iteration: int, + ) -> ModelParameterIterationDetailedRecord: + cl_rec.probability_two_random_records_match = ( + probability_two_random_records_match + ) + return cls( + **asdict(cl_rec), + iteration=iteration, + ) + + class EMTrainingSession: """Manages training models using the Expectation Maximisation algorithm, and holds statistics on the evolution of parameter estimates. Plots diagnostic charts @@ -320,20 +343,20 @@ def _blocking_adjusted_probability_two_random_records_match(self): return adjusted_prop_m @property - def _iteration_history_records(self): + def _iteration_history_records(self) -> list[ModelParameterIterationDetailedRecord]: output_records = [] for iteration, core_model_settings in enumerate( self._core_model_settings_history ): - records = core_model_settings.parameters_as_detailed_records - - for r in records: - r["iteration"] = iteration - # TODO: why lambda from current settings, not history? - r["probability_two_random_records_match"] = ( - self.core_model_settings.probability_two_random_records_match + records = [ + ModelParameterIterationDetailedRecord.from_settings_param_detailed_record( + r, + probability_two_random_records_match=self.core_model_settings.probability_two_random_records_match, + iteration=iteration, ) + for r in core_model_settings.parameters_as_detailed_records + ] output_records.extend(records) return output_records @@ -370,7 +393,7 @@ def match_weights_interactive_history_chart(self) -> ChartReturnType: Returns: An interactive Altair chart. """ - records = self._iteration_history_records + records = list(map(lambda rec: rec.as_dict(), self._iteration_history_records)) return match_weights_interactive_history_chart( records, blocking_rule=self._blocking_rule_for_training.blocking_rule_sql ) @@ -382,7 +405,7 @@ def m_u_values_interactive_history_chart(self) -> ChartReturnType: Returns: An interactive Altair chart. """ - records = self._iteration_history_records + records = list(map(lambda rec: rec.as_dict(), self._iteration_history_records)) return m_u_parameters_interactive_history_chart(records) def __repr__(self): diff --git a/splink/internals/settings.py b/splink/internals/settings.py index 2d417e3f7c..40c6029796 100644 --- a/splink/internals/settings.py +++ b/splink/internals/settings.py @@ -595,12 +595,16 @@ def _as_completed_dict(self): } def match_weights_chart(self, as_dict=False): - records = self._parameters_as_detailed_records + records = list( + map(lambda rec: rec.as_dict(), self._parameters_as_detailed_records) + ) return match_weights_chart(records, as_dict=as_dict) def m_u_parameters_chart(self, as_dict=False): - records = self._parameters_as_detailed_records + records = list( + map(lambda rec: rec.as_dict(), self._parameters_as_detailed_records) + ) return m_u_parameters_chart(records, as_dict=as_dict) def _columns_without_estimated_parameters_message(self): From 201bfccd959192a921a8058ded7f06bd0c569955 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Tue, 17 Feb 2026 17:16:56 +0000 Subject: [PATCH 14/20] comparison_match_weights_chart using dataclasses directly --- splink/internals/charts.py | 23 ++++++++++++++++++----- splink/internals/comparison.py | 2 +- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/splink/internals/charts.py b/splink/internals/charts.py index d48b1981de..6d66cf4f34 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -3,8 +3,10 @@ import json import math import os -from typing import TYPE_CHECKING, Any, Dict, Union +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Dict, Protocol, Union +from splink.internals.comparison_level import ComparisonLevelDetailedRecord from splink.internals.misc import read_resource from splink.internals.waterfall_chart import records_to_waterfall_data @@ -41,6 +43,14 @@ def altair_or_json( return chart_dict +class AsDictable(Protocol): + def as_dict(self) -> dict[str, Any]: ... + + +def list_items_as_dicts(lst: Iterable[AsDictable]) -> list[dict[str, Any]]: + return list(map(lambda item: item.as_dict(), lst)) + + iframe_message = """ To view in Jupyter you can use the following command: @@ -110,18 +120,21 @@ def match_weights_chart(records, as_dict=False): return altair_or_json(chart, as_dict=as_dict) -def comparison_match_weights_chart(records, as_dict=False): +def comparison_match_weights_chart( + records: list[ComparisonLevelDetailedRecord], as_dict: bool = False +) -> ChartReturnType: chart_path = "match_weights_interactive_history.json" chart = load_chart_definition(chart_path) # Remove iteration history since this is a static chart - del chart["vconcat"][0] + # TODO: some render issue if we remove empty top panel, so leave for now + # del chart["vconcat"][0] del chart["params"] del chart["transform"] chart["title"]["text"] = "Comparison summary" - records = [r for r in records if r["comparison_vector_value"] != -1] - chart["data"]["values"] = records + records = [r for r in records if r.comparison_vector_value != -1] + chart["data"]["values"] = list_items_as_dicts(records) return altair_or_json(chart, as_dict=as_dict) diff --git a/splink/internals/comparison.py b/splink/internals/comparison.py index b8988a9934..a3e9857499 100644 --- a/splink/internals/comparison.py +++ b/splink/internals/comparison.py @@ -467,5 +467,5 @@ def match_weights_chart(self, as_dict=False): """Display a chart of comparison levels of the comparison""" from splink.internals.charts import comparison_match_weights_chart - records = list(map(lambda rec: rec.as_dict(), self._as_detailed_records)) + records = self._as_detailed_records return comparison_match_weights_chart(records, as_dict=as_dict) From 8fefe1ed656d706c1b9b305aa5063442266cdcc0 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Tue, 17 Feb 2026 17:17:06 +0000 Subject: [PATCH 15/20] test comparison_match_weight_charts --- tests/test_charts.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/tests/test_charts.py b/tests/test_charts.py index 455b6ac837..b85cb33114 100644 --- a/tests/test_charts.py +++ b/tests/test_charts.py @@ -142,6 +142,24 @@ def test_m_u_charts(dialect, test_helpers): linker.visualisations.match_weights_chart() +@mark_with_dialects_excluding() +def test_comparison_match_weight_charts(dialect, test_helpers): + settings = { + "link_type": "dedupe_only", + "comparisons": [ + cl.ExactMatch("gender"), + cl.ExactMatch("tm_partial"), + cl.LevenshteinAtThresholds("surname", [1]), + ], + } + helper = test_helpers[dialect] + + linker = helper.linker_with_registration(df, settings) + + comp = linker._settings_obj.comparisons[0] + comp.match_weights_chart() + + @mark_with_dialects_excluding() def test_parameter_estimate_charts(dialect, test_helpers): settings = { From 6b814c54ebde16495f8d0d7c5edc9528c127dee4 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Tue, 17 Feb 2026 17:32:50 +0000 Subject: [PATCH 16/20] match weight charts using dataclass --- splink/internals/charts.py | 11 +++++++---- splink/internals/settings.py | 6 ++---- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/splink/internals/charts.py b/splink/internals/charts.py index 6d66cf4f34..99a5c14b2f 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -8,6 +8,7 @@ from splink.internals.comparison_level import ComparisonLevelDetailedRecord from splink.internals.misc import read_resource +from splink.internals.settings import ModelParameterDetailedRecord from splink.internals.waterfall_chart import records_to_waterfall_data if TYPE_CHECKING: @@ -96,7 +97,9 @@ def save_offline_chart( print(iframe_message.format(filename=filename)) # noqa: T201 -def match_weights_chart(records, as_dict=False): +def match_weights_chart( + records: list[ModelParameterDetailedRecord], as_dict: bool = False +) -> ChartReturnType: chart_path = "match_weights_interactive_history.json" chart = load_chart_definition(chart_path) @@ -104,15 +107,15 @@ def match_weights_chart(records, as_dict=False): del chart["params"] del chart["transform"] - records = [r for r in records if r["comparison_vector_value"] != -1] - chart["data"]["values"] = records + records = [r for r in records if r.comparison_vector_value != -1] bayes_factors = [ abs(l2bf) for r in records - if (l2bf := r["log2_bayes_factor"]) is not None and not math.isinf(l2bf) + if (l2bf := r.log2_bayes_factor) is not None and not math.isinf(l2bf) ] max_value = math.ceil(max(bayes_factors)) + chart["data"]["values"] = list_items_as_dicts(records) chart["vconcat"][0]["encoding"]["x"]["scale"]["domain"] = [-max_value, max_value] chart["vconcat"][1]["encoding"]["x"]["scale"]["domain"] = [-max_value, max_value] diff --git a/splink/internals/settings.py b/splink/internals/settings.py index 40c6029796..d690d07c21 100644 --- a/splink/internals/settings.py +++ b/splink/internals/settings.py @@ -595,12 +595,10 @@ def _as_completed_dict(self): } def match_weights_chart(self, as_dict=False): - records = list( - map(lambda rec: rec.as_dict(), self._parameters_as_detailed_records) + return match_weights_chart( + self._parameters_as_detailed_records, as_dict=as_dict ) - return match_weights_chart(records, as_dict=as_dict) - def m_u_parameters_chart(self, as_dict=False): records = list( map(lambda rec: rec.as_dict(), self._parameters_as_detailed_records) From 6360bac31efec8e470a654ed216b5c231dd99ac2 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Tue, 17 Feb 2026 17:55:25 +0000 Subject: [PATCH 17/20] mu charts using dataclass --- splink/internals/charts.py | 10 ++++++---- splink/internals/settings.py | 5 ++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/splink/internals/charts.py b/splink/internals/charts.py index 99a5c14b2f..a1cfc7c544 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -141,7 +141,9 @@ def comparison_match_weights_chart( return altair_or_json(chart, as_dict=as_dict) -def m_u_parameters_chart(records, as_dict=False): +def m_u_parameters_chart( + records: list[ModelParameterDetailedRecord], as_dict: bool = False +) -> ChartReturnType: chart_path = "m_u_parameters_interactive_history.json" chart = load_chart_definition(chart_path) @@ -152,10 +154,10 @@ def m_u_parameters_chart(records, as_dict=False): records = [ r for r in records - if r["comparison_vector_value"] != -1 - and r["comparison_name"] != "probability_two_random_records_match" + if r.comparison_vector_value != -1 + and r.comparison_name != "probability_two_random_records_match" ] - chart["data"]["values"] = records + chart["data"]["values"] = list_items_as_dicts(records) return altair_or_json(chart, as_dict=as_dict) diff --git a/splink/internals/settings.py b/splink/internals/settings.py index d690d07c21..fce55ce66a 100644 --- a/splink/internals/settings.py +++ b/splink/internals/settings.py @@ -600,10 +600,9 @@ def match_weights_chart(self, as_dict=False): ) def m_u_parameters_chart(self, as_dict=False): - records = list( - map(lambda rec: rec.as_dict(), self._parameters_as_detailed_records) + return m_u_parameters_chart( + self._parameters_as_detailed_records, as_dict=as_dict ) - return m_u_parameters_chart(records, as_dict=as_dict) def _columns_without_estimated_parameters_message(self): message_lines = [] From 21a4ea52854b2bb19e4a8e6fce4f10e61a2a2f28 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Wed, 18 Feb 2026 09:18:42 +0000 Subject: [PATCH 18/20] import only for type checking --- splink/internals/charts.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/splink/internals/charts.py b/splink/internals/charts.py index a1cfc7c544..55c031fc89 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -4,19 +4,20 @@ import math import os from collections.abc import Iterable -from typing import TYPE_CHECKING, Any, Dict, Protocol, Union +from typing import TYPE_CHECKING, Any, Protocol, Union -from splink.internals.comparison_level import ComparisonLevelDetailedRecord from splink.internals.misc import read_resource -from splink.internals.settings import ModelParameterDetailedRecord from splink.internals.waterfall_chart import records_to_waterfall_data if TYPE_CHECKING: from altair import SchemaBase + + from splink.internals.comparison_level import ComparisonLevelDetailedRecord + from splink.internals.settings import ModelParameterDetailedRecord else: SchemaBase = None # type alias: -ChartReturnType = Union[Dict[Any, Any], SchemaBase] +ChartReturnType = Union[dict[Any, Any], SchemaBase] def load_chart_definition(filename): From acbec497d34d064974bb2c4c7e5a1994c04e6a3d Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Wed, 18 Feb 2026 09:56:10 +0000 Subject: [PATCH 19/20] avoid kw only not supported in 3.9, and as we only have 1 default arg, that's only instantiated in 1 place, it is not a big sacrifice to just set this value at the call site --- splink/internals/comparison_level.py | 5 +++-- splink/internals/em_training_session.py | 2 +- splink/internals/settings.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/splink/internals/comparison_level.py b/splink/internals/comparison_level.py index 85dcb0208c..0f4de99597 100644 --- a/splink/internals/comparison_level.py +++ b/splink/internals/comparison_level.py @@ -117,7 +117,7 @@ def _default_u_values(num_levels: int) -> list[float]: return u_vals -@dataclass(kw_only=True) +@dataclass class ComparisonLevelDetailedRecord: sql_condition: str | None label_for_charts: str @@ -139,7 +139,7 @@ class ComparisonLevelDetailedRecord: comparison_vector_value: int max_comparison_vector_value: int - comparison_name: str | None = None + comparison_name: str | None def as_dict(self): return asdict(self) @@ -786,6 +786,7 @@ def _as_detailed_record( bayes_factor_description=self._bayes_factor_description, comparison_vector_value=self.comparison_vector_value, max_comparison_vector_value=comparison_num_levels - 1, + comparison_name=None, ) def _parameter_estimates_as_records( diff --git a/splink/internals/em_training_session.py b/splink/internals/em_training_session.py index cbd25c9c12..1681972db1 100644 --- a/splink/internals/em_training_session.py +++ b/splink/internals/em_training_session.py @@ -41,7 +41,7 @@ from splink.internals.splink_dataframe import SplinkDataFrame -@dataclass(kw_only=True) +@dataclass class ModelParameterIterationDetailedRecord(ModelParameterDetailedRecord): iteration: int diff --git a/splink/internals/settings.py b/splink/internals/settings.py index fce55ce66a..5a807d84d7 100644 --- a/splink/internals/settings.py +++ b/splink/internals/settings.py @@ -30,7 +30,7 @@ class ComparisonAndLevelDict(TypedDict): comparison: Comparison -@dataclass(kw_only=True) +@dataclass class ModelParameterDetailedRecord(ComparisonLevelDetailedRecord): probability_two_random_records_match: float comparison_sort_order: int From cbc429f8f5bd413353127f88f3257fbb098f7317 Mon Sep 17 00:00:00 2001 From: ADBond <48208438+ADBond@users.noreply.github.com> Date: Wed, 18 Feb 2026 11:16:12 +0000 Subject: [PATCH 20/20] iteration charts with dataclasses --- splink/internals/charts.py | 27 ++++++++++++++++--------- splink/internals/em_training_session.py | 7 +++---- 2 files changed, 21 insertions(+), 13 deletions(-) diff --git a/splink/internals/charts.py b/splink/internals/charts.py index 55c031fc89..e3bcd392ae 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -13,6 +13,9 @@ from altair import SchemaBase from splink.internals.comparison_level import ComparisonLevelDetailedRecord + from splink.internals.em_training_session import ( + ModelParameterIterationDetailedRecord, + ) from splink.internals.settings import ModelParameterDetailedRecord else: SchemaBase = None @@ -170,39 +173,45 @@ def probability_two_random_records_match_iteration_chart(records, as_dict=False) return altair_or_json(chart, as_dict=as_dict) -def match_weights_interactive_history_chart(records, as_dict=False, blocking_rule=None): +def match_weights_interactive_history_chart( + records: list[ModelParameterIterationDetailedRecord], + as_dict: bool = False, + blocking_rule: str | None = None, +) -> ChartReturnType: chart_path = "match_weights_interactive_history.json" chart = load_chart_definition(chart_path) chart["title"]["subtitle"] = f"Training session blocked on {blocking_rule}" - records = [r for r in records if r["comparison_vector_value"] != -1] - chart["data"]["values"] = records + records = [r for r in records if r.comparison_vector_value != -1] max_iteration = 0 for r in records: - max_iteration = max(r["iteration"], max_iteration) + max_iteration = max(r.iteration, max_iteration) + chart["data"]["values"] = list_items_as_dicts(records) chart["params"][0]["bind"]["max"] = max_iteration chart["params"][0]["value"] = max_iteration return altair_or_json(chart, as_dict=as_dict) -def m_u_parameters_interactive_history_chart(records, as_dict=False): +def m_u_parameters_interactive_history_chart( + records: list[ModelParameterIterationDetailedRecord], as_dict: bool = False +) -> ChartReturnType: chart_path = "m_u_parameters_interactive_history.json" chart = load_chart_definition(chart_path) records = [ r for r in records - if r["comparison_vector_value"] != -1 - and r["comparison_name"] != "probability_two_random_records_match" + if r.comparison_vector_value != -1 + and r.comparison_name != "probability_two_random_records_match" ] - chart["data"]["values"] = records max_iteration = 0 for r in records: - max_iteration = max(r["iteration"], max_iteration) + max_iteration = max(r.iteration, max_iteration) + chart["data"]["values"] = list_items_as_dicts(records) chart["params"][0]["bind"]["max"] = max_iteration chart["params"][0]["value"] = max_iteration return altair_or_json(chart, as_dict=as_dict) diff --git a/splink/internals/em_training_session.py b/splink/internals/em_training_session.py index 1681972db1..f50af9153b 100644 --- a/splink/internals/em_training_session.py +++ b/splink/internals/em_training_session.py @@ -393,9 +393,9 @@ def match_weights_interactive_history_chart(self) -> ChartReturnType: Returns: An interactive Altair chart. """ - records = list(map(lambda rec: rec.as_dict(), self._iteration_history_records)) return match_weights_interactive_history_chart( - records, blocking_rule=self._blocking_rule_for_training.blocking_rule_sql + self._iteration_history_records, + blocking_rule=self._blocking_rule_for_training.blocking_rule_sql, ) def m_u_values_interactive_history_chart(self) -> ChartReturnType: @@ -405,8 +405,7 @@ def m_u_values_interactive_history_chart(self) -> ChartReturnType: Returns: An interactive Altair chart. """ - records = list(map(lambda rec: rec.as_dict(), self._iteration_history_records)) - return m_u_parameters_interactive_history_chart(records) + return m_u_parameters_interactive_history_chart(self._iteration_history_records) def __repr__(self): deactivated_cols = ", ".join(