diff --git a/docs/changes/DM-53999.misc.rst b/docs/changes/DM-53999.misc.rst new file mode 100644 index 00000000..d85df61f --- /dev/null +++ b/docs/changes/DM-53999.misc.rst @@ -0,0 +1 @@ +Improved the handling and validation of ``Column`` overrides diff --git a/python/felis/datamodel.py b/python/felis/datamodel.py index 1a36008f..8dd7c3bb 100644 --- a/python/felis/datamodel.py +++ b/python/felis/datamodel.py @@ -541,6 +541,25 @@ def check_votable_xtype(self) -> Column: self.votable_xtype = "timestamp" return self + def _update_from_overrides(self, overrides: ColumnOverrides) -> None: + """Update the column attributes from the given overrides. + + Parameters + ---------- + overrides + The column overrides to apply or `None` to skip applying overrides. + + Notes + ----- + Using ``model_fields_set`` allows updating only the fields that are + explicitly set in the `overrides` object. This prevents overwriting + existing column attributes which were not explicitly provided. + """ + if overrides.model_fields_set: + logger.debug("Applying overrides to column '%s': %s", self.id, overrides.model_fields_set) + for field in overrides.model_fields_set: + setattr(self, field, getattr(overrides, field)) + class Constraint(BaseObject): """Table constraint model.""" @@ -819,10 +838,15 @@ def serialize_columns(self, columns: list[ColumnRef | Column]) -> list[str]: class ColumnOverrides(BaseModel): """Allowed overrides for a referenced column. - All fields are optional; missing values mean "inherit from the base - column". + Notes + ----- + All of these fields are optional. Values of None may be explicitly set to + override the corresponding attribute in the referenced column but only + for certain fields (see validation in `_check_non_nullable_overrides`). """ + model_config = CONFIG.copy() + datatype: DataType | None = None """New datatype for the column.""" @@ -832,7 +856,7 @@ class ColumnOverrides(BaseModel): description: str | None = None """New description for the column.""" - nullable: bool = True + nullable: bool | None = None """New nullable flag for the column.""" tap_principal: int | None = Field(default=None, alias="tap:principal") @@ -841,6 +865,18 @@ class ColumnOverrides(BaseModel): tap_column_index: int | None = Field(default=None, alias="tap:column_index") """Override for the TAP_SCHEMA column index.""" + @model_validator(mode="before") + @classmethod + def _check_non_nullable_overrides(cls, data: Any) -> Any: + """Check that certain fields are not overridden to null.""" + if not isinstance(data, dict): + return data + non_nullable_fields = ("datatype", "length", "nullable", "tap_principal") + for name in non_nullable_fields: + if name in data and data[name] is None: + raise ValueError(f"The '{name}' field cannot be overridden to null") + return data + @field_serializer("datatype") def serialize_datatype(self, value: DataType | None) -> str | None: """Convert `DataType` to string when serializing to JSON/YAML. @@ -1331,21 +1367,6 @@ def _dereference_resource_columns(self: Schema, info: ValidationInfo) -> Schema: table.column_refs = {} return self - @classmethod - def _copy_overrides_to_column( - cls, column_ref: ColumnResourceRef, column_copy: Column - ) -> ColumnOverrides | None: - """Copy overrides from a column ref to a column.""" - if column_ref.overrides is not None: - overrides = column_ref.overrides - override_fields = overrides.model_fields_set - for field_name in override_fields: - if hasattr(column_copy, field_name): - # Use attribute assignment to avoid type conversion issues - # which can occur with using model_dump and model_copy - setattr(column_copy, field_name, getattr(overrides, field_name)) - return column_ref.overrides - @classmethod def _process_column_refs( cls, @@ -1394,8 +1415,7 @@ def _process_column_refs( f"from resource '{resource_schema.name}' and no ref_name provided" ) - # Create a copy of the base column and apply - # overrides + # Create a copy of the base column column_copy = base_column.model_copy() # Set the local name (key from the mapping) @@ -1406,10 +1426,10 @@ def _process_column_refs( # written out during serialization column_copy._is_resource_ref = True - # Apply overrides to the original column definition - overrides: ColumnOverrides | None = None - if column_ref is not None: - overrides = cls._copy_overrides_to_column(column_ref, column_copy) + # Apply overrides to the referenced column definition + overrides = column_ref.overrides if column_ref is not None else None + if overrides is not None: + column_copy._update_from_overrides(overrides) # Manually set the ID of the copied column as ID generation has # already occurred by now diff --git a/tests/test_datamodel.py b/tests/test_datamodel.py index 744f0398..cfa7ff85 100644 --- a/tests/test_datamodel.py +++ b/tests/test_datamodel.py @@ -22,6 +22,7 @@ import difflib import os import pathlib +import re import shutil import tempfile import unittest @@ -35,6 +36,7 @@ CheckConstraint, Column, ColumnGroup, + ColumnOverrides, Constraint, DataType, ForeignKeyConstraint, @@ -1551,5 +1553,205 @@ def test_tap_column_index_with_overrides(self) -> None: self.fail(f"Unexpected column name: {column.name}") +class ColumnOverridesTestCase(unittest.TestCase): + """Test application of overrides to a column, setting all allowed + fields. + """ + + def test_all_override_fields_exist_on_column(self) -> None: + """Ensure every ColumnOverrides field corresponds to an attribute on + Column. + """ + override_fields = set(ColumnOverrides.model_fields) + column_fields = set(Column.model_fields) + + missing = override_fields - column_fields + + self.assertFalse( + missing, + f"Column is missing attributes for override fields: {sorted(missing)}", + ) + + def test_overrides_all(self) -> None: + """Test updating all allowed column fields from overrides.""" + # Create a base column + base_column = Column( + name="base_column", + id="#base_column", + description="Base column", + datatype="char", + length=64, + nullable=False, + tap_principal=1, + tap_column_index=10, + ) + + # Override all allowed fields with different values + overrides = ColumnOverrides( + description="Ref column", + datatype="string", + length=256, + nullable=True, + tap_principal=0, + tap_column_index=100, + ) + + # Apply overrides + base_column._update_from_overrides(overrides) + + # Check that the attributes were updated correctly + self.assertEqual(base_column.description, "Ref column") + self.assertEqual(base_column.datatype, "string") + self.assertEqual(base_column.length, 256) + self.assertEqual(base_column.nullable, True) + self.assertEqual(base_column.tap_principal, 0) + self.assertEqual(base_column.tap_column_index, 100) + + def test_overrides_subset(self) -> None: + """Test updating a subset of allowed column fields from overrides.""" + # Create a base column + base_column = Column( + name="base_column", + id="#base_column", + description="Base column", + datatype="char", + length=64, + nullable=False, + tap_principal=1, + tap_column_index=10, + ) + + # Override all allowed fields with different values + overrides = ColumnOverrides( + description="Ref column", + tap_column_index=100, + ) + + # Apply overrides + base_column._update_from_overrides(overrides) + + # Check that the attributes were updated correctly + self.assertEqual(base_column.description, "Ref column") + self.assertEqual(base_column.datatype, "char") + self.assertEqual(base_column.length, 64) + self.assertEqual(base_column.nullable, False) + self.assertEqual(base_column.tap_principal, 1) + self.assertEqual(base_column.tap_column_index, 100) + + def test_overrides_default(self) -> None: + """Test that applying the default overrides is a no-op.""" + # Create a base column + base_column = Column( + name="base_column", + id="#base_column", + description="Base column", + datatype="char", + length=64, + nullable=False, + tap_principal=1, + tap_column_index=10, + ) + + # Apply overrides + base_column._update_from_overrides(ColumnOverrides()) + + # Check that the attributes remain unchanged + self.assertEqual(base_column.description, "Base column") + self.assertEqual(base_column.datatype, "char") + self.assertEqual(base_column.length, 64) + self.assertEqual(base_column.nullable, False) + self.assertEqual(base_column.tap_principal, 1) + self.assertEqual(base_column.tap_column_index, 10) + + def test_overrides_with_explicit_none_values(self) -> None: + """Test that passing explicit None values in overrides does update + the column attributes where allowed and raises errors if it is not. + """ + # Create a base column + base_column = Column( + name="base_column", + id="#base_column", + description="Base column", + datatype="int", + length=64, + nullable=False, + tap_principal=1, + tap_column_index=10, + ) + + # Create overrides with explicit None values for nullable fields + overrides = ColumnOverrides( + description=None, + tap_column_index=None, + ) + + # Apply overrides + base_column._update_from_overrides(overrides) + + # Check that the attributes were updated to None where allowed + self.assertIsNone(base_column.description) + self.assertIsNone(base_column.tap_column_index) + + # Check that setting non-nullable fields to None raise a specific + # ValueError on ColumnOverrides creation + for non_nullable_field in ("datatype", "length", "nullable", "tap_principal"): + with self.assertRaisesRegex( + ValueError, + re.escape(f"The '{non_nullable_field}' field cannot be overridden to null"), + ): + ColumnOverrides(**{non_nullable_field: None}) + + def test_extra_fields_in_overrides(self) -> None: + """Test that extra fields in ColumnOverrides raise a + ValidationError. + """ + with self.assertRaises(ValidationError) as cm: + ColumnOverrides( + description="Test column", + extra_field="This should not be allowed", + ) + + self.assertIn("Extra inputs are not permitted", str(cm.exception)) + + def test_overrides_accept_alias_keys(self) -> None: + """Test that alias keys for TAP fields are accepted and populate the + corresponding model fields. + """ + overrides = ColumnOverrides(**{"tap:principal": 1, "tap:column_index": 42}) + + self.assertEqual(overrides.tap_principal, 1) + self.assertEqual(overrides.tap_column_index, 42) + + # Ensure these count as explicitly provided (for model_fields_set + # logic). + self.assertIn("tap_principal", overrides.model_fields_set) + self.assertIn("tap_column_index", overrides.model_fields_set) + + def test_datatype_deserialize_and_serialize(self) -> None: + """Test that datatype is deserialized from a string to DataType and + serialized back to a string. + """ + overrides = ColumnOverrides(datatype="char") + + # Deserialization should yield a DataType instance (not a raw str). + self.assertIsInstance(overrides.datatype, DataType) + self.assertEqual(str(overrides.datatype), "char") + + # Serialization should produce a JSON-friendly string value. + dumped = overrides.model_dump(mode="json") + self.assertEqual(dumped["datatype"], "char") + + # None should remain None on serialization. + overrides_none = ColumnOverrides() + dumped_none = overrides_none.model_dump(mode="json") + self.assertIsNone(dumped_none["datatype"]) + + def test_non_nullable_overrides_data_is_none(self) -> None: + """Test that passing None to ``_check_non_nullable_overrides`` does not + raise an error. + """ + ColumnOverrides()._check_non_nullable_overrides(None) + + if __name__ == "__main__": unittest.main()