diff --git a/CHANGELOG.md b/CHANGELOG.md index b81a0fd541..059463a828 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,19 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Removed salting mechanism as it is no longer required for parallelisation in DuckDB [#2849](https://github.com/moj-analytical-services/splink/pull/2849) - `pandas` and `numpy` are no longer required dependencies [#2883](https://github.com/moj-analytical-services/splink/pull/2883) +## [4.0.15] - 2026-02-17 + +### Changed + +Faster two_dataset_link_only joins when joining small table to large in duckdb by @RobinL in https://github.com/moj-analytical-services/splink/pull/2936 + +## [4.0.14] - 2026-02-12 + +### Changed + +* Two dataset link only exploding blocking rule optimisation by @RobinL in https://github.com/moj-analytical-services/splink/pull/2931 +* Filtered neighbours gets persisted by @RobinL in https://github.com/moj-analytical-services/splink/pull/2933 + ## [4.0.13] - 2026-02-12 ### Fixed diff --git a/splink/internals/charts.py b/splink/internals/charts.py index 9f1a4a31fa..e3bcd392ae 100644 --- a/splink/internals/charts.py +++ b/splink/internals/charts.py @@ -3,17 +3,24 @@ import json import math import os -from typing import TYPE_CHECKING, Any, Dict, Union +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Protocol, Union from splink.internals.misc import read_resource from splink.internals.waterfall_chart import records_to_waterfall_data if TYPE_CHECKING: from altair import SchemaBase + + from splink.internals.comparison_level import ComparisonLevelDetailedRecord + from splink.internals.em_training_session import ( + ModelParameterIterationDetailedRecord, + ) + from splink.internals.settings import ModelParameterDetailedRecord else: SchemaBase = None # type alias: -ChartReturnType = Union[Dict[Any, Any], SchemaBase] +ChartReturnType = Union[dict[Any, Any], SchemaBase] def load_chart_definition(filename): @@ -41,6 +48,14 @@ def altair_or_json( return chart_dict +class AsDictable(Protocol): + def as_dict(self) -> dict[str, Any]: ... + + +def list_items_as_dicts(lst: Iterable[AsDictable]) -> list[dict[str, Any]]: + return list(map(lambda item: item.as_dict(), lst)) + + iframe_message = """ To view in Jupyter you can use the following command: @@ -86,7 +101,9 @@ def save_offline_chart( print(iframe_message.format(filename=filename)) # noqa: T201 -def match_weights_chart(records, as_dict=False): +def match_weights_chart( + records: list[ModelParameterDetailedRecord], as_dict: bool = False +) -> ChartReturnType: chart_path = "match_weights_interactive_history.json" chart = load_chart_definition(chart_path) @@ -94,15 +111,15 @@ def match_weights_chart(records, as_dict=False): del chart["params"] del chart["transform"] - records = [r for r in records if r["comparison_vector_value"] != -1] - chart["data"]["values"] = records + records = [r for r in records if r.comparison_vector_value != -1] bayes_factors = [ abs(l2bf) for r in records - if (l2bf := r["log2_bayes_factor"]) is not None and not math.isinf(l2bf) + if (l2bf := r.log2_bayes_factor) is not None and not math.isinf(l2bf) ] max_value = math.ceil(max(bayes_factors)) + chart["data"]["values"] = list_items_as_dicts(records) chart["vconcat"][0]["encoding"]["x"]["scale"]["domain"] = [-max_value, max_value] chart["vconcat"][1]["encoding"]["x"]["scale"]["domain"] = [-max_value, max_value] @@ -110,22 +127,27 @@ def match_weights_chart(records, as_dict=False): return altair_or_json(chart, as_dict=as_dict) -def comparison_match_weights_chart(records, as_dict=False): +def comparison_match_weights_chart( + records: list[ComparisonLevelDetailedRecord], as_dict: bool = False +) -> ChartReturnType: chart_path = "match_weights_interactive_history.json" chart = load_chart_definition(chart_path) # Remove iteration history since this is a static chart - del chart["vconcat"][0] + # TODO: some render issue if we remove empty top panel, so leave for now + # del chart["vconcat"][0] del chart["params"] del chart["transform"] chart["title"]["text"] = "Comparison summary" - records = [r for r in records if r["comparison_vector_value"] != -1] - chart["data"]["values"] = records + records = [r for r in records if r.comparison_vector_value != -1] + chart["data"]["values"] = list_items_as_dicts(records) return altair_or_json(chart, as_dict=as_dict) -def m_u_parameters_chart(records, as_dict=False): +def m_u_parameters_chart( + records: list[ModelParameterDetailedRecord], as_dict: bool = False +) -> ChartReturnType: chart_path = "m_u_parameters_interactive_history.json" chart = load_chart_definition(chart_path) @@ -136,10 +158,10 @@ def m_u_parameters_chart(records, as_dict=False): records = [ r for r in records - if r["comparison_vector_value"] != -1 - and r["comparison_name"] != "probability_two_random_records_match" + if r.comparison_vector_value != -1 + and r.comparison_name != "probability_two_random_records_match" ] - chart["data"]["values"] = records + chart["data"]["values"] = list_items_as_dicts(records) return altair_or_json(chart, as_dict=as_dict) @@ -151,39 +173,45 @@ def probability_two_random_records_match_iteration_chart(records, as_dict=False) return altair_or_json(chart, as_dict=as_dict) -def match_weights_interactive_history_chart(records, as_dict=False, blocking_rule=None): +def match_weights_interactive_history_chart( + records: list[ModelParameterIterationDetailedRecord], + as_dict: bool = False, + blocking_rule: str | None = None, +) -> ChartReturnType: chart_path = "match_weights_interactive_history.json" chart = load_chart_definition(chart_path) chart["title"]["subtitle"] = f"Training session blocked on {blocking_rule}" - records = [r for r in records if r["comparison_vector_value"] != -1] - chart["data"]["values"] = records + records = [r for r in records if r.comparison_vector_value != -1] max_iteration = 0 for r in records: - max_iteration = max(r["iteration"], max_iteration) + max_iteration = max(r.iteration, max_iteration) + chart["data"]["values"] = list_items_as_dicts(records) chart["params"][0]["bind"]["max"] = max_iteration chart["params"][0]["value"] = max_iteration return altair_or_json(chart, as_dict=as_dict) -def m_u_parameters_interactive_history_chart(records, as_dict=False): +def m_u_parameters_interactive_history_chart( + records: list[ModelParameterIterationDetailedRecord], as_dict: bool = False +) -> ChartReturnType: chart_path = "m_u_parameters_interactive_history.json" chart = load_chart_definition(chart_path) records = [ r for r in records - if r["comparison_vector_value"] != -1 - and r["comparison_name"] != "probability_two_random_records_match" + if r.comparison_vector_value != -1 + and r.comparison_name != "probability_two_random_records_match" ] - chart["data"]["values"] = records max_iteration = 0 for r in records: - max_iteration = max(r["iteration"], max_iteration) + max_iteration = max(r.iteration, max_iteration) + chart["data"]["values"] = list_items_as_dicts(records) chart["params"][0]["bind"]["max"] = max_iteration chart["params"][0]["value"] = max_iteration return altair_or_json(chart, as_dict=as_dict) @@ -349,20 +377,6 @@ def parameter_estimate_comparisons(records, as_dict=False): return altair_or_json(chart, as_dict=as_dict) -def missingness_chart(records, as_dict=False): - chart_path = "missingness.json" - chart = load_chart_definition(chart_path) - - chart["data"]["values"] = records - - record_count = records[0]["total_record_count"] - - for c in chart["layer"]: - c["title"] = f"Missingness per column out of {record_count:,.0f} records" - - return altair_or_json(chart, as_dict=as_dict) - - def unlinkables_chart( records, x_col="match_weight", diff --git a/splink/internals/comparison.py b/splink/internals/comparison.py index c7e69492d6..a3e9857499 100644 --- a/splink/internals/comparison.py +++ b/splink/internals/comparison.py @@ -10,7 +10,12 @@ join_list_with_commas_final_and, ) -from .comparison_level import ComparisonLevel, _default_m_values, _default_u_values +from .comparison_level import ( + ComparisonLevel, + ComparisonLevelDetailedRecord, + _default_m_values, + _default_u_values, +) # https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports if TYPE_CHECKING: @@ -360,15 +365,11 @@ def _is_trained(self): return self._all_m_are_trained and self._all_u_are_trained @property - def _as_detailed_records(self) -> list[dict[str, Any]]: + def _as_detailed_records(self) -> list[ComparisonLevelDetailedRecord]: records = [] for cl in self.comparison_levels: - record = {} - record["comparison_name"] = self.output_column_name - record = { - **record, - **cl._as_detailed_record(self._num_levels, self.comparison_levels), - } + record = cl._as_detailed_record(self._num_levels, self.comparison_levels) + record.comparison_name = self.output_column_name records.append(record) return records diff --git a/splink/internals/comparison_level.py b/splink/internals/comparison_level.py index 61aa669711..0f4de99597 100644 --- a/splink/internals/comparison_level.py +++ b/splink/internals/comparison_level.py @@ -4,6 +4,7 @@ import math import re from copy import copy +from dataclasses import asdict, dataclass from statistics import median from textwrap import dedent from typing import Any, Optional, Union, cast @@ -116,6 +117,34 @@ def _default_u_values(num_levels: int) -> list[float]: return u_vals +@dataclass +class ComparisonLevelDetailedRecord: + sql_condition: str | None + label_for_charts: str + + has_tf_adjustments: bool + tf_adjustment_column: str | None + tf_adjustment_weight: float | None + + is_null_level: bool + + m_probability: float | None + u_probability: float | None + m_probability_description: str | None + u_probability_description: str | None + + bayes_factor: float | None + log2_bayes_factor: float + bayes_factor_description: str + + comparison_vector_value: int + max_comparison_vector_value: int + comparison_name: str | None + + def as_dict(self): + return asdict(self) + + class ComparisonLevel: """Each ComparisonLevel defines a gradation (category) of similarity within a `Comparison`. @@ -735,39 +764,31 @@ def _as_completed_dict(self): def _as_detailed_record( self, comparison_num_levels: int, comparison_levels: list[ComparisonLevel] - ) -> dict[str, Any]: + ) -> ComparisonLevelDetailedRecord: "A detailed representation of this level to describe it in charting outputs" - output: dict[str, Any] = {} - output["sql_condition"] = self.sql_condition - output["label_for_charts"] = self._label_for_charts_no_duplicates( - comparison_levels + return ComparisonLevelDetailedRecord( + sql_condition=self.sql_condition, + label_for_charts=self._label_for_charts_no_duplicates(comparison_levels), + has_tf_adjustments=self._has_tf_adjustments, + tf_adjustment_column=( + self._tf_adjustment_input_column.input_name + if self._has_tf_adjustments + else None + ), + tf_adjustment_weight=self._tf_adjustment_weight, + is_null_level=self.is_null_level, + m_probability=self.m_probability if not self.is_null_level else None, + u_probability=self.u_probability if not self.is_null_level else None, + m_probability_description=self._m_probability_description, + u_probability_description=self._u_probability_description, + bayes_factor=self._bayes_factor, + log2_bayes_factor=self._log2_bayes_factor, + bayes_factor_description=self._bayes_factor_description, + comparison_vector_value=self.comparison_vector_value, + max_comparison_vector_value=comparison_num_levels - 1, + comparison_name=None, ) - if not self._is_null_level: - output["m_probability"] = self.m_probability - output["u_probability"] = self.u_probability - - output["m_probability_description"] = self._m_probability_description - output["u_probability_description"] = self._u_probability_description - - output["has_tf_adjustments"] = self._has_tf_adjustments - if self._has_tf_adjustments: - output["tf_adjustment_column"] = self._tf_adjustment_input_column.input_name - else: - output["tf_adjustment_column"] = None - output["tf_adjustment_weight"] = self._tf_adjustment_weight - - output["is_null_level"] = self.is_null_level - output["bayes_factor"] = self._bayes_factor - output["log2_bayes_factor"] = self._log2_bayes_factor - output["comparison_vector_value"] = self.comparison_vector_value - output["max_comparison_vector_value"] = comparison_num_levels - 1 - output["bayes_factor_description"] = self._bayes_factor_description - output["m_probability_description"] = self._m_probability_description - output["u_probability_description"] = self._u_probability_description - - return output - def _parameter_estimates_as_records( self, comparison_num_levels: int, comparison_levels: list[ComparisonLevel] ) -> list[dict[str, Any]]: @@ -786,9 +807,9 @@ def _parameter_estimates_as_records( else: record["estimated_probability_as_log_odds"] = None - record["sql_condition"] = cl_record["sql_condition"] - record["comparison_level_label"] = cl_record["label_for_charts"] - record["comparison_vector_value"] = cl_record["comparison_vector_value"] + record["sql_condition"] = cl_record.sql_condition + record["comparison_level_label"] = cl_record.label_for_charts + record["comparison_vector_value"] = cl_record.comparison_vector_value output_records.append(record) return output_records diff --git a/splink/internals/comparison_vector_values.py b/splink/internals/comparison_vector_values.py index ae6341543c..97b86bebf1 100644 --- a/splink/internals/comparison_vector_values.py +++ b/splink/internals/comparison_vector_values.py @@ -4,7 +4,9 @@ from typing import List, Optional from splink.internals.input_column import InputColumn -from splink.internals.unique_id_concat import _composite_unique_id_from_nodes_sql +from splink.internals.unique_id_concat import ( + _composite_unique_id_from_nodes_sql, +) logger = logging.getLogger(__name__) @@ -43,6 +45,8 @@ def compute_comparison_vector_values_from_id_pairs_sqls( source_dataset_input_column: Optional[InputColumn], unique_id_input_column: InputColumn, include_clerical_match_score: bool = False, + link_type: Optional[str] = None, + sql_dialect_str: Optional[str] = None, ) -> list[dict[str, str]]: """Compute the comparison vectors from __splink__blocked_id_pairs, the materialised dataframe of blocked pairwise record comparisons. @@ -59,18 +63,47 @@ def compute_comparison_vector_values_from_id_pairs_sqls( select_cols_expr = ", \n".join(columns_to_select_for_blocking) + # Where there are large numbers of unmatched records, the DuckDB query planner + # can struggle with the double inner join below. It should + # push the filters down to the input tables, but it doesn't always do this. + # This forces it. it is only really relevant in the link only case, + # where one dataset is much larger than the other + # This optimisation is here due to poor performance observed in + # the `uk_address_matcher` package + # TODO: Once DuckDB 1.5 is released, check this is still needed + # ref https://github.com/moj-analytical-services/uk_address_matcher/issues/226 + if ( + input_tablename_l == input_tablename_r + and link_type == "two_dataset_link_only" + and sql_dialect_str == "duckdb" + ): + uid_expr = _composite_unique_id_from_nodes_sql(unique_id_columns) + sql = f""" + select * + from {input_tablename_l} + where + {uid_expr} in (select join_key_l from __splink__blocked_id_pairs) + or + {uid_expr} in (select join_key_r from __splink__blocked_id_pairs) + """ + + sqls.append( + {"sql": sql, "output_table_name": "__splink__df_concat_with_tf_filtered"} + ) + input_tablename_l = "__splink__df_concat_with_tf_filtered" + input_tablename_r = "__splink__df_concat_with_tf_filtered" + uid_l_expr = _composite_unique_id_from_nodes_sql(unique_id_columns, "l") uid_r_expr = _composite_unique_id_from_nodes_sql(unique_id_columns, "r") # The first table selects the required columns from the input tables # and alises them as `col_l`, `col_r` etc # using the __splink__blocked_id_pairs as an associated (junction) table - # That is, it does the join, but doesn't compute the comparison vectors - sql = sql = f""" + sql = f""" select {select_cols_expr}, b.match_key - from {input_tablename_l} as l - inner join __splink__blocked_id_pairs as b + from __splink__blocked_id_pairs as b + inner join {input_tablename_l} as l on {uid_l_expr} = b.join_key_l inner join {input_tablename_r} as r on {uid_r_expr} = b.join_key_r diff --git a/splink/internals/em_training_session.py b/splink/internals/em_training_session.py index 866d7023ef..f50af9153b 100644 --- a/splink/internals/em_training_session.py +++ b/splink/internals/em_training_session.py @@ -1,6 +1,7 @@ from __future__ import annotations import logging +from dataclasses import asdict, dataclass from typing import TYPE_CHECKING, List from splink.internals.blocking import BlockingRule, block_using_rules_sqls @@ -22,6 +23,7 @@ from splink.internals.settings import ( ComparisonAndLevelDict, CoreModelSettings, + ModelParameterDetailedRecord, Settings, TrainingSettings, ) @@ -39,6 +41,27 @@ from splink.internals.splink_dataframe import SplinkDataFrame +@dataclass +class ModelParameterIterationDetailedRecord(ModelParameterDetailedRecord): + iteration: int + + @classmethod + def from_settings_param_detailed_record( + cls, + cl_rec: ModelParameterDetailedRecord, + *, + probability_two_random_records_match: float, + iteration: int, + ) -> ModelParameterIterationDetailedRecord: + cl_rec.probability_two_random_records_match = ( + probability_two_random_records_match + ) + return cls( + **asdict(cl_rec), + iteration=iteration, + ) + + class EMTrainingSession: """Manages training models using the Expectation Maximisation algorithm, and holds statistics on the evolution of parameter estimates. Plots diagnostic charts @@ -320,20 +343,20 @@ def _blocking_adjusted_probability_two_random_records_match(self): return adjusted_prop_m @property - def _iteration_history_records(self): + def _iteration_history_records(self) -> list[ModelParameterIterationDetailedRecord]: output_records = [] for iteration, core_model_settings in enumerate( self._core_model_settings_history ): - records = core_model_settings.parameters_as_detailed_records - - for r in records: - r["iteration"] = iteration - # TODO: why lambda from current settings, not history? - r["probability_two_random_records_match"] = ( - self.core_model_settings.probability_two_random_records_match + records = [ + ModelParameterIterationDetailedRecord.from_settings_param_detailed_record( + r, + probability_two_random_records_match=self.core_model_settings.probability_two_random_records_match, + iteration=iteration, ) + for r in core_model_settings.parameters_as_detailed_records + ] output_records.extend(records) return output_records @@ -370,9 +393,9 @@ def match_weights_interactive_history_chart(self) -> ChartReturnType: Returns: An interactive Altair chart. """ - records = self._iteration_history_records return match_weights_interactive_history_chart( - records, blocking_rule=self._blocking_rule_for_training.blocking_rule_sql + self._iteration_history_records, + blocking_rule=self._blocking_rule_for_training.blocking_rule_sql, ) def m_u_values_interactive_history_chart(self) -> ChartReturnType: @@ -382,8 +405,7 @@ def m_u_values_interactive_history_chart(self) -> ChartReturnType: Returns: An interactive Altair chart. """ - records = self._iteration_history_records - return m_u_parameters_interactive_history_chart(records) + return m_u_parameters_interactive_history_chart(self._iteration_history_records) def __repr__(self): deactivated_cols = ", ".join( diff --git a/splink/internals/files/chart_defs/missingness.json b/splink/internals/files/chart_defs/missingness.json deleted file mode 100644 index de18f660fb..0000000000 --- a/splink/internals/files/chart_defs/missingness.json +++ /dev/null @@ -1,80 +0,0 @@ -{ - "config": { - "view": { - "continuousWidth": 400, - "continuousHeight": 300 - }, - "axis": { - "labelFontSize": 11 - } - }, - "title": "", - "layer": [ - { - "mark": "bar", - "encoding": { - "color": { - "type": "quantitative", - "field": "null_proportion", - "legend": { - "format": ".0%", - "offset": 30 - }, - "scale": {"domain": [0, 1], - "range": "heatmap" - }, - "title": "Missingness" - }, - "tooltip": [ - { - "type": "nominal", - "field": "column_name", - "title": "Column" - }, - { - "type": "quantitative", - "field": "null_count", - "format": ",.0f", - "title": "Count of nulls" - }, - { - "type": "quantitative", - "field": "null_proportion", - "format": ".2%", - "title": "Percentage of nulls" - }, - { - "type": "quantitative", - "field": "total_record_count", - "format": ",.0f", - "title": "Total record count" - } - ], - "x": { - "type": "quantitative", - "scale": {"domain": [0, 1]}, - "axis": { - "labelAlign": "center", - "format": "%", - "title": "Percentage of nulls" - }, - "field": "null_proportion" - }, - "y": { - "type": "nominal", - "axis": { - "title": "" - }, - "field": "column_name", - "sort": "-x" - } - }, - "title": "" - } - ], - "data": { - "values": "", - "name": "data-0e7bce5a1d2f132e282789d6ef7780fe" - }, - "$schema": "https://vega.github.io/schema/vega-lite/v5.9.3.json" -} diff --git a/splink/internals/linker_components/inference.py b/splink/internals/linker_components/inference.py index 5811034a47..bedc3c4500 100644 --- a/splink/internals/linker_components/inference.py +++ b/splink/internals/linker_components/inference.py @@ -109,6 +109,8 @@ def deterministic_link(self) -> SplinkDataFrame: input_tablename_r="__splink__df_concat_with_tf", source_dataset_input_column=settings.column_info_settings.source_dataset_input_column, unique_id_input_column=settings.column_info_settings.unique_id_input_column, + link_type=settings._link_type, + sql_dialect_str=self._linker._sql_dialect_str, ) pipeline.enqueue_list_of_sqls(sqls) @@ -328,6 +330,8 @@ def predict_chunk( input_tablename_r="__splink__df_concat_with_tf", source_dataset_input_column=self._linker._settings_obj.column_info_settings.source_dataset_input_column, unique_id_input_column=self._linker._settings_obj.column_info_settings.unique_id_input_column, + link_type=settings._link_type, + sql_dialect_str=self._linker._sql_dialect_str, ) pipeline.enqueue_list_of_sqls(sqls) diff --git a/splink/internals/settings.py b/splink/internals/settings.py index 9a4adea43b..5a807d84d7 100644 --- a/splink/internals/settings.py +++ b/splink/internals/settings.py @@ -8,7 +8,10 @@ from splink.internals.blocking import BlockingRule from splink.internals.charts import m_u_parameters_chart, match_weights_chart from splink.internals.comparison import Comparison -from splink.internals.comparison_level import ComparisonLevel +from splink.internals.comparison_level import ( + ComparisonLevel, + ComparisonLevelDetailedRecord, +) from splink.internals.dialects import SplinkDialect from splink.internals.input_column import InputColumn from splink.internals.misc import ( @@ -27,6 +30,26 @@ class ComparisonAndLevelDict(TypedDict): comparison: Comparison +@dataclass +class ModelParameterDetailedRecord(ComparisonLevelDetailedRecord): + probability_two_random_records_match: float + comparison_sort_order: int + + @classmethod + def from_cl_detailed_record( + cls, + cl_rec: ComparisonLevelDetailedRecord, + *, + probability_two_random_records_match: float, + comparison_sort_order: int, + ) -> ModelParameterDetailedRecord: + return cls( + **asdict(cl_rec), + probability_two_random_records_match=probability_two_random_records_match, + comparison_sort_order=comparison_sort_order, + ) + + @dataclass(frozen=True) class ColumnInfoSettings: match_weight_column_prefix: str @@ -109,14 +132,18 @@ def copy(self): return deepcopy(self) @property - def parameters_as_detailed_records(self): + def parameters_as_detailed_records(self) -> list[ModelParameterDetailedRecord]: output = [] rr_match = self.probability_two_random_records_match for i, cc in enumerate(self.comparisons): - records = cc._as_detailed_records - for r in records: - r["probability_two_random_records_match"] = rr_match - r["comparison_sort_order"] = i + records = [ + ModelParameterDetailedRecord.from_cl_detailed_record( + r, + probability_two_random_records_match=rr_match, + comparison_sort_order=i, + ) + for r in cc._as_detailed_records + ] output.extend(records) prior_description = ( @@ -128,26 +155,26 @@ def parameters_as_detailed_records(self): ) # Finally add a record for probability_two_random_records_match - prop_record = { - "comparison_name": "probability_two_random_records_match", - "sql_condition": None, - "label_for_charts": "", - "m_probability": None, - "u_probability": None, - "m_probability_description": None, - "u_probability_description": None, - "has_tf_adjustments": False, - "tf_adjustment_column": None, - "tf_adjustment_weight": None, - "is_null_level": False, - "bayes_factor": prob_to_bayes_factor(rr_match), - "log2_bayes_factor": prob_to_match_weight(rr_match), - "comparison_vector_value": 0, - "max_comparison_vector_value": 0, - "bayes_factor_description": prior_description, - "probability_two_random_records_match": rr_match, - "comparison_sort_order": -1, - } + prop_record = ModelParameterDetailedRecord( + comparison_name="probability_two_random_records_match", + sql_condition=None, + label_for_charts="", + m_probability=None, + u_probability=None, + m_probability_description=None, + u_probability_description=None, + has_tf_adjustments=False, + tf_adjustment_column=None, + tf_adjustment_weight=None, + is_null_level=False, + bayes_factor=prob_to_bayes_factor(rr_match), + log2_bayes_factor=prob_to_match_weight(rr_match), + comparison_vector_value=0, + max_comparison_vector_value=0, + bayes_factor_description=prior_description, + probability_two_random_records_match=rr_match, + comparison_sort_order=-1, + ) output.insert(0, prop_record) return output @@ -568,13 +595,14 @@ def _as_completed_dict(self): } def match_weights_chart(self, as_dict=False): - records = self._parameters_as_detailed_records - - return match_weights_chart(records, as_dict=as_dict) + return match_weights_chart( + self._parameters_as_detailed_records, as_dict=as_dict + ) def m_u_parameters_chart(self, as_dict=False): - records = self._parameters_as_detailed_records - return m_u_parameters_chart(records, as_dict=as_dict) + return m_u_parameters_chart( + self._parameters_as_detailed_records, as_dict=as_dict + ) def _columns_without_estimated_parameters_message(self): message_lines = [] diff --git a/splink/internals/term_frequencies.py b/splink/internals/term_frequencies.py index da12b12699..85034f6474 100644 --- a/splink/internals/term_frequencies.py +++ b/splink/internals/term_frequencies.py @@ -239,7 +239,9 @@ def tf_adjustment_chart( # Data for chart comparison = linker._settings_obj._get_comparison_by_output_column_name(col) - comparison_records = comparison._as_detailed_records + comparison_records = list( + map(lambda rec: rec.as_dict(), comparison._as_detailed_records) + ) keys_to_retain = [ "comparison_vector_value", diff --git a/splink/internals/waterfall_chart.py b/splink/internals/waterfall_chart.py index a8a080957b..07ef8b3992 100644 --- a/splink/internals/waterfall_chart.py +++ b/splink/internals/waterfall_chart.py @@ -57,9 +57,9 @@ def _comparison_records( cl = c._get_comparison_level_by_comparison_vector_value(cv_value) waterfall_record = { field: value - for field, value in cl._as_detailed_record( - c._num_levels, c.comparison_levels - ).items() + for field, value in cl._as_detailed_record(c._num_levels, c.comparison_levels) + .as_dict() + .items() if field in [ "label_for_charts", diff --git a/tests/test_charts.py b/tests/test_charts.py index 455b6ac837..b85cb33114 100644 --- a/tests/test_charts.py +++ b/tests/test_charts.py @@ -142,6 +142,24 @@ def test_m_u_charts(dialect, test_helpers): linker.visualisations.match_weights_chart() +@mark_with_dialects_excluding() +def test_comparison_match_weight_charts(dialect, test_helpers): + settings = { + "link_type": "dedupe_only", + "comparisons": [ + cl.ExactMatch("gender"), + cl.ExactMatch("tm_partial"), + cl.LevenshteinAtThresholds("surname", [1]), + ], + } + helper = test_helpers[dialect] + + linker = helper.linker_with_registration(df, settings) + + comp = linker._settings_obj.comparisons[0] + comp.match_weights_chart() + + @mark_with_dialects_excluding() def test_parameter_estimate_charts(dialect, test_helpers): settings = {